import numpy as np
import matplotlib.pyplot as plt
We will provide a partial implementation of a classifier that operates in a manner similar to a decision tree, but in which the cuts are selected at random.
class RandomTree:
def __init__(self, X, y, max_depth=2, depth=0, min_leaf_size=1):
self.X = np.array(X)
self.y = np.array(y)
self.n_obs = len(y)
self.depth = depth
self.classes, self.class_counts = np.unique(self.y, return_counts=True)
self.prediction = np.argmax(self.class_counts)
# self.gini
if depth == max_depth:
self.axis = None
self.t = None
self.left = None
self.right = None
return
# Randomly split points (Rather than finding the best cut)
self.axis = np.random.choice(range(self.X.shape[1]))
self.t = np.random.uniform(np.min(X[:,self.axis]), np.max(X[:,self.axis]))
# Create selection array for instances below the cut
sel = X[:,self.axis] <= self.t
# Determine if cut can be made
if (np.sum(sel) < min_leaf_size) or (np.sum(~sel) < min_leaf_size):
self.axis = None
self.t = None
self.left = None
self.right = None
return
self.left = RandomTree( X[sel,:], y[sel], max_depth, depth+1, min_leaf_size)
self.right = RandomTree( X[~sel,:], y[~sel], max_depth, depth+1, min_leaf_size)
def print_tree(self):
msg = ' ' * self.depth + '* n = ' + str(self.n) + ', axis = ' + str(self.axis)
if(self.t != None):
msg += ', Cut = ' + str( round(self.t,2) )
print(msg)
if self.left != None:
self.left.print_tree()
self.right.print_tree()
def print_tree(self):
msg = ' ' * self.depth + '* '
msg += 'Size: ' + str(self.n_obs) + ' '
msg += str(list(self.class_counts))
#msg += ', Gini: ' + str(round(self.gini,2))
if(self.t != None):
msg += ', Axis:' + str(self.axis)
msg += ', Cut: ' + str(round(self.t,2))
else:
msg += ', Predicted Class: ' + str(self.classes[self.prediction])
print(msg)
if(self.left != None):
self.left.print_tree()
self.right.print_tree()
# def predict_single(self, obs):
# obs = np.array(obs)
#
# if at leaf node:
# return majority class
#
# if observation is on left side of cut:
# return prediction generated by left child
#
# if observation is on right side of cut:
# return prediction generated by right child
np.random.seed(1)
n = 1000
X = np.random.uniform(0,10,5*n).reshape(n,5)
y = np.random.choice([0,1,2],n)
tree = RandomTree(X, y, max_depth=2)
tree.print_tree()
tree = RandomTree(X, y, max_depth=5, min_leaf_size=5)
tree.print_tree()