summaryrefslogtreecommitdiff
path: root/tree.py
diff options
context:
space:
mode:
Diffstat (limited to 'tree.py')
-rw-r--r--tree.py81
1 files changed, 41 insertions, 40 deletions
diff --git a/tree.py b/tree.py
index 692d5e3..9ca8eae 100644
--- a/tree.py
+++ b/tree.py
@@ -203,61 +203,62 @@
# 2D data array corresponding to some node
# /y/ numpy.array, 1D numpy array of binary labels corresponding to the
# rows in the 2D data array corresponding to some node
-# /minleaf/ @todo
+# /minleaf/ int, number of x-rows a child must have before splitting
# Returns -> tuple
-# """
-
-
+# Computes the best split based on the given features using the impurity
+# function.
-def bestsplit(x, y, minleaf) -> None:
- """
- x = vector of single col
- y = vector of classes (last col in x)
-
- Consider splits of type "x <= c" where "c" is the average of two consecutive
- values of x in the sorted order.
+# EXAMPLE:
+# >>> x
+# array([28., 32., 24., 27., 32., 30., 58., 52., 40., 28.])
+# >>> y
+# array([0., 0., 0., 0., 0., 1., 1., 1., 1., 1.])
+# >>> bestsplit(x,y,minleaf=1)
+# (1.4285714285714286, array([False, False, False, False, False, False, True, True, True,
+# False]), array([ True, True, True, True, True, True, False, False, False,
+# True]), 36.0)
+#
- x and y must be of the same length
+# """
- y[i] must be the class label of the i-th observation, and x[i] is the
- correspnding value of attribute x
+# def exhaustive_split_search(rows, classes, minleaf):
+# """
+# /rows/ numpy.array, 2D data array corresponding to some node
+# /classes/ numpy.array, 1D numpy array of binary labels corresponding to the
+# rows in the 2D data array corresponding to some node
+# /minleaf/ int, number of x-rows a child must have before splitting
- Example (best split on income):
+# Returns -> List
- >>> bestsplit(credit_data[:,3],credit_data[:,5])
- 36
- """
- x_sorted = np.sort(np.unique(x))
- split_points = (x_sorted[:len(x_sorted) - 1] + x_sorted[1:]) / 2
+# Stores the best splits computed with the bestsplit function for the
+# considered columns in a list, if the list is empty then there are no
+# splits and the node becomes a leaf node.
- # Hieren stoppen we (delta_i, split_value, rows_left, rows_right)
- best_list = []
- while split_points.size != 0:
- split_value = split_points[-1]
-
- mask_left, mask_right = x > split_value, x <= split_value
- classes_left, classes_right = y[mask_left], y[mask_right]
+# """
- if len(classes_left) < minleaf or len(classes_right) < minleaf:
- split_points = split_points[:-1]
- continue
+# def add_children(node, best_split):
+# """
+# /node/ Node object, the current node in the main tree growing loop
+# /best_split/ tuple, tuple containing (delta_i, rows_left, rows_right, splitvalue)
- delta_i = (impurity(classes_left) * len(classes_left) +
- impurity(classes_right) * len(classes_right))
+# Processes the splits into the tree data structure and returns children yet
+# to be splitted to the main nodelist in tree_grow.
- best_list.append((delta_i, mask_left, mask_right, split_value))
+# """
- split_points = split_points[:-1]
+# def update_mask(mask, current_mask):
+# """
+# /mask/ np.array, 1D boolean vector corresponding to the rows in the new
+# child node that might have a length that is incompatible with the rows in
+# the main 2D data array x of tree_grow
+# /current_mask/ np.array, 1D boolean vector corresponding to the rows in the current node
- # Bereken de best split voor deze x col, als er ten minste 1 bestaat die
- # voldoet aan min leaf
- if best_list:
- return min(best_list, key=lambda x: x[0])
- else:
- return False
+# Updates the spit bool array from any dimension to an array with length
+# equal to the total number of rows in dataset x of tree_grow.
+# """
import numpy as np