diff options
Diffstat (limited to 'tree.py')
| -rw-r--r-- | tree.py | 8 |
1 files changed, 4 insertions, 4 deletions
@@ -1,10 +1,10 @@ import numpy as np -import cProfile -import pstats +# import cProfile +# import pstats # import tqdm # from tqdm import trange -from pstats import SortKey +# from pstats import SortKey from sklearn import metrics # age,married,house,income,gender,class @@ -252,7 +252,7 @@ def update_mask(mask, current_mask): equal to the total number of rows in dataset x. """ copy = np.array(current_mask, copy=True) - copy[np.where(current_mask)] = mask + copy[current_mask == True] = mask return copy |
