diff options
Diffstat (limited to 'tree.py')
| -rw-r--r-- | tree.py | 81 |
1 files changed, 41 insertions, 40 deletions
@@ -203,61 +203,62 @@ # 2D data array corresponding to some node # /y/ numpy.array, 1D numpy array of binary labels corresponding to the # rows in the 2D data array corresponding to some node -# /minleaf/ @todo +# /minleaf/ int, number of x-rows a child must have before splitting # Returns -> tuple -# """ - - +# Computes the best split based on the given features using the impurity +# function. -def bestsplit(x, y, minleaf) -> None: - """ - x = vector of single col - y = vector of classes (last col in x) - - Consider splits of type "x <= c" where "c" is the average of two consecutive - values of x in the sorted order. +# EXAMPLE: +# >>> x +# array([28., 32., 24., 27., 32., 30., 58., 52., 40., 28.]) +# >>> y +# array([0., 0., 0., 0., 0., 1., 1., 1., 1., 1.]) +# >>> bestsplit(x,y,minleaf=1) +# (1.4285714285714286, array([False, False, False, False, False, False, True, True, True, +# False]), array([ True, True, True, True, True, True, False, False, False, +# True]), 36.0) +# - x and y must be of the same length +# """ - y[i] must be the class label of the i-th observation, and x[i] is the - correspnding value of attribute x +# def exhaustive_split_search(rows, classes, minleaf): +# """ +# /rows/ numpy.array, 2D data array corresponding to some node +# /classes/ numpy.array, 1D numpy array of binary labels corresponding to the +# rows in the 2D data array corresponding to some node +# /minleaf/ int, number of x-rows a child must have before splitting - Example (best split on income): +# Returns -> List - >>> bestsplit(credit_data[:,3],credit_data[:,5]) - 36 - """ - x_sorted = np.sort(np.unique(x)) - split_points = (x_sorted[:len(x_sorted) - 1] + x_sorted[1:]) / 2 +# Stores the best splits computed with the bestsplit function for the +# considered columns in a list, if the list is empty then there are no +# splits and the node becomes a leaf node. - # Hieren stoppen we (delta_i, split_value, rows_left, rows_right) - best_list = [] - while split_points.size != 0: - split_value = split_points[-1] - - mask_left, mask_right = x > split_value, x <= split_value - classes_left, classes_right = y[mask_left], y[mask_right] +# """ - if len(classes_left) < minleaf or len(classes_right) < minleaf: - split_points = split_points[:-1] - continue +# def add_children(node, best_split): +# """ +# /node/ Node object, the current node in the main tree growing loop +# /best_split/ tuple, tuple containing (delta_i, rows_left, rows_right, splitvalue) - delta_i = (impurity(classes_left) * len(classes_left) + - impurity(classes_right) * len(classes_right)) +# Processes the splits into the tree data structure and returns children yet +# to be splitted to the main nodelist in tree_grow. - best_list.append((delta_i, mask_left, mask_right, split_value)) +# """ - split_points = split_points[:-1] +# def update_mask(mask, current_mask): +# """ +# /mask/ np.array, 1D boolean vector corresponding to the rows in the new +# child node that might have a length that is incompatible with the rows in +# the main 2D data array x of tree_grow +# /current_mask/ np.array, 1D boolean vector corresponding to the rows in the current node - # Bereken de best split voor deze x col, als er ten minste 1 bestaat die - # voldoet aan min leaf - if best_list: - return min(best_list, key=lambda x: x[0]) - else: - return False +# Updates the spit bool array from any dimension to an array with length +# equal to the total number of rows in dataset x of tree_grow. +# """ import numpy as np |
