Package ML :: Package DecTree :: Module ID3
[hide private]
[frames] | no frames]

Source Code for Module ML.DecTree.ID3

  1  # 
  2  #  Copyright (C) 2000,2003  greg Landrum and Rational Discovery LLC 
  3  # 
  4  """ ID3 Decision Trees 
  5   
  6    contains an implementation of the ID3 decision tree algorithm 
  7    as described in Tom Mitchell's book "Machine Learning" 
  8   
  9    It relies upon the _Tree.TreeNode_ data structure (or something 
 10      with the same API) defined locally to represent the trees  
 11   
 12  """ 
 13   
 14  from Numeric import * 
 15  from ML.DecTree import DecTree 
 16  from ML.InfoTheory import entropy 
 17   
18 -def CalcTotalEntropy(examples,nPossibleVals):
19 """ Calculates the total entropy of the data set (w.r.t. the results) 20 21 **Arguments** 22 23 - examples: a list (nInstances long) of lists of variable values + instance 24 values 25 - nPossibleVals: a list (nVars long) of the number of possible values each variable 26 can adopt. 27 28 **Returns** 29 30 a float containing the informational entropy of the data set. 31 32 """ 33 nRes = nPossibleVals[-1] 34 resList = zeros(nRes,Int) 35 for example in examples: 36 res = example[-1] 37 resList[res] = resList[res] + 1 38 return entropy.InfoEntropy(resList)
39
40 -def GenVarTable(examples,nPossibleVals,vars):
41 """Generates a list of variable tables for the examples passed in. 42 43 The table for a given variable records the number of times each possible value 44 of that variable appears for each possible result of the function. 45 46 **Arguments** 47 48 - examples: a list (nInstances long) of lists of variable values + instance 49 values 50 51 - nPossibleVals: a list containing the number of possible values of 52 each variable + the number of values of the function. 53 54 - vars: a list of the variables to include in the var table 55 56 57 **Returns** 58 59 a list of variable result tables. Each table is a Numeric array 60 which is varValues x nResults 61 """ 62 nVars = len(vars) 63 res = [None]*nVars 64 nFuncVals = nPossibleVals[-1] 65 66 for i in xrange(nVars): 67 res[i] = zeros((nPossibleVals[vars[i]],nFuncVals),Int) 68 for example in examples: 69 val = int(example[-1]) 70 for i in xrange(nVars): 71 res[i][int(example[vars[i]]),val] += 1 72 73 return res
74
75 -def ID3(examples,target,attrs,nPossibleVals,depth=0,maxDepth=-1, 76 **kwargs):
77 """ Implements the ID3 algorithm for constructing decision trees. 78 79 From Mitchell's book, page 56 80 81 This is *slightly* modified from Mitchell's book because it supports 82 multivalued (non-binary) results. 83 84 **Arguments** 85 86 - examples: a list (nInstances long) of lists of variable values + instance 87 values 88 89 - target: an int 90 91 - attrs: a list of ints indicating which variables can be used in the tree 92 93 - nPossibleVals: a list containing the number of possible values of 94 every variable. 95 96 - depth: (optional) the current depth in the tree 97 98 - maxDepth: (optional) the maximum depth to which the tree 99 will be grown 100 101 **Returns** 102 103 a DecTree.DecTreeNode with the decision tree 104 105 **NOTE:** This code cannot bootstrap (start from nothing...) 106 use _ID3Boot_ (below) for that. 107 """ 108 varTable = GenVarTable(examples,nPossibleVals,attrs) 109 tree=DecTree.DecTreeNode(None,'node') 110 111 # store the total entropy... in case that is interesting 112 totEntropy = CalcTotalEntropy(examples,nPossibleVals) 113 tree.SetData(totEntropy) 114 #tree.SetExamples(examples) 115 116 # the matrix of results for this target: 117 tMat = GenVarTable(examples,nPossibleVals,[target])[0] 118 # counts of each result code: 119 counts = sum(tMat) 120 nzCounts = nonzero(counts) 121 122 if len(nzCounts) == 1: 123 # bottomed out because there is only one result code left 124 # with any counts (i.e. there's only one type of example 125 # left... this is GOOD!). 126 res = nzCounts[0] 127 tree.SetLabel(res) 128 tree.SetName(str(res)) 129 tree.SetTerminal(1) 130 elif len(attrs) == 0 or (maxDepth>=0 and depth>=maxDepth): 131 # Bottomed out: no variables left or max depth hit 132 # We don't really know what to do here, so 133 # use the heuristic of picking the most prevalent 134 # result 135 v = argmax(counts) 136 tree.SetLabel(v) 137 tree.SetName('%d?'%v) 138 tree.SetTerminal(1) 139 else: 140 # find the variable which gives us the largest information gain 141 142 gains = map(lambda x: entropy.InfoGain(x),varTable) 143 best = attrs[argmax(gains)] 144 145 146 # remove that variable from the lists of possible variables 147 nextAttrs = attrs[:] 148 if not kwargs.get('recycleVars',0): 149 nextAttrs.remove(best) 150 151 # set some info at this node 152 tree.SetName('Var: %d'%best) 153 tree.SetLabel(best) 154 #tree.SetExamples(examples) 155 tree.SetTerminal(0) 156 157 # loop over possible values of the new variable and 158 # build a subtree for each one 159 for val in xrange(nPossibleVals[best]): 160 nextExamples = [] 161 for example in examples: 162 if example[best] == val: 163 nextExamples.append(example) 164 if len(nextExamples) == 0: 165 # this particular value of the variable has no examples, 166 # so there's not much sense in recursing. 167 # This can (and does) happen. 168 v = argmax(counts) 169 tree.AddChild('%d'%v,label=v,data=0.0,isTerminal=1) 170 else: 171 # recurse 172 tree.AddChildNode(ID3(nextExamples,best,nextAttrs,nPossibleVals,depth+1,maxDepth, 173 **kwargs)) 174 return tree
175
176 -def ID3Boot(examples,attrs,nPossibleVals,initialVar=None,depth=0,maxDepth=-1, 177 **kwargs):
178 """ Bootstrapping code for the ID3 algorithm 179 180 see ID3 for descriptions of the arguments 181 182 If _initialVar_ is not set, the algorithm will automatically 183 choose the first variable in the tree (the standard greedy 184 approach). Otherwise, _initialVar_ will be used as the first 185 split. 186 187 """ 188 totEntropy = CalcTotalEntropy(examples,nPossibleVals) 189 varTable = GenVarTable(examples,nPossibleVals,attrs) 190 191 tree=DecTree.DecTreeNode(None,'node') 192 #tree.SetExamples(examples) 193 tree._nResultCodes = nPossibleVals[-1] 194 195 # <perl>you've got to love any language which will let you 196 # do this much work in a single line :-)</perl> 197 if initialVar is None: 198 best = attrs[argmax(map(lambda x: entropy.InfoGain(x),varTable))] 199 else: 200 best = initialVar 201 202 tree.SetName('Var: %d'%best) 203 tree.SetData(totEntropy) 204 tree.SetLabel(best) 205 tree.SetTerminal(0) 206 nextAttrs = attrs[:] 207 if not kwargs.get('recycleVars',0): 208 nextAttrs.remove(best) 209 210 for val in xrange(nPossibleVals[best]): 211 nextExamples = [] 212 for example in examples: 213 if example[best] == val: 214 nextExamples.append(example) 215 216 tree.AddChildNode(ID3(nextExamples,best,nextAttrs,nPossibleVals,depth,maxDepth, 217 **kwargs)) 218 return tree
219 220
221 -def TestMultiTree():
222 """Testing code for generating trees with more than 2 possible results 223 224 """ 225 from ML.Data import MLData 226 print 'Testing MultiValue Tree Construction' 227 examples = [[0,1,0,0], 228 [0,0,0,1], 229 [0,0,1,2], 230 [0,1,1,2], 231 [1,0,0,2], 232 [1,0,1,2], 233 [1,1,0,2], 234 [1,1,1,0] 235 ] 236 data = MLData.MLQuantDataSet(examples) 237 attrs = range(0,data.GetNVars()) 238 t1 = ID3Boot(data.GetAllData(),attrs,data.GetNPossibleVals()) 239 #t1.Print() 240 t1.Pickle('multi.pkl') 241 242 print 'Testing Pickle Load' 243 import cPickle 244 f = open('regress/MultiTreeRes.pkl','r') 245 t2 = cPickle.load(f) 246 print 'Testing Correctness' 247 assert t1 == t2,'Equality Test Failed' 248 249 print 'All Tests Passed!'
250
251 -def TestTree():
252 """Testing code for trees with a single possible result 253 254 """ 255 from ML.Data import MLData 256 257 print 'Testing Tree Construction' 258 examples = [[0,0,0,0,0], 259 [0,0,0,1,0], 260 [1,0,0,0,1], 261 [2,1,0,0,1], 262 [2,2,1,0,1], 263 [2,2,1,1,0], 264 [1,2,1,1,1], 265 [0,1,0,0,0], 266 [0,2,1,0,1], 267 [2,1,1,0,1], 268 [0,1,1,1,1], 269 [1,1,0,1,1], 270 [1,0,1,0,1], 271 [2,1,0,1,0] 272 ] 273 274 data = MLData.MLQuantDataSet(examples) 275 attrs = range(0,data.GetNVars()) 276 t1 = ID3Boot(data.GetAllData(),attrs,data.GetNPossibleVals()) 277 278 print 'Testing Tree Validity' 279 t2 = DecTree.DecTreeNode(None,'Var: 0',0) 280 281 c = DecTree.DecTreeNode(t2,'Var: 2',2) 282 t2.AddChildNode(c) 283 c2 = DecTree.DecTreeNode(c,'0',0,isTerminal=1) 284 c.AddChildNode(c2) 285 c2 = DecTree.DecTreeNode(c,'1',1,isTerminal=1) 286 c.AddChildNode(c2) 287 288 c = DecTree.DecTreeNode(t2,'1',1,isTerminal=1) 289 t2.AddChildNode(c) 290 291 c = DecTree.DecTreeNode(t2,'Var: 3',3) 292 t2.AddChildNode(c) 293 c2 = DecTree.DecTreeNode(c,'1',1,isTerminal=1) 294 c.AddChildNode(c2) 295 c2 = DecTree.DecTreeNode(c,'0',0,isTerminal=1) 296 c.AddChildNode(c2) 297 298 assert t2==t1,'Trees do not match' 299 #print 'Testing Printing' 300 #t1.Print(showData=1) 301 print 'Testing Pickle' 302 t1.Pickle('save.pkl') 303 print 'Classification Tests:' 304 assert t1.ClassifyExample(examples[0])==examples[0][-1],'Example 0 misclassified' 305 assert t1.ClassifyExample(examples[1])==examples[1][-1],'Example 1 misclassified' 306 assert t1.ClassifyExample(examples[6])==examples[6][-1],'Example 6 misclassified' 307 308 print 'Testing Copy' 309 import copy 310 t2 = copy.deepcopy(t1) 311 assert t1==t2,'copy failed' 312 print 'Testing Set Membership' 313 l = [t1] 314 assert t2 in l,'Set Membership failed' 315 #print 't2 in [t1]', t2 in l, 'index:',l.index(t2) 316 print 'All tests passed!'
317
318 -def TestNamedTree():
319 """ testing code for named trees 320 321 """ 322 from ML.Data import MLData 323 print 'Testing Named Tree Construction' 324 examples = [[0,1,0,0], 325 [0,0,0,1], 326 [0,0,1,2], 327 [0,1,1,2], 328 [1,0,0,2], 329 [1,0,1,2], 330 [1,1,0,2], 331 [1,1,1,0] 332 ] 333 names = ['ex1','ex2','ex3','ex4','ex5','ex6','ex7','ex8'] 334 data = MLData.MLQuantDataSet(examples,ptNames=names) 335 attrs = range(1,data.GetNVars()+1) 336 t1 = ID3Boot(data.GetNamedData(),attrs,[0]+data.GetNPossibleVals()) 337 print 'All tests passed!'
338 339 340 if __name__ == "__main__": 341 TestTree() 342 TestMultiTree() 343 TestNamedTree() 344