summaryrefslogtreecommitdiff
path: root/tree.py
diff options
context:
space:
mode:
Diffstat (limited to 'tree.py')
-rw-r--r--tree.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/tree.py b/tree.py
index e971bd4..3a073ac 100644
--- a/tree.py
+++ b/tree.py
@@ -1,10 +1,10 @@
import numpy as np
-import cProfile
-import pstats
+# import cProfile
+# import pstats
# import tqdm
# from tqdm import trange
-from pstats import SortKey
+# from pstats import SortKey
from sklearn import metrics
# age,married,house,income,gender,class
@@ -252,7 +252,7 @@ def update_mask(mask, current_mask):
equal to the total number of rows in dataset x.
"""
copy = np.array(current_mask, copy=True)
- copy[np.where(current_mask)] = mask
+ copy[current_mask == True] = mask
return copy