diff options
| -rw-r--r-- | tree.py | 103 |
1 files changed, 85 insertions, 18 deletions
@@ -171,6 +171,7 @@ # the classes vector. Note that when the number of 1 and 0 elements are # equal, it returns 0. +# EXAMPLE: # >>> y # array([0., 0., 0., 0., 0., 1., 1., 1., 1., 1.]) # >>> major_vote(y) @@ -178,6 +179,85 @@ # """ +# def impurity(array): +# """ +# /array/ numpy.array, 1D numpy array of zeroes and ones + +# Returns -> float + +# Computes the gini index impurity based on the relative frequency of ones +# in the vector. +# +# EXAMPLE: +# >>> array=np.array([1,0,1,1,1,0,0,1,1,0,1]) +# >>> array +# array([1,0,1,1,1,0,0,1,1,0,1]) +# >>> impurity(array) +# 0.23140495867768596 + +# """ + +# def bestsplit(x,y,minleaf): +# """ +# /x/ numpy.array, 1D numpy array corresponding to a feature column of the +# 2D data array corresponding to some node +# /y/ numpy.array, 1D numpy array of binary labels corresponding to the +# rows in the 2D data array corresponding to some node +# /minleaf/ @todo + +# Returns -> tuple + +# """ + + + +def bestsplit(x, y, minleaf) -> None: + """ + x = vector of single col + y = vector of classes (last col in x) + + Consider splits of type "x <= c" where "c" is the average of two consecutive + values of x in the sorted order. + + x and y must be of the same length + + y[i] must be the class label of the i-th observation, and x[i] is the + correspnding value of attribute x + + Example (best split on income): + + >>> bestsplit(credit_data[:,3],credit_data[:,5]) + 36 + """ + x_sorted = np.sort(np.unique(x)) + split_points = (x_sorted[:len(x_sorted) - 1] + x_sorted[1:]) / 2 + + # Hieren stoppen we (delta_i, split_value, rows_left, rows_right) + best_list = [] + while split_points.size != 0: + split_value = split_points[-1] + + mask_left, mask_right = x > split_value, x <= split_value + classes_left, classes_right = y[mask_left], y[mask_right] + + if len(classes_left) < minleaf or len(classes_right) < minleaf: + split_points = split_points[:-1] + continue + + delta_i = (impurity(classes_left) * len(classes_left) + + impurity(classes_right) * len(classes_right)) + + best_list.append((delta_i, mask_left, mask_right, split_value)) + + split_points = split_points[:-1] + + # Bereken de best split voor deze x col, als er ten minste 1 bestaat die + # voldoet aan min leaf + if best_list: + return min(best_list, key=lambda x: x[0]) + else: + return False + import numpy as np @@ -371,18 +451,7 @@ def major_vote(classes): def impurity(array) -> int: """ - Assumes the argument array is a one dimensional vector of zeroes and ones. - Computes the gini index impurity based on the relative frequency of ones in - the vector. - - Example: - - >>> array=np.array([1,0,1,1,1,0,0,1,1,0,1]) - >>> array - array([1,0,1,1,1,0,0,1,1,0,1]) - - >>> impurity(array) - 0.23140495867768596 + Calculates the impurity of the labels in a node. """ n_labels = len(array) n_labels_1 = array.sum() @@ -415,23 +484,21 @@ def bestsplit(x, y, minleaf) -> None: # Hieren stoppen we (delta_i, split_value, rows_left, rows_right) best_list = [] - # Stop wanneer de array met split points leeg is while split_points.size != 0: split_value = split_points[-1] + mask_left, mask_right = x > split_value, x <= split_value classes_left, classes_right = y[mask_left], y[mask_right] - # Kijk of er genoeg rows in de gesplitte nodes terechtkomen, anders - # mogen we de split niet toelaten vanwege de minleaf constraint if len(classes_left) < minleaf or len(classes_right) < minleaf: split_points = split_points[:-1] continue delta_i = (impurity(classes_left) * len(classes_left) + impurity(classes_right) * len(classes_right)) - # stop huidige splits in de lijst om best split te berekenen + best_list.append((delta_i, mask_left, mask_right, split_value)) - # Haal de huidige split_point uit split_points + split_points = split_points[:-1] # Bereken de best split voor deze x col, als er ten minste 1 bestaat die @@ -444,7 +511,7 @@ def bestsplit(x, y, minleaf) -> None: def exhaustive_split_search(rows, classes, minleaf): """ - The nfeat repeated application of best_split. + The nfeat repeated application of bestsplit. """ # We hebben enumerate nodig, want we willen weten op welke col (i) # (age,married,house,income,gender) we een split doen |
