diff options
Diffstat (limited to 'tree.py')
| -rw-r--r-- | tree.py | 46 |
1 files changed, 42 insertions, 4 deletions
@@ -1,3 +1,5 @@ +import numpy as np +from sklearn import metrics #- Names and student no.: # Hunter Sterk 6981046 @@ -359,9 +361,17 @@ def tree_grow(x=None, y=None, nmin=None, minleaf=None, nfeat=None) -> Tree: continue if impurity(node_classes) > 0: - node_rows = x[node.split_value_or_rows] + # FIX: feature choice that was lost in versioning + # OLD: node_rows = x[node.split_value_or_rows] + # node_rows = x[node.split_value_or_rows] + # print(node.split_value_or_rows) + + nfeat_col_choice = np.random.choice(range(x.shape[1]), nfeat, replace=False) + feat_select = np.sort(nfeat_col_choice) + node_rows = x[node.split_value_or_rows][:, feat_select] + exhaustive_best_list = exhaustive_split_search( - node_rows, node_classes, minleaf) + node_rows, node_classes, feat_select, minleaf) if not exhaustive_best_list: node.is_leaf_node(node_classes) continue @@ -510,7 +520,7 @@ def bestsplit(x, y, minleaf) -> None: return False -def exhaustive_split_search(rows, classes, minleaf): +def exhaustive_split_search(rows, classes, feat_select, minleaf): """ The nfeat repeated application of bestsplit. """ @@ -521,7 +531,7 @@ def exhaustive_split_search(rows, classes, minleaf): col_best_split = bestsplit(col, classes, minleaf) if col_best_split: # add for which row we calculated the best split - col_best_split += (i, ) + col_best_split += (feat_select[i], ) exhaustive_best_list.append(col_best_split) return exhaustive_best_list @@ -555,3 +565,31 @@ def update_mask(mask, current_mask): copy = np.array(current_mask, copy=True) copy[current_mask == True] = mask return copy + +if __name__ == '__main__': + c = np.loadtxt('./data/credit_score.txt', delimiter=',', skiprows=1) + x, y = c[:,0:5], c[:,5] + tr = tree_grow(x=x, y=y, nmin=2, minleaf=1, nfeat=5) + tree_pred(x, tr, true=y) + + c = np.loadtxt('./data/credit_score.txt', delimiter=',', skiprows=1) + x, y = c[:,0:5], c[:,5] + + trs = tree_grow_b(x=x, y=y, nmin=2, minleaf=1, nfeat=4, m=50) + tree_pred_b(x, trs, true=y) + + + c = np.loadtxt('./data/pima_indians.csv', delimiter=',') + x, y = c[:,0:8], c[:,8].astype(int) + + tr = tree_grow(x=x, y=y, nmin=20, minleaf=5, nfeat=8) + tree_pred(x, tr, true=y) + + + c = np.loadtxt('./data/pima_indians.csv', delimiter=',') + x, y = c[:,0:8], c[:,8].astype(int) + + trs = tree_grow_b(x=x, y=y, nmin=20, minleaf=5, nfeat=4, m=5) + tree_pred_b(x, trs, true=y) + + |
