summaryrefslogtreecommitdiff
path: root/assignment1.py
diff options
context:
space:
mode:
authorMike Vink <mike1994vink@gmail.com>2020-09-23 08:42:50 +0200
committerMike Vink <mike1994vink@gmail.com>2020-09-23 08:42:50 +0200
commit0cbb9da5d206817f857391b1965467945c25c056 (patch)
tree8dd5f1e67415900b4a0e7e7c0a8185615aa4a067 /assignment1.py
parent86b1c76e6bd6652ea667922e46d88d145d00d19e (diff)
Minor things
Diffstat (limited to 'assignment1.py')
-rw-r--r--assignment1.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/assignment1.py b/assignment1.py
index 8c5c746..4b264d7 100644
--- a/assignment1.py
+++ b/assignment1.py
@@ -55,6 +55,7 @@ class Node:
leaf node
"""
self.col = None
+ # This weird numpy line gives the majority vote, which is 1 or 0
self.split_value_or_rows = np.argmax(
np.bincount(node_classes.astype(int)))
@@ -198,9 +199,6 @@ def bestsplit(x, y, min_leaf) -> None:
# twee "x rows" arrays we moeten returnen, en welke split value het beste
# was natuurlijk.
- # Nodig voor de delta i formule
- impurity_parent, n_classes_parent = impurity(y), len(y)
-
# Hieren stoppen we (delta_i, split_value, rows_left, rows_right)
best_list = []
# Stop wanneer de array met split points leeg is
@@ -220,10 +218,12 @@ def bestsplit(x, y, min_leaf) -> None:
split_points = split_points[:-1]
continue
- # delta_i formule
+ # delta_i formule, improved by not taking the parent impurity into
+ # account, and not making the weighted average but the weigthed sum
+ # only (thanks Lonnie)
delta_i = (
impurity(classes_left) * len(classes_left) +
- impurity(classes_right) * len(classes_right)) / n_classes_parent
+ impurity(classes_right) * len(classes_right))
# stop huidige splits in de lijst om best split te berekenen
best_list.append((delta_i, mask_left, mask_right, split_value))
# Haal de huidige split_point uit split_points