summaryrefslogtreecommitdiff
path: root/tree.py
diff options
context:
space:
mode:
authorMike Vink <mike1994vink@gmail.com>2020-09-24 17:52:55 +0200
committerMike Vink <mike1994vink@gmail.com>2020-09-24 17:52:55 +0200
commit35c613a547bb209cc6e9a289019c1ad3feff5fd4 (patch)
treebad67252b0beadc2db9e357491ea517ab223e013 /tree.py
parente1ad1dcd2ea5c7663ae0a867cc53b7cb13f3c91a (diff)
Concatenate list of arrays instead of vstack
Diffstat (limited to 'tree.py')
-rw-r--r--tree.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/tree.py b/tree.py
index 64e9d21..9f67b0b 100644
--- a/tree.py
+++ b/tree.py
@@ -202,7 +202,7 @@ def bestsplit(x, y, min_leaf) -> None:
# twee "x rows" arrays we moeten returnen, en welke split value het beste
# was natuurlijk.
- # Hieren stoppen we (rows for children, delta_i, split value, col will be added later)
+ # Hieren stacken we arrays onder (rows for children, delta_i, split value, col will be added later)
best_array = np.zeros((split_points.shape[0], x.shape[0] + 3), dtype='float64')
# print(f"{best_array=}")
# Stop wanneer de array met split points leeg is
@@ -255,7 +255,7 @@ def exhaustive_split_search(rows, classes, min_leaf):
print("\t\t->entering exhaustive split search")
# We hebben enumerate nodig, want we willen weten op welke col
# (age,married,house,income,gender) we een split doen
- exhaustive_best_array = np.zeros((1, rows.shape[0] + 3), dtype='float64')
+ exhaustive_best_array = []
print(f"Rows:\n{rows},\n Classes:\n{classes}")
for i, col in enumerate(rows.transpose()):
# calculate the best split for the col that satisfies the min_leaf
@@ -263,7 +263,8 @@ def exhaustive_split_search(rows, classes, min_leaf):
best_array_for_col = bestsplit(col, classes, min_leaf)
# add col number to rows
best_array_for_col[:,-1] = i
- exhaustive_best_array = np.vstack((exhaustive_best_array, best_array_for_col))
+ exhaustive_best_array.append(best_array_for_col)
+ exhaustive_best_array = np.concatenate(exhaustive_best_array)
print("The array with exhaustive splits is\n", exhaustive_best_array)
print("\t\t->returning from exhaustive split search")
return exhaustive_best_array