summaryrefslogtreecommitdiff
path: root/tree.py
diff options
context:
space:
mode:
Diffstat (limited to 'tree.py')
-rw-r--r--tree.py46
1 files changed, 42 insertions, 4 deletions
diff --git a/tree.py b/tree.py
index 0852d15..dcd5551 100644
--- a/tree.py
+++ b/tree.py
@@ -1,3 +1,5 @@
+import numpy as np
+from sklearn import metrics
#- Names and student no.:
# Hunter Sterk 6981046
@@ -359,9 +361,17 @@ def tree_grow(x=None, y=None, nmin=None, minleaf=None, nfeat=None) -> Tree:
continue
if impurity(node_classes) > 0:
- node_rows = x[node.split_value_or_rows]
+ # FIX: feature choice that was lost in versioning
+ # OLD: node_rows = x[node.split_value_or_rows]
+ # node_rows = x[node.split_value_or_rows]
+ # print(node.split_value_or_rows)
+
+ nfeat_col_choice = np.random.choice(range(x.shape[1]), nfeat, replace=False)
+ feat_select = np.sort(nfeat_col_choice)
+ node_rows = x[node.split_value_or_rows][:, feat_select]
+
exhaustive_best_list = exhaustive_split_search(
- node_rows, node_classes, minleaf)
+ node_rows, node_classes, feat_select, minleaf)
if not exhaustive_best_list:
node.is_leaf_node(node_classes)
continue
@@ -510,7 +520,7 @@ def bestsplit(x, y, minleaf) -> None:
return False
-def exhaustive_split_search(rows, classes, minleaf):
+def exhaustive_split_search(rows, classes, feat_select, minleaf):
"""
The nfeat repeated application of bestsplit.
"""
@@ -521,7 +531,7 @@ def exhaustive_split_search(rows, classes, minleaf):
col_best_split = bestsplit(col, classes, minleaf)
if col_best_split:
# add for which row we calculated the best split
- col_best_split += (i, )
+ col_best_split += (feat_select[i], )
exhaustive_best_list.append(col_best_split)
return exhaustive_best_list
@@ -555,3 +565,31 @@ def update_mask(mask, current_mask):
copy = np.array(current_mask, copy=True)
copy[current_mask == True] = mask
return copy
+
+if __name__ == '__main__':
+ c = np.loadtxt('./data/credit_score.txt', delimiter=',', skiprows=1)
+ x, y = c[:,0:5], c[:,5]
+ tr = tree_grow(x=x, y=y, nmin=2, minleaf=1, nfeat=5)
+ tree_pred(x, tr, true=y)
+
+ c = np.loadtxt('./data/credit_score.txt', delimiter=',', skiprows=1)
+ x, y = c[:,0:5], c[:,5]
+
+ trs = tree_grow_b(x=x, y=y, nmin=2, minleaf=1, nfeat=4, m=50)
+ tree_pred_b(x, trs, true=y)
+
+
+ c = np.loadtxt('./data/pima_indians.csv', delimiter=',')
+ x, y = c[:,0:8], c[:,8].astype(int)
+
+ tr = tree_grow(x=x, y=y, nmin=20, minleaf=5, nfeat=8)
+ tree_pred(x, tr, true=y)
+
+
+ c = np.loadtxt('./data/pima_indians.csv', delimiter=',')
+ x, y = c[:,0:8], c[:,8].astype(int)
+
+ trs = tree_grow_b(x=x, y=y, nmin=20, minleaf=5, nfeat=4, m=5)
+ tree_pred_b(x, trs, true=y)
+
+