diff options
| author | Mike Vink <mike1994vink@gmail.com> | 2020-09-18 02:26:01 +0200 |
|---|---|---|
| committer | Mike Vink <mike1994vink@gmail.com> | 2020-09-18 02:26:01 +0200 |
| commit | 21284cbb4b9086678092e0a51630d3b589f91084 (patch) | |
| tree | 4bb738b51f2da2769aa02053c82d6bf2e6526e3e | |
| parent | ca2e889d514684c861fbf8e1f21acdce4ceef730 (diff) | |
Major: improvements on readability, factoring
| -rw-r--r-- | assignment1.py | 126 |
1 files changed, 89 insertions, 37 deletions
diff --git a/assignment1.py b/assignment1.py index c1a4370..2a670c2 100644 --- a/assignment1.py +++ b/assignment1.py @@ -44,8 +44,6 @@ class Node: def add_split(self, left, right): """ - Method that is called in the main loop of tree_grow. - Lets the node object point to two other objects that can be either Leaf or Node. """ @@ -54,7 +52,8 @@ class Node: def is_leaf_node(self, node_classes): """ - @todo: Docstring for is_leaf_node + is_leaf_node is used to change the col attribute to zero to indicate a + leaf node """ self.col = None self.split_value_or_rows = np.argmax( @@ -85,22 +84,31 @@ class Tree: """ self.tree = root_node_obj - - # def __repr__(self): - # nodelist = [self.tree] - # tree_str = '' - # while nodelist: - # current_node = nodelist.pop() - # # print(current_node.value) - # try: - # childs = [current_node.right, current_node.left] - # nodelist += childs - # except AttributeError: - # pass - # col, c = current_node.value - # tree_str += f"{col=}, {c=}" - # return tree_str - + def __repr__(self): + tree_string = '' + node = self.tree + depth = 0 + nodelist = [node] + while nodelist: + node = nodelist.pop() + depth += 1 + if node.col is not None: + left, right = node.left, node.right + nodelist += [left, right] + else: + continue + tree_string += '\n' + depth * ' ' + tree_string += (depth + 4) * ' ' + '/' + ' ' * 2 + '\\' + tree_string += '\n' + ' ' * 2 * depth + for direc in left, right: + if not direc.split_value_or_rows%10: + tree_string += ' ' * 4 + else: + tree_string += ' ' * 3 + tree_string += str(int(direc.split_value_or_rows)) + + tree_string = depth * ' ' + str(int(self.tree.split_value_or_rows)) + tree_string + return tree_string def impurity(array) -> int: """ @@ -171,7 +179,7 @@ def bestsplit(x, y, min_leaf) -> None: # was natuurlijk. # Nodig voor de delta i formule - impurity_parent, n_rows_parent = impurity(y), len(y) + impurity_parent, n_classes_parent = impurity(y), len(y) # Hieren stoppen we (delta_i, split_value, rows_left, rows_right) best_list = [] @@ -195,19 +203,24 @@ def bestsplit(x, y, min_leaf) -> None: # delta_i formule delta_i = impurity_parent - ( impurity(classes_left) * len(classes_left) + - impurity(classes_right) * len(classes_right)) / n_rows_parent + impurity(classes_right) * len(classes_right)) / n_classes_parent # stop huidige splits in de lijst om best split te berekenen best_list.append((delta_i, mask_left, mask_right, split_value)) # Haal de huidige split_point uit split_points split_points = split_points[:-1] # Bereken de best split voor deze x col - return max(best_list, key=lambda x: x[0]) + if best_list: + return max(best_list, key=lambda x: x[0]) + else: + return False + def exhaustive_split_search(rows, classes, min_leaf): """ @todo: Docstring for exhaustive_split_search """ + print("\t\t->entering exhaustive split search") # We hebben enumerate nodig, want we willen weten op welke col # (age,married,house,income,gender) we een split doen exhaustive_best_list = [] @@ -216,15 +229,19 @@ def exhaustive_split_search(rows, classes, min_leaf): # calculate the best split for the col that satisfies the min_leaf # constraint col_best_split = bestsplit(col, classes, min_leaf) - # add for which row we calculated the best split - col_best_split += (i,) - exhaustive_best_list.append(col_best_split) + # Check if there was a split fullfilling the min leaf rule + if col_best_split: + # add for which row we calculated the best split + col_best_split += (i,) + exhaustive_best_list.append(col_best_split) + print("\t\t->returning from exhaustive split search") return exhaustive_best_list def add_children(node, x, best_split): """ @todo: Docstring for add_children """ + print("\t\t\t->entering add children") # The mask that was used to get the rows for the current node from x, we # need this to update the rows for the children current_mask = node.split_value_or_rows @@ -236,19 +253,47 @@ def add_children(node, x, best_split): # Give the current node the split_value and col it needs for predictions node.split_value_or_rows, node.col = node_split_value, node_col - mask_left, mask_right = update_mask(x, mask_left, mask_right, current_mask) - return [Node(split_value_or_rows=mask_left), Node(split_value_or_rows=mask_right)] + # Updating the row masks to give it to children, keeping numpy dimension consistent + mask_left, mask_right = update_mask(mask_left, current_mask), update_mask(mask_right, current_mask) + + # Adding the link between parent and children + node.left, node.right = Node(split_value_or_rows=mask_left), Node(split_value_or_rows=mask_right) + print("\t\t\t->children added to node and node list\n") + return [node.left, node.right] -def update_mask(x, mask_left, mask_right, current_mask): +def update_mask(mask, current_mask): """ @todo: Docstring for update_mask """ - print(f"{current_mask=}") - print(f"{mask_left=}") - print(f"{mask_right=}") - current_row_no = np.where(current_mask) - print(f"{current_rows=}") - return mask_left, mask_right + print("\t\t\t\t->entering update mask to calculate which rows belong to child") + # current_mask = np.array([ True, True, True, True, True, True, False, False, False, True]) + # print(f"Parent mask: {current_mask=}") + # print(f"The result of bestsplit: {mask=}") + # mask_left=np.array([False, True, False, False, False, True, True]) + # mask_right=np.array([ True, False, True, True, True, False, False]) + copy = np.array(current_mask, copy=True) + copy[np.where(current_mask)] = mask + # print(f"Child mask:{current_mask=}") + + # print(f"{current_mask=}") + # print(f"{mask_right=}") + # current_mask[np.where(current_mask)] = mask_left + # print(f"{current_mask=}") + # mask_right = np.where(current_mask, False, True) + # print(f"{mask_right=}") + + # current_row_indexs = np.where(current_mask, mask_left, mask_right) + # print(f"{current_mask[current_row_indexs]=}") + # current_mask[current_row_indexs] = mask_left + + + # print(f"{current_row_indexs=}") + # print(f"{mask_left=}") + # print(f"{len(mask_left)=}") + # print(f"{mask_right=}") + # print(f"{len(mask_right)=}") + print("\t\t\t\t->updated row mask for child node") + return copy # @@ -283,13 +328,13 @@ def tree_grow(x=None, # etc. totdat alle splits gemaakt zijn en de lijst leeg is. nodelist = [root] - while None not in nodelist: - print(nodelist) + while nodelist: + print("->Taking new node from the node list") # Pop de current node uit de nodelist node = nodelist.pop() - print(nodelist) # Gebruik de boolean mask van de node om de rows in de node uit x te halen node_rows = x[node.split_value_or_rows] + # print(node_rows) # Gebruik de boolean mask van de node om de classes in de node uit y te halen node_classes = y[node.split_value_or_rows] @@ -298,10 +343,13 @@ def tree_grow(x=None, # Test of de node een leaf node is met n_min if len(node_rows) < n_min: node.is_leaf_node(node_classes) + print("\t->Node has less rows than n_min, it is a leaf node, continueing to next node") continue + print("\t->Node has more rows than n_min, it is not a leaf node") # Als de node niet puur is, probeer dan te splitten if impurity(node_classes) > 0: + print("\t->Node is not pure yet starting exhaustive split search") # We gaan exhaustively voor de rows in de node over de cols # (age,married,house,income,gender) om de bestsplit te # bepalen @@ -316,7 +364,9 @@ def tree_grow(x=None, # an de min leaf constraint if not exhaustive_best_list: node.is_leaf_node(node_classes) + print("\t\t->No split that fullfils min_leaf was found continueing to next node") continue + print("\t->Exhaustive search found a split fulfilling the min_leaf rule!") # Hier halen we de beste split, en rows voor de child/split nodes # uit de exhaustive best list best_split = max(exhaustive_best_list, key=lambda z: z[0]) @@ -325,7 +375,9 @@ def tree_grow(x=None, # node.add_split(left_child_node, right_child_node) else: node.is_leaf_node(node_classes) + print("\t\t->The node is already pure, it is necessarily a leaf!") continue + print(tr) return tr # Initiate the nodelist with tuples of slice and class labels |
