summaryrefslogtreecommitdiff
path: root/tree.py
diff options
context:
space:
mode:
authorMike Vink <mike1994vink@gmail.com>2020-10-04 15:29:51 +0200
committerMike Vink <mike1994vink@gmail.com>2020-10-04 15:29:51 +0200
commit5eeb3766fa3b94a58c23485664b18e0cfa92023b (patch)
treee5ad2f3e999ed3e34657e201db4d1b6a1f33dc43 /tree.py
parentaeda214879ee4858eb93169257aa7b87c0973789 (diff)
Almost finished with writing the docs
Diffstat (limited to 'tree.py')
-rw-r--r--tree.py103
1 files changed, 85 insertions, 18 deletions
diff --git a/tree.py b/tree.py
index db40611..0351e3c 100644
--- a/tree.py
+++ b/tree.py
@@ -171,6 +171,7 @@
# the classes vector. Note that when the number of 1 and 0 elements are
# equal, it returns 0.
+# EXAMPLE:
# >>> y
# array([0., 0., 0., 0., 0., 1., 1., 1., 1., 1.])
# >>> major_vote(y)
@@ -178,6 +179,85 @@
# """
+# def impurity(array):
+# """
+# /array/ numpy.array, 1D numpy array of zeroes and ones
+
+# Returns -> float
+
+# Computes the gini index impurity based on the relative frequency of ones
+# in the vector.
+#
+# EXAMPLE:
+# >>> array=np.array([1,0,1,1,1,0,0,1,1,0,1])
+# >>> array
+# array([1,0,1,1,1,0,0,1,1,0,1])
+# >>> impurity(array)
+# 0.23140495867768596
+
+# """
+
+# def bestsplit(x,y,minleaf):
+# """
+# /x/ numpy.array, 1D numpy array corresponding to a feature column of the
+# 2D data array corresponding to some node
+# /y/ numpy.array, 1D numpy array of binary labels corresponding to the
+# rows in the 2D data array corresponding to some node
+# /minleaf/ @todo
+
+# Returns -> tuple
+
+# """
+
+
+
+def bestsplit(x, y, minleaf) -> None:
+ """
+ x = vector of single col
+ y = vector of classes (last col in x)
+
+ Consider splits of type "x <= c" where "c" is the average of two consecutive
+ values of x in the sorted order.
+
+ x and y must be of the same length
+
+ y[i] must be the class label of the i-th observation, and x[i] is the
+ correspnding value of attribute x
+
+ Example (best split on income):
+
+ >>> bestsplit(credit_data[:,3],credit_data[:,5])
+ 36
+ """
+ x_sorted = np.sort(np.unique(x))
+ split_points = (x_sorted[:len(x_sorted) - 1] + x_sorted[1:]) / 2
+
+ # Hieren stoppen we (delta_i, split_value, rows_left, rows_right)
+ best_list = []
+ while split_points.size != 0:
+ split_value = split_points[-1]
+
+ mask_left, mask_right = x > split_value, x <= split_value
+ classes_left, classes_right = y[mask_left], y[mask_right]
+
+ if len(classes_left) < minleaf or len(classes_right) < minleaf:
+ split_points = split_points[:-1]
+ continue
+
+ delta_i = (impurity(classes_left) * len(classes_left) +
+ impurity(classes_right) * len(classes_right))
+
+ best_list.append((delta_i, mask_left, mask_right, split_value))
+
+ split_points = split_points[:-1]
+
+ # Bereken de best split voor deze x col, als er ten minste 1 bestaat die
+ # voldoet aan min leaf
+ if best_list:
+ return min(best_list, key=lambda x: x[0])
+ else:
+ return False
+
import numpy as np
@@ -371,18 +451,7 @@ def major_vote(classes):
def impurity(array) -> int:
"""
- Assumes the argument array is a one dimensional vector of zeroes and ones.
- Computes the gini index impurity based on the relative frequency of ones in
- the vector.
-
- Example:
-
- >>> array=np.array([1,0,1,1,1,0,0,1,1,0,1])
- >>> array
- array([1,0,1,1,1,0,0,1,1,0,1])
-
- >>> impurity(array)
- 0.23140495867768596
+ Calculates the impurity of the labels in a node.
"""
n_labels = len(array)
n_labels_1 = array.sum()
@@ -415,23 +484,21 @@ def bestsplit(x, y, minleaf) -> None:
# Hieren stoppen we (delta_i, split_value, rows_left, rows_right)
best_list = []
- # Stop wanneer de array met split points leeg is
while split_points.size != 0:
split_value = split_points[-1]
+
mask_left, mask_right = x > split_value, x <= split_value
classes_left, classes_right = y[mask_left], y[mask_right]
- # Kijk of er genoeg rows in de gesplitte nodes terechtkomen, anders
- # mogen we de split niet toelaten vanwege de minleaf constraint
if len(classes_left) < minleaf or len(classes_right) < minleaf:
split_points = split_points[:-1]
continue
delta_i = (impurity(classes_left) * len(classes_left) +
impurity(classes_right) * len(classes_right))
- # stop huidige splits in de lijst om best split te berekenen
+
best_list.append((delta_i, mask_left, mask_right, split_value))
- # Haal de huidige split_point uit split_points
+
split_points = split_points[:-1]
# Bereken de best split voor deze x col, als er ten minste 1 bestaat die
@@ -444,7 +511,7 @@ def bestsplit(x, y, minleaf) -> None:
def exhaustive_split_search(rows, classes, minleaf):
"""
- The nfeat repeated application of best_split.
+ The nfeat repeated application of bestsplit.
"""
# We hebben enumerate nodig, want we willen weten op welke col (i)
# (age,married,house,income,gender) we een split doen