summaryrefslogtreecommitdiff
path: root/tree.py
diff options
context:
space:
mode:
Diffstat (limited to 'tree.py')
-rw-r--r--tree.py36
1 files changed, 18 insertions, 18 deletions
diff --git a/tree.py b/tree.py
index e971bd4..75ea633 100644
--- a/tree.py
+++ b/tree.py
@@ -82,26 +82,26 @@ class Tree:
"""
# Maak een lijst van nodes, wiens indexes overeen komen met de rows in
# x die we willen droppen
- nodes = [self.tree] * len(x)
+ rows_to_predict = len(x)
+ nodes = np.array([self.tree] * rows_to_predict)
+ predictions = np.zeros(rows_to_predict)
- # De index van de row van x die we in de boom willen droppen
+ # # De index van de row van x die we in de boom willen droppen
drop = 0
- while not all(pred_class in {0, 1} for pred_class in nodes):
- # Als de col None is dan is het een leaf node, dus dan is de row
- # van x hier beland
- if nodes[drop].col is None:
- nodes[drop] = nodes[drop].split_value_or_rows
+ node = nodes[0]
+ while nodes.size != 0:
+ node = nodes[0]
+ if node.col is None:
+ node = node.split_value_or_rows
+ predictions[drop] = node
+ nodes = nodes[1:]
drop += 1
continue
-
- # Vergelijk de x col (age,married,house,income,gender,class), in de
- # gedropte row met het split value van de node. Op basis hiervan
- # drop naar links of rechts
- if x[drop, nodes[drop].col] > nodes[drop].split_value_or_rows:
- nodes[drop] = nodes[drop].left
+ elif x[drop, node.col] > node.split_value_or_rows:
+ nodes[0] = node.left
else:
- nodes[drop] = nodes[drop].right
- return np.array(nodes)
+ nodes[0] = node.right
+ return predictions
# Work in progress tree printer
#
@@ -368,7 +368,7 @@ if __name__ == '__main__':
delimiter=',',
skip_header=True)
- print("Dataset: credit data")
+ print("\nDataset: credit data")
tree_pred(x=credit_data[:, :5],
tr=tree_grow(x=credit_data[:, 0:5],
y=credit_data[:, 5],
@@ -377,7 +377,7 @@ if __name__ == '__main__':
nfeat=5),
training=credit_data[:, 5])
- print("Dataset: credit data")
+ print("\nDataset: credit data")
tree_pred_b(x=credit_data[:, :5],
tr=tree_grow_b(x=credit_data[:, 0:5],
y=credit_data[:, 5],
@@ -417,7 +417,7 @@ if __name__ == '__main__':
# Time profile of pima indians data prediction with single tree
# print("prediction metrics single tree pima indians:")
# cProfile.run(
- # "tree_pred_b(x=pima_indians[:, :8], tr=tree_grow_b(x=pima_indians[:, :8], y=pima_indians[:, 8], nmin=20, minleaf=5, nfeat=pima_indians.shape[1] - 1, m=50), training=pima_indians[:, 8])",
+ # "tree_pred_b(x=pima_indians[:, :8], tr=tree_grow_b(x=pima_indians[:, :8], y=pima_indians[:, 8], nmin=20, minleaf=5, nfeat=4, m=5), training=pima_indians[:, 8])",
# 'restats')
# p = pstats.Stats('restats')