Quantcast
Channel: How to extract sklearn decision tree rules to pandas boolean conditions? - Stack Overflow
Viewing all articles
Browse latest Browse all 4

Answer by vlemaistre for How to extract sklearn decision tree rules to pandas boolean conditions?

$
0
0

First of all let's use the scikit documentation on decision tree structure to get information about the tree that was constructed :

n_nodes = clf.tree_.node_countchildren_left = clf.tree_.children_leftchildren_right = clf.tree_.children_rightfeature = clf.tree_.featurethreshold = clf.tree_.threshold

We then define two recursive functions. The first one will find the path from the tree's root to create a specific node (all the leaves in our case). The second one will write the specific rules used to create a node using its creation path :

def find_path(node_numb, path, x):        path.append(node_numb)        if node_numb == x:            return True        left = False        right = False        if (children_left[node_numb] !=-1):            left = find_path(children_left[node_numb], path, x)        if (children_right[node_numb] !=-1):            right = find_path(children_right[node_numb], path, x)        if left or right :            return True        path.remove(node_numb)        return Falsedef get_rule(path, column_names):    mask = ''    for index, node in enumerate(path):        #We check if we are not in the leaf        if index!=len(path)-1:            # Do we go under or over the threshold ?            if (children_left[node] == path[index+1]):                mask += "(df['{}']<= {}) \t ".format(column_names[feature[node]], threshold[node])            else:                mask += "(df['{}']> {}) \t ".format(column_names[feature[node]], threshold[node])    # We insert the & at the right places    mask = mask.replace("\t", "&", mask.count("\t") - 1)    mask = mask.replace("\t", "")    return mask

Finally, we use those two functions to first store the creation path of each leaf. And then to store the rules used to create each leaf :

# Leavesleave_id = clf.apply(X_test)paths ={}for leaf in np.unique(leave_id):    path_leaf = []    find_path(0, path_leaf, leaf)    paths[leaf] = np.unique(np.sort(path_leaf))rules = {}for key in paths:    rules[key] = get_rule(paths[key], pima.columns)

With the data you gave the output is :

rules ={3: "(df['insulin']<= 127.5) & (df['bp']<= 26.450000762939453) & (df['bp']<= 9.100000381469727)  ", 4: "(df['insulin']<= 127.5) & (df['bp']<= 26.450000762939453) & (df['bp']> 9.100000381469727)  ", 6: "(df['insulin']<= 127.5) & (df['bp']> 26.450000762939453) & (df['skin']<= 27.5)  ", 7: "(df['insulin']<= 127.5) & (df['bp']> 26.450000762939453) & (df['skin']> 27.5)  ", 10: "(df['insulin']> 127.5) & (df['bp']<= 28.149999618530273) & (df['insulin']<= 145.5)  ", 11: "(df['insulin']> 127.5) & (df['bp']<= 28.149999618530273) & (df['insulin']> 145.5)  ", 13: "(df['insulin']> 127.5) & (df['bp']> 28.149999618530273) & (df['insulin']<= 158.5)  ", 14: "(df['insulin']> 127.5) & (df['bp']> 28.149999618530273) & (df['insulin']> 158.5)  "}

Since the rules are strings, you can't directly call them using df[rules[3]], you have to use the eval function like so df[eval(rules[3])]


Viewing all articles
Browse latest Browse all 4

Trending Articles



<script src="https://jsc.adskeeper.com/r/s/rssing.com.1596347.js" async> </script>