From d2ad4ed22fe13e74622768cab4c9b17245515e9a Mon Sep 17 00:00:00 2001 From: Mike Vink Date: Fri, 25 Sep 2020 09:15:40 +0200 Subject: Fix: bottleneck in the prediction --- tree.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/tree.py b/tree.py index e971bd4..75ea633 100644 --- a/tree.py +++ b/tree.py @@ -82,26 +82,26 @@ class Tree: """ # Maak een lijst van nodes, wiens indexes overeen komen met de rows in # x die we willen droppen - nodes = [self.tree] * len(x) + rows_to_predict = len(x) + nodes = np.array([self.tree] * rows_to_predict) + predictions = np.zeros(rows_to_predict) - # De index van de row van x die we in de boom willen droppen + # # De index van de row van x die we in de boom willen droppen drop = 0 - while not all(pred_class in {0, 1} for pred_class in nodes): - # Als de col None is dan is het een leaf node, dus dan is de row - # van x hier beland - if nodes[drop].col is None: - nodes[drop] = nodes[drop].split_value_or_rows + node = nodes[0] + while nodes.size != 0: + node = nodes[0] + if node.col is None: + node = node.split_value_or_rows + predictions[drop] = node + nodes = nodes[1:] drop += 1 continue - - # Vergelijk de x col (age,married,house,income,gender,class), in de - # gedropte row met het split value van de node. Op basis hiervan - # drop naar links of rechts - if x[drop, nodes[drop].col] > nodes[drop].split_value_or_rows: - nodes[drop] = nodes[drop].left + elif x[drop, node.col] > node.split_value_or_rows: + nodes[0] = node.left else: - nodes[drop] = nodes[drop].right - return np.array(nodes) + nodes[0] = node.right + return predictions # Work in progress tree printer # @@ -368,7 +368,7 @@ if __name__ == '__main__': delimiter=',', skip_header=True) - print("Dataset: credit data") + print("\nDataset: credit data") tree_pred(x=credit_data[:, :5], tr=tree_grow(x=credit_data[:, 0:5], y=credit_data[:, 5], @@ -377,7 +377,7 @@ if __name__ == '__main__': nfeat=5), training=credit_data[:, 5]) - print("Dataset: credit data") + print("\nDataset: credit data") tree_pred_b(x=credit_data[:, :5], tr=tree_grow_b(x=credit_data[:, 0:5], y=credit_data[:, 5], @@ -417,7 +417,7 @@ if __name__ == '__main__': # Time profile of pima indians data prediction with single tree # print("prediction metrics single tree pima indians:") # cProfile.run( - # "tree_pred_b(x=pima_indians[:, :8], tr=tree_grow_b(x=pima_indians[:, :8], y=pima_indians[:, 8], nmin=20, minleaf=5, nfeat=pima_indians.shape[1] - 1, m=50), training=pima_indians[:, 8])", + # "tree_pred_b(x=pima_indians[:, :8], tr=tree_grow_b(x=pima_indians[:, :8], y=pima_indians[:, 8], nmin=20, minleaf=5, nfeat=4, m=5), training=pima_indians[:, 8])", # 'restats') # p = pstats.Stats('restats') -- cgit v1.2.3