summaryrefslogtreecommitdiff
path: root/assignment1.py
diff options
context:
space:
mode:
Diffstat (limited to 'assignment1.py')
-rw-r--r--assignment1.py70
1 files changed, 64 insertions, 6 deletions
diff --git a/assignment1.py b/assignment1.py
index 9acd43e..a65c601 100644
--- a/assignment1.py
+++ b/assignment1.py
@@ -22,16 +22,34 @@ class Tree():
"""
@todo: docstring for Tree
"""
- def __init__(self, tr_d_structure):
+ def __init__(self, d_structure):
"""@todo: Docstring for init method.
- /tr_d_structure/ @todo
+ /d_structure/ @todo
"""
- self.tr_d_structure = tr_d_structure
+ self.d_structure = d_structure
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exception_type, exception_value, traceback):
+ print("Tree has been traversed")
def __repr__(self):
- return str(self.tr_d_structure)
+ return str(self.d_structure)
+
+ def drop_left(self) -> dict:
+ """
+ @todo: Docstring for drop_left
+ """
+ self.d_structure = self.d_structure["left"]
+
+ def drop_right(self) -> dict:
+ """
+ @todo: Docstring for drop_right
+ """
+ self.d_structure = self.d_structure["right"]
tree_grow_defaults = {
@@ -65,8 +83,8 @@ def tree_grow(x=None,
print(
f"Number of features/attributes to be randomly drawn from {x=} to be considered for a split, should only be lower than {len(x[0,:])=} for random forest growing, {n_feat=}"
)
- tr_d_structure = {}
- return Tree(tr_d_structure)
+ d_structure = {}
+ return Tree(d_structure)
# Calling the function, unpacking default as argument
@@ -84,6 +102,45 @@ def tree_pred(x=None, tr=None, **defaults) -> np.array:
"""
print("\n\n#########Tree_pred output start:\n")
print(f"Drop a row in {x=} down the tree {tr.__repr__()}")
+ # x should be a set of rows that represent cases to be dropped in the tree,
+ # we iterate over the numpy rows
+ y = np.array([])
+ for case in x:
+ # Assumes that tr is a data structure with the following form
+ #
+ # {"root":(split, {"left":(...), "right":(...)})
+ #
+ # Where the (...) is recursion of the pattern tuple(tuple, dict)
+ with Tree(case, tr) as tr:
+ # Unpack the root node, which is just a key value pair with the form:
+ # "root": (split, dict(subtree))
+ split, tr.d_structure = tr.d_structure["root"]
+ # Unpack the split info we need to do the first split
+ # c is the split value for the drop, and col is the attribute/feature
+ # we are dropping on
+ col, c = split
+ while isinstance(tr.d_structure, dict):
+ # The drop methods return a tuple by giving the current
+ # d_structure a key that is either "left" or right":
+ # def drop_left(self):
+ # ...
+ # return tr.d_structure["left"]
+ #
+ # Don't even need to check for leaf node, since the while loop
+ # should do that.
+ if case[col] > c:
+ split, tr.d_structure = tr.drop_right()
+ elif case[col] <= c:
+ split, tr.d_structure = tr.drop_left()
+ split = col, c
+ # Assumes that the tr.d_structure is the leaf node that is just the
+ # majority(!) class label, which is just the integer 1 or 0.
+ #
+ # This is also the reason we break out of the while loop, when we
+ # arrive to a leaf node with the drop methods, the tr.d_structure
+ # type is no longer a dict, but an int.
+ y.append(tr.d_structure)
+ return y
tree_pred(**tree_pred_defaults)
@@ -92,6 +149,7 @@ tree_pred(**tree_pred_defaults)
#
# Put all helper functions below this comment!
+
def impurity(array) -> int:
"""
Assumes the argument array is a one dimensional vector of zeroes and ones.