summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--assignment1.py12
1 files changed, 11 insertions, 1 deletions
diff --git a/assignment1.py b/assignment1.py
index 67e3c87..8bdacce 100644
--- a/assignment1.py
+++ b/assignment1.py
@@ -30,6 +30,7 @@ class Tree():
"""
self.d_structure = d_structure
+ # Not really used at the moment
def __enter__(self):
return self
@@ -39,6 +40,13 @@ class Tree():
def __repr__(self):
return str(self.d_structure)
+ def clone_tree(self):
+ """
+ @todo: Docstring for clone_tree
+ """
+ # Clone for predicting a case
+ return Tree(self.d_structure)
+
def drop_left(self) -> dict:
"""
@todo: Docstring for drop_left
@@ -113,7 +121,9 @@ def tree_pred(x=None, tr=None, **defaults) -> np.array:
# {"root":(split, {"left":(...), "right":(...)})
#
# Where the (...) is recursion of the pattern tuple(tuple, dict)
- with Tree(case, tr) as tr:
+
+ # Don't know if copying is the best
+ with tr.clone_tree() as tr:
# Unpack the root node, which is just a key value pair with the form:
# "root": (split, dict(subtree))
split, tr.d_structure = tr.d_structure["root"]