Lesson 21 - Decision Tree Structure

The following topics are discussed in this notebook:

  • Tips for implementing a Decision Tree class.
In [1]:
import numpy as np
import matplotlib.pyplot as plt

Random Tree Class

We will provide a partial implementation of a classifier that operates in a manner similar to a decision tree, but in which the cuts are selected at random.

In [2]:
class RandomTree:
    
    def __init__(self, X, y, max_depth=2, depth=0, min_leaf_size=1):
        
        self.X = np.array(X)
        self.y = np.array(y)
        self.n_obs = len(y)
        self.depth = depth
        
        self.classes, self.class_counts = np.unique(self.y, return_counts=True)    
        self.prediction = np.argmax(self.class_counts)
        
        # self.gini 
        
        if depth == max_depth:
            self.axis = None
            self.t = None
            self.left = None
            self.right = None
            return
        
        # Randomly split points (Rather than finding the best cut)
        self.axis = np.random.choice(range(self.X.shape[1]))
        self.t = np.random.uniform(np.min(X[:,self.axis]), np.max(X[:,self.axis]))
        
        # Create selection array for instances below the cut
        sel = X[:,self.axis] <= self.t
        
        # Determine if cut can be made
        if (np.sum(sel) < min_leaf_size) or (np.sum(~sel) < min_leaf_size):
            self.axis = None
            self.t = None
            self.left = None
            self.right = None
            return
        
        self.left = RandomTree( X[sel,:], y[sel], max_depth, depth+1, min_leaf_size)
        self.right = RandomTree( X[~sel,:], y[~sel], max_depth, depth+1, min_leaf_size)
        
        
    def print_tree(self):
        
        msg = '  ' * self.depth + '* n = ' + str(self.n) + ', axis = ' + str(self.axis)
        if(self.t != None):
            msg += ', Cut = ' + str( round(self.t,2) )
        
        print(msg)
        
        if self.left != None:
            self.left.print_tree()
            self.right.print_tree()
    
    def print_tree(self):
        
        msg = '  ' * self.depth + '* '
        msg += 'Size: ' + str(self.n_obs) + ' '
        msg += str(list(self.class_counts))
        #msg += ', Gini: ' + str(round(self.gini,2))
                
        if(self.t != None):
            msg += ', Axis:' + str(self.axis)
            msg += ', Cut: ' + str(round(self.t,2))
        else:
            msg += ', Predicted Class: ' + str(self.classes[self.prediction])
        
        print(msg)
                        
        if(self.left != None):
            self.left.print_tree()
            self.right.print_tree()
            
            
#    def predict_single(self, obs):
#       obs = np.array(obs)
#        
#       if at leaf node:
#           return majority class
#
#       if observation is on left side of cut:
#           return prediction generated by left child
#
#       if observation is on right side of cut:
#           return prediction generated by right child    
            

Example

In [3]:
np.random.seed(1)

n = 1000
X = np.random.uniform(0,10,5*n).reshape(n,5)
y = np.random.choice([0,1,2],n)

tree = RandomTree(X, y, max_depth=2)
tree.print_tree()
* Size: 1000 [300, 369, 331], Axis:2, Cut: 5.71
  * Size: 584 [178, 211, 195], Axis:3, Cut: 3.44
    * Size: 195 [57, 70, 68], Predicted Class: 1
    * Size: 389 [121, 141, 127], Predicted Class: 1
  * Size: 416 [122, 158, 136], Axis:0, Cut: 3.07
    * Size: 126 [34, 53, 39], Predicted Class: 1
    * Size: 290 [88, 105, 97], Predicted Class: 1
In [4]:
tree = RandomTree(X, y, max_depth=5, min_leaf_size=5)
tree.print_tree()
* Size: 1000 [300, 369, 331], Axis:0, Cut: 9.22
  * Size: 925 [283, 337, 305], Axis:0, Cut: 3.68
    * Size: 378 [118, 144, 116], Axis:0, Cut: 3.52
      * Size: 363 [113, 140, 110], Axis:1, Cut: 4.74
        * Size: 170 [58, 60, 52], Axis:0, Cut: 2.74
          * Size: 137 [42, 54, 41], Predicted Class: 1
          * Size: 33 [16, 6, 11], Predicted Class: 0
        * Size: 193 [55, 80, 58], Axis:1, Cut: 9.4
          * Size: 172 [51, 71, 50], Predicted Class: 1
          * Size: 21 [4, 9, 8], Predicted Class: 1
      * Size: 15 [5, 4, 6], Axis:4, Cut: 7.63
        * Size: 9 [2, 3, 4], Predicted Class: 2
        * Size: 6 [3, 1, 2], Predicted Class: 0
    * Size: 547 [165, 193, 189], Axis:0, Cut: 8.95
      * Size: 521 [161, 183, 177], Axis:3, Cut: 5.47
        * Size: 288 [79, 99, 110], Axis:3, Cut: 0.6
          * Size: 24 [7, 6, 11], Predicted Class: 2
          * Size: 264 [72, 93, 99], Predicted Class: 2
        * Size: 233 [82, 84, 67], Axis:0, Cut: 6.79
          * Size: 151 [57, 52, 42], Predicted Class: 0
          * Size: 82 [25, 32, 25], Predicted Class: 1
      * Size: 26 [4, 10, 12], Axis:0, Cut: 9.08
        * Size: 17 [3, 5, 9], Axis:3, Cut: 3.77
          * Size: 8 [1, 3, 4], Predicted Class: 2
          * Size: 9 [2, 2, 5], Predicted Class: 2
        * Size: 9 [1, 5, 3], Predicted Class: 1
  * Size: 75 [17, 32, 26], Axis:2, Cut: 2.29
    * Size: 13 [1, 6, 6], Predicted Class: 1
    * Size: 62 [16, 26, 20], Axis:3, Cut: 8.67
      * Size: 51 [14, 20, 17], Axis:1, Cut: 6.13
        * Size: 34 [11, 12, 11], Axis:2, Cut: 6.4
          * Size: 18 [4, 8, 6], Predicted Class: 1
          * Size: 16 [7, 4, 5], Predicted Class: 0
        * Size: 17 [3, 8, 6], Axis:2, Cut: 3.63
          * Size: 5 [5], Predicted Class: 1
          * Size: 12 [3, 3, 6], Predicted Class: 2
      * Size: 11 [2, 6, 3], Axis:1, Cut: 3.23
        * Size: 5 [1, 3, 1], Predicted Class: 1
        * Size: 6 [1, 3, 2], Predicted Class: 1