diff options
| author | Mike Vink <mike1994vink@gmail.com> | 2020-09-15 11:28:59 +0200 |
|---|---|---|
| committer | Mike Vink <mike1994vink@gmail.com> | 2020-09-15 11:28:59 +0200 |
| commit | bf1356a7d1ee66d2da4f288381d9b4387a7b17a8 (patch) | |
| tree | e58863b5a3756f0e977017c0c31f115ad78c6f17 | |
| parent | ee067768dc96a78cbfd65e941ab8d706fa95127e (diff) | |
add: my idea for data structure, and tree_pred
| -rw-r--r-- | assignment1.py | 70 |
1 files changed, 64 insertions, 6 deletions
diff --git a/assignment1.py b/assignment1.py index 9acd43e..a65c601 100644 --- a/assignment1.py +++ b/assignment1.py @@ -22,16 +22,34 @@ class Tree(): """ @todo: docstring for Tree """ - def __init__(self, tr_d_structure): + def __init__(self, d_structure): """@todo: Docstring for init method. - /tr_d_structure/ @todo + /d_structure/ @todo """ - self.tr_d_structure = tr_d_structure + self.d_structure = d_structure + + def __enter__(self): + return self + + def __exit__(self, exception_type, exception_value, traceback): + print("Tree has been traversed") def __repr__(self): - return str(self.tr_d_structure) + return str(self.d_structure) + + def drop_left(self) -> dict: + """ + @todo: Docstring for drop_left + """ + self.d_structure = self.d_structure["left"] + + def drop_right(self) -> dict: + """ + @todo: Docstring for drop_right + """ + self.d_structure = self.d_structure["right"] tree_grow_defaults = { @@ -65,8 +83,8 @@ def tree_grow(x=None, print( f"Number of features/attributes to be randomly drawn from {x=} to be considered for a split, should only be lower than {len(x[0,:])=} for random forest growing, {n_feat=}" ) - tr_d_structure = {} - return Tree(tr_d_structure) + d_structure = {} + return Tree(d_structure) # Calling the function, unpacking default as argument @@ -84,6 +102,45 @@ def tree_pred(x=None, tr=None, **defaults) -> np.array: """ print("\n\n#########Tree_pred output start:\n") print(f"Drop a row in {x=} down the tree {tr.__repr__()}") + # x should be a set of rows that represent cases to be dropped in the tree, + # we iterate over the numpy rows + y = np.array([]) + for case in x: + # Assumes that tr is a data structure with the following form + # + # {"root":(split, {"left":(...), "right":(...)}) + # + # Where the (...) is recursion of the pattern tuple(tuple, dict) + with Tree(case, tr) as tr: + # Unpack the root node, which is just a key value pair with the form: + # "root": (split, dict(subtree)) + split, tr.d_structure = tr.d_structure["root"] + # Unpack the split info we need to do the first split + # c is the split value for the drop, and col is the attribute/feature + # we are dropping on + col, c = split + while isinstance(tr.d_structure, dict): + # The drop methods return a tuple by giving the current + # d_structure a key that is either "left" or right": + # def drop_left(self): + # ... + # return tr.d_structure["left"] + # + # Don't even need to check for leaf node, since the while loop + # should do that. + if case[col] > c: + split, tr.d_structure = tr.drop_right() + elif case[col] <= c: + split, tr.d_structure = tr.drop_left() + split = col, c + # Assumes that the tr.d_structure is the leaf node that is just the + # majority(!) class label, which is just the integer 1 or 0. + # + # This is also the reason we break out of the while loop, when we + # arrive to a leaf node with the drop methods, the tr.d_structure + # type is no longer a dict, but an int. + y.append(tr.d_structure) + return y tree_pred(**tree_pred_defaults) @@ -92,6 +149,7 @@ tree_pred(**tree_pred_defaults) # # Put all helper functions below this comment! + def impurity(array) -> int: """ Assumes the argument array is a one dimensional vector of zeroes and ones. |
