diff options
| author | Mike Vink <mike1994vink@gmail.com> | 2020-09-15 12:38:59 +0200 |
|---|---|---|
| committer | Mike Vink <mike1994vink@gmail.com> | 2020-09-15 12:38:59 +0200 |
| commit | 220afa88dc69c1eb59a1ba3c75a6936f40bc156a (patch) | |
| tree | bc53dd60c357be8c8651a9be8054e2d72a7c7970 | |
| parent | 6cb7bea9a1ceba7441c3aaa0e0b28a8bd8730425 (diff) | |
add: Working tree_pred
| -rw-r--r-- | assignment1.py | 73 |
1 files changed, 31 insertions, 42 deletions
diff --git a/assignment1.py b/assignment1.py index 8bdacce..4134270 100644 --- a/assignment1.py +++ b/assignment1.py @@ -30,36 +30,24 @@ class Tree(): """ self.d_structure = d_structure - # Not really used at the moment - def __enter__(self): - return self - - def __exit__(self, exception_type, exception_value, traceback): - print("Tree has been traversed") - def __repr__(self): return str(self.d_structure) - def clone_tree(self): - """ - @todo: Docstring for clone_tree - """ - # Clone for predicting a case - return Tree(self.d_structure) - def drop_left(self) -> dict: """ @todo: Docstring for drop_left """ # Need to change - self.d_structure = self.d_structure["left"] + print("Dropping left:\n", self.prediction["left"]) + return self.prediction["left"] def drop_right(self) -> dict: """ @todo: Docstring for drop_right """ # Need to change - self.d_structure = self.d_structure["right"] + print("Dropping right:\n", self.prediction["right"]) + return self.prediction["right"] tree_grow_defaults = { @@ -93,7 +81,7 @@ def tree_grow(x=None, print( f"Number of features/attributes to be randomly drawn from {x=} to be considered for a split, should only be lower than {len(x[0,:])=} for random forest growing, {n_feat=}" ) - d_structure = {} + d_structure = {"root": ((0,30), {"left": (None, 1), "right": (None, 0)})} # dummy data return Tree(d_structure) @@ -122,36 +110,37 @@ def tree_pred(x=None, tr=None, **defaults) -> np.array: # # Where the (...) is recursion of the pattern tuple(tuple, dict) - # Don't know if copying is the best - with tr.clone_tree() as tr: - # Unpack the root node, which is just a key value pair with the form: - # "root": (split, dict(subtree)) - split, tr.d_structure = tr.d_structure["root"] - # Unpack the split info we need to do the first split - # c is the split value for the drop, and col is the attribute/feature - # we are dropping on + # Copy the data structure of the tree, not very efficient? + tr.prediction = tr.d_structure + + split, tr.prediction = tr.prediction["root"] + # Unpack the split info we need to do the first split + # c is the split value for the drop, and col is the attribute/feature + # we are dropping on + col, c = split + while isinstance(tr.prediction, dict): + # The drop methods return a tuple by giving the current + # d_structure a key that is either "left" or right": + # def drop_left(self): + # ... + # return tr.prediction["left"] + # + # Don't even need to check for leaf node, since the while loop + # should do that. col, c = split - while isinstance(tr.d_structure, dict): - # The drop methods return a tuple by giving the current - # d_structure a key that is either "left" or right": - # def drop_left(self): - # ... - # return tr.d_structure["left"] - # - # Don't even need to check for leaf node, since the while loop - # should do that. - if case[col] > c: - split, tr.d_structure = tr.drop_right() - elif case[col] <= c: - split, tr.d_structure = tr.drop_left() - split = col, c - # Assumes that the tr.d_structure is the leaf node that is just the - # majority(!) class label, which is just the integer 1 or 0. + if case[col] > c: + split, tr.prediction = tr.drop_right() + elif case[col] <= c: + split, tr.prediction = tr.drop_left() + # Assumes that the leaf nodes in tr.d_structure is the leaf node + # that is just the majority(!) class label, which is just the integer 1 or 0. # # This is also the reason we break out of the while loop, when we # arrive to a leaf node with the drop methods, the tr.d_structure # type is no longer a dict, but an int. - y.append(tr.d_structure) + # print(tr.prediction) + y = np.append(y, tr.prediction) + print("Predictions from the tree:\n", y) return y |
