summaryrefslogtreecommitdiff
path: root/assignment1.py
diff options
context:
space:
mode:
authorMike Vink <mike1994vink@gmail.com>2020-09-18 02:26:01 +0200
committerMike Vink <mike1994vink@gmail.com>2020-09-18 02:26:01 +0200
commit21284cbb4b9086678092e0a51630d3b589f91084 (patch)
tree4bb738b51f2da2769aa02053c82d6bf2e6526e3e /assignment1.py
parentca2e889d514684c861fbf8e1f21acdce4ceef730 (diff)
Major: improvements on readability, factoring
Diffstat (limited to 'assignment1.py')
-rw-r--r--assignment1.py126
1 files changed, 89 insertions, 37 deletions
diff --git a/assignment1.py b/assignment1.py
index c1a4370..2a670c2 100644
--- a/assignment1.py
+++ b/assignment1.py
@@ -44,8 +44,6 @@ class Node:
def add_split(self, left, right):
"""
- Method that is called in the main loop of tree_grow.
-
Lets the node object point to two other objects that can be either Leaf
or Node.
"""
@@ -54,7 +52,8 @@ class Node:
def is_leaf_node(self, node_classes):
"""
- @todo: Docstring for is_leaf_node
+ is_leaf_node is used to change the col attribute to zero to indicate a
+ leaf node
"""
self.col = None
self.split_value_or_rows = np.argmax(
@@ -85,22 +84,31 @@ class Tree:
"""
self.tree = root_node_obj
-
- # def __repr__(self):
- # nodelist = [self.tree]
- # tree_str = ''
- # while nodelist:
- # current_node = nodelist.pop()
- # # print(current_node.value)
- # try:
- # childs = [current_node.right, current_node.left]
- # nodelist += childs
- # except AttributeError:
- # pass
- # col, c = current_node.value
- # tree_str += f"{col=}, {c=}"
- # return tree_str
-
+ def __repr__(self):
+ tree_string = ''
+ node = self.tree
+ depth = 0
+ nodelist = [node]
+ while nodelist:
+ node = nodelist.pop()
+ depth += 1
+ if node.col is not None:
+ left, right = node.left, node.right
+ nodelist += [left, right]
+ else:
+ continue
+ tree_string += '\n' + depth * ' '
+ tree_string += (depth + 4) * ' ' + '/' + ' ' * 2 + '\\'
+ tree_string += '\n' + ' ' * 2 * depth
+ for direc in left, right:
+ if not direc.split_value_or_rows%10:
+ tree_string += ' ' * 4
+ else:
+ tree_string += ' ' * 3
+ tree_string += str(int(direc.split_value_or_rows))
+
+ tree_string = depth * ' ' + str(int(self.tree.split_value_or_rows)) + tree_string
+ return tree_string
def impurity(array) -> int:
"""
@@ -171,7 +179,7 @@ def bestsplit(x, y, min_leaf) -> None:
# was natuurlijk.
# Nodig voor de delta i formule
- impurity_parent, n_rows_parent = impurity(y), len(y)
+ impurity_parent, n_classes_parent = impurity(y), len(y)
# Hieren stoppen we (delta_i, split_value, rows_left, rows_right)
best_list = []
@@ -195,19 +203,24 @@ def bestsplit(x, y, min_leaf) -> None:
# delta_i formule
delta_i = impurity_parent - (
impurity(classes_left) * len(classes_left) +
- impurity(classes_right) * len(classes_right)) / n_rows_parent
+ impurity(classes_right) * len(classes_right)) / n_classes_parent
# stop huidige splits in de lijst om best split te berekenen
best_list.append((delta_i, mask_left, mask_right, split_value))
# Haal de huidige split_point uit split_points
split_points = split_points[:-1]
# Bereken de best split voor deze x col
- return max(best_list, key=lambda x: x[0])
+ if best_list:
+ return max(best_list, key=lambda x: x[0])
+ else:
+ return False
+
def exhaustive_split_search(rows, classes, min_leaf):
"""
@todo: Docstring for exhaustive_split_search
"""
+ print("\t\t->entering exhaustive split search")
# We hebben enumerate nodig, want we willen weten op welke col
# (age,married,house,income,gender) we een split doen
exhaustive_best_list = []
@@ -216,15 +229,19 @@ def exhaustive_split_search(rows, classes, min_leaf):
# calculate the best split for the col that satisfies the min_leaf
# constraint
col_best_split = bestsplit(col, classes, min_leaf)
- # add for which row we calculated the best split
- col_best_split += (i,)
- exhaustive_best_list.append(col_best_split)
+ # Check if there was a split fullfilling the min leaf rule
+ if col_best_split:
+ # add for which row we calculated the best split
+ col_best_split += (i,)
+ exhaustive_best_list.append(col_best_split)
+ print("\t\t->returning from exhaustive split search")
return exhaustive_best_list
def add_children(node, x, best_split):
"""
@todo: Docstring for add_children
"""
+ print("\t\t\t->entering add children")
# The mask that was used to get the rows for the current node from x, we
# need this to update the rows for the children
current_mask = node.split_value_or_rows
@@ -236,19 +253,47 @@ def add_children(node, x, best_split):
# Give the current node the split_value and col it needs for predictions
node.split_value_or_rows, node.col = node_split_value, node_col
- mask_left, mask_right = update_mask(x, mask_left, mask_right, current_mask)
- return [Node(split_value_or_rows=mask_left), Node(split_value_or_rows=mask_right)]
+ # Updating the row masks to give it to children, keeping numpy dimension consistent
+ mask_left, mask_right = update_mask(mask_left, current_mask), update_mask(mask_right, current_mask)
+
+ # Adding the link between parent and children
+ node.left, node.right = Node(split_value_or_rows=mask_left), Node(split_value_or_rows=mask_right)
+ print("\t\t\t->children added to node and node list\n")
+ return [node.left, node.right]
-def update_mask(x, mask_left, mask_right, current_mask):
+def update_mask(mask, current_mask):
"""
@todo: Docstring for update_mask
"""
- print(f"{current_mask=}")
- print(f"{mask_left=}")
- print(f"{mask_right=}")
- current_row_no = np.where(current_mask)
- print(f"{current_rows=}")
- return mask_left, mask_right
+ print("\t\t\t\t->entering update mask to calculate which rows belong to child")
+ # current_mask = np.array([ True, True, True, True, True, True, False, False, False, True])
+ # print(f"Parent mask: {current_mask=}")
+ # print(f"The result of bestsplit: {mask=}")
+ # mask_left=np.array([False, True, False, False, False, True, True])
+ # mask_right=np.array([ True, False, True, True, True, False, False])
+ copy = np.array(current_mask, copy=True)
+ copy[np.where(current_mask)] = mask
+ # print(f"Child mask:{current_mask=}")
+
+ # print(f"{current_mask=}")
+ # print(f"{mask_right=}")
+ # current_mask[np.where(current_mask)] = mask_left
+ # print(f"{current_mask=}")
+ # mask_right = np.where(current_mask, False, True)
+ # print(f"{mask_right=}")
+
+ # current_row_indexs = np.where(current_mask, mask_left, mask_right)
+ # print(f"{current_mask[current_row_indexs]=}")
+ # current_mask[current_row_indexs] = mask_left
+
+
+ # print(f"{current_row_indexs=}")
+ # print(f"{mask_left=}")
+ # print(f"{len(mask_left)=}")
+ # print(f"{mask_right=}")
+ # print(f"{len(mask_right)=}")
+ print("\t\t\t\t->updated row mask for child node")
+ return copy
#
@@ -283,13 +328,13 @@ def tree_grow(x=None,
# etc. totdat alle splits gemaakt zijn en de lijst leeg is.
nodelist = [root]
- while None not in nodelist:
- print(nodelist)
+ while nodelist:
+ print("->Taking new node from the node list")
# Pop de current node uit de nodelist
node = nodelist.pop()
- print(nodelist)
# Gebruik de boolean mask van de node om de rows in de node uit x te halen
node_rows = x[node.split_value_or_rows]
+ # print(node_rows)
# Gebruik de boolean mask van de node om de classes in de node uit y te halen
node_classes = y[node.split_value_or_rows]
@@ -298,10 +343,13 @@ def tree_grow(x=None,
# Test of de node een leaf node is met n_min
if len(node_rows) < n_min:
node.is_leaf_node(node_classes)
+ print("\t->Node has less rows than n_min, it is a leaf node, continueing to next node")
continue
+ print("\t->Node has more rows than n_min, it is not a leaf node")
# Als de node niet puur is, probeer dan te splitten
if impurity(node_classes) > 0:
+ print("\t->Node is not pure yet starting exhaustive split search")
# We gaan exhaustively voor de rows in de node over de cols
# (age,married,house,income,gender) om de bestsplit te
# bepalen
@@ -316,7 +364,9 @@ def tree_grow(x=None,
# an de min leaf constraint
if not exhaustive_best_list:
node.is_leaf_node(node_classes)
+ print("\t\t->No split that fullfils min_leaf was found continueing to next node")
continue
+ print("\t->Exhaustive search found a split fulfilling the min_leaf rule!")
# Hier halen we de beste split, en rows voor de child/split nodes
# uit de exhaustive best list
best_split = max(exhaustive_best_list, key=lambda z: z[0])
@@ -325,7 +375,9 @@ def tree_grow(x=None,
# node.add_split(left_child_node, right_child_node)
else:
node.is_leaf_node(node_classes)
+ print("\t\t->The node is already pure, it is necessarily a leaf!")
continue
+ print(tr)
return tr
# Initiate the nodelist with tuples of slice and class labels