Package ML :: Package DecTree :: Module BuildQuantTree
[hide private]
[frames] | no frames]

Source Code for Module ML.DecTree.BuildQuantTree

  1  # $Id: BuildQuantTree.py 2 2006-05-06 22:54:39Z glandrum $ 
  2  # 
  3  #  Copyright (C) 2001-2004  greg Landrum and Rational Discovery LLC 
  4  #  All Rights Reserved 
  5  # 
  6  """  
  7   
  8  """ 
  9   
 10  from Numeric import * 
 11  import RandomArray 
 12  from ML.DecTree import QuantTree, ID3 
 13  from ML.InfoTheory import entropy 
 14  from ML.Data import Quantize 
 15   
16 -def FindBest(resCodes,examples,nBoundsPerVar,nPossibleRes, 17 nPossibleVals,attrs,**kwargs):
18 bestGain =-1e6 19 best = -1 20 bestBounds = [] 21 22 if not len(examples): 23 return best,bestGain,bestBounds 24 25 nToTake = kwargs.get('randomDescriptors',0) 26 if nToTake > 0: 27 nAttrs = len(attrs) 28 if nToTake < nAttrs: 29 ids = RandomArray.permutation(nAttrs) 30 tmp = [attrs[x] for x in ids[:nToTake]] 31 #print '\tavail:',tmp 32 attrs = tmp 33 34 for var in attrs: 35 nBounds = nBoundsPerVar[var] 36 if nBounds > 0: 37 #vTable = map(lambda x,z=var:x[z],examples) 38 try: 39 vTable = [x[var] for x in examples] 40 except IndexError: 41 print 'index error retrieving variable: %d'%var 42 raise 43 qBounds,gainHere = Quantize.FindVarMultQuantBounds(vTable,nBounds, 44 resCodes,nPossibleRes) 45 #print '\tvar:',var,qBounds,gainHere 46 elif nBounds==0: 47 vTable = ID3.GenVarTable(examples,nPossibleVals,[var])[0] 48 gainHere = entropy.InfoGain(vTable) 49 qBounds = [] 50 else: 51 gainHere = -1e6 52 qBounds = [] 53 if gainHere > bestGain: 54 bestGain = gainHere 55 bestBounds = qBounds 56 best = var 57 if best == -1: 58 print 'best unaltered' 59 print '\tattrs:',attrs 60 print '\tnBounds:',take(nBoundsPerVar,attrs) 61 print '\texamples:' 62 for example in examples: 63 print '\t\t',example 64 65 return best,bestGain,bestBounds
66 67
68 -def BuildQuantTree(examples,target,attrs,nPossibleVals,nBoundsPerVar, 69 depth=0,maxDepth=-1,**kwargs):
70 """ 71 **Arguments** 72 73 - examples: a list of lists (nInstances x nVariables+1) of variable 74 values + instance values 75 76 - target: an int 77 78 - attrs: a list of ints indicating which variables can be used in the tree 79 80 - nPossibleVals: a list containing the number of possible values of 81 every variable. 82 83 - nBoundsPerVar: the number of bounds to include for each variable 84 85 - depth: (optional) the current depth in the tree 86 87 - maxDepth: (optional) the maximum depth to which the tree 88 will be grown 89 **Returns** 90 91 a QuantTree.QuantTreeNode with the decision tree 92 93 **NOTE:** This code cannot bootstrap (start from nothing...) 94 use _QuantTreeBoot_ (below) for that. 95 """ 96 tree=QuantTree.QuantTreeNode(None,'node') 97 tree.SetData(-666) 98 nPossibleRes = nPossibleVals[-1] 99 100 # counts of each result code: 101 resCodes = [int(x[-1]) for x in examples] 102 counts = [0]*nPossibleRes 103 for res in resCodes: 104 counts[res] += 1 105 nzCounts = nonzero(counts) 106 107 if len(nzCounts) == 1: 108 # bottomed out because there is only one result code left 109 # with any counts (i.e. there's only one type of example 110 # left... this is GOOD!). 111 res = nzCounts[0] 112 tree.SetLabel(res) 113 tree.SetName(str(res)) 114 tree.SetTerminal(1) 115 elif len(attrs) == 0 or (maxDepth>=0 and depth>maxDepth): 116 # Bottomed out: no variables left or max depth hit 117 # We don't really know what to do here, so 118 # use the heuristic of picking the most prevalent 119 # result 120 v = argmax(counts) 121 tree.SetLabel(v) 122 tree.SetName('%d?'%v) 123 tree.SetTerminal(1) 124 else: 125 # find the variable which gives us the largest information gain 126 best,bestGain,bestBounds = FindBest(resCodes,examples,nBoundsPerVar, 127 nPossibleRes,nPossibleVals,attrs, 128 **kwargs) 129 130 # remove that variable from the lists of possible variables 131 nextAttrs = attrs[:] 132 if not kwargs.get('recycleVars',0): 133 nextAttrs.remove(best) 134 135 # set some info at this node 136 tree.SetName('Var: %d'%(best)) 137 tree.SetLabel(best) 138 tree.SetQuantBounds(bestBounds) 139 tree.SetTerminal(0) 140 141 # loop over possible values of the new variable and 142 # build a subtree for each one 143 indices = range(len(examples)) 144 if len(bestBounds) > 0: 145 for bound in bestBounds: 146 nextExamples = [] 147 for index in indices[:]: 148 ex = examples[index] 149 if ex[best] < bound: 150 nextExamples.append(ex) 151 indices.remove(index) 152 153 if len(nextExamples) == 0: 154 # this particular value of the variable has no examples, 155 # so there's not much sense in recursing. 156 # This can (and does) happen. 157 v = argmax(counts) 158 tree.AddChild('%d'%v,label=v,data=0.0,isTerminal=1) 159 else: 160 # recurse 161 tree.AddChildNode(BuildQuantTree(nextExamples,best, 162 nextAttrs,nPossibleVals, 163 nBoundsPerVar, 164 depth=depth+1,maxDepth=maxDepth, 165 **kwargs)) 166 # add the last points remaining 167 nextExamples = [] 168 for index in indices: 169 nextExamples.append(examples[index]) 170 if len(nextExamples) == 0: 171 v = argmax(counts) 172 tree.AddChild('%d'%v,label=v,data=0.0,isTerminal=1) 173 else: 174 tree.AddChildNode(BuildQuantTree(nextExamples,best, 175 nextAttrs,nPossibleVals, 176 nBoundsPerVar, 177 depth=depth+1,maxDepth=maxDepth, 178 **kwargs)) 179 else: 180 for val in xrange(nPossibleVals[best]): 181 nextExamples = [] 182 for example in examples: 183 if example[best] == val: 184 nextExamples.append(example) 185 if len(nextExamples) == 0: 186 v = argmax(counts) 187 tree.AddChild('%d'%v,label=v,data=0.0,isTerminal=1) 188 else: 189 tree.AddChildNode(BuildQuantTree(nextExamples,best, 190 nextAttrs,nPossibleVals, 191 nBoundsPerVar, 192 depth=depth+1,maxDepth=maxDepth, 193 **kwargs)) 194 return tree
195
196 -def QuantTreeBoot(examples,attrs,nPossibleVals,nBoundsPerVar,initialVar=None, 197 maxDepth=-1,**kwargs):
198 """ Bootstrapping code for the QuantTree 199 200 If _initialVar_ is not set, the algorithm will automatically 201 choose the first variable in the tree (the standard greedy 202 approach). Otherwise, _initialVar_ will be used as the first 203 split. 204 205 """ 206 attrs = attrs[:] 207 for i in range(len(nBoundsPerVar)): 208 if nBoundsPerVar[i]==-1 and i in attrs: 209 attrs.remove(i) 210 211 tree=QuantTree.QuantTreeNode(None,'node') 212 nPossibleRes = nPossibleVals[-1] 213 tree._nResultCodes = nPossibleRes 214 215 resCodes = [int(x[-1]) for x in examples] 216 counts = [0]*nPossibleRes 217 for res in resCodes: 218 counts[res] += 1 219 if initialVar is None: 220 best,gainHere,qBounds = FindBest(resCodes,examples,nBoundsPerVar, 221 nPossibleRes,nPossibleVals,attrs, 222 **kwargs) 223 else: 224 best = initialVar 225 if nBoundsPerVar[best] > 0: 226 vTable = map(lambda x,z=best:x[z],examples) 227 qBounds,gainHere = Quantize.FindVarMultQuantBounds(vTable,nBoundsPerVar[best], 228 resCodes,nPossibleRes) 229 elif nBoundsPerVar[best] == 0: 230 vTable = ID3.GenVarTable(examples,nPossibleVals,[best])[0] 231 gainHere = entropy.InfoGain(vTable) 232 qBounds = [] 233 else: 234 gainHere = -1e6 235 qBounds = [] 236 237 tree.SetName('Var: %d'%(best)) 238 tree.SetData(gainHere) 239 tree.SetLabel(best) 240 tree.SetTerminal(0) 241 tree.SetQuantBounds(qBounds) 242 nextAttrs = attrs[:] 243 if not kwargs.get('recycleVars',0): 244 nextAttrs.remove(best) 245 246 indices = range(len(examples)) 247 if len(qBounds) > 0: 248 for bound in qBounds: 249 nextExamples = [] 250 for index in indices[:]: 251 ex = examples[index] 252 if ex[best] < bound: 253 nextExamples.append(ex) 254 indices.remove(index) 255 256 if len(nextExamples): 257 tree.AddChildNode(BuildQuantTree(nextExamples,best, 258 nextAttrs,nPossibleVals, 259 nBoundsPerVar, 260 depth=1,maxDepth=maxDepth, 261 **kwargs)) 262 else: 263 v = argmax(counts) 264 tree.AddChild('%d??'%(v),label=v,data=0.0,isTerminal=1) 265 # add the last points remaining 266 nextExamples = [] 267 for index in indices: 268 nextExamples.append(examples[index]) 269 if len(nextExamples) != 0: 270 tree.AddChildNode(BuildQuantTree(nextExamples,best, 271 nextAttrs,nPossibleVals, 272 nBoundsPerVar, 273 depth=1,maxDepth=maxDepth, 274 **kwargs)) 275 else: 276 v = argmax(counts) 277 tree.AddChild('%d??'%(v),label=v,data=0.0,isTerminal=1) 278 else: 279 for val in xrange(nPossibleVals[best]): 280 nextExamples = [] 281 for example in examples: 282 if example[best] == val: 283 nextExamples.append(example) 284 if len(nextExamples) != 0: 285 tree.AddChildNode(BuildQuantTree(nextExamples,best, 286 nextAttrs,nPossibleVals, 287 nBoundsPerVar, 288 depth=1,maxDepth=maxDepth, 289 **kwargs)) 290 else: 291 v = argmax(counts) 292 tree.AddChild('%d??'%(v),label=v,data=0.0,isTerminal=1) 293 return tree
294 295
296 -def TestTree():
297 """ testing code for named trees 298 299 """ 300 examples1 = [['p1',0,1,0,0], 301 ['p2',0,0,0,1], 302 ['p3',0,0,1,2], 303 ['p4',0,1,1,2], 304 ['p5',1,0,0,2], 305 ['p6',1,0,1,2], 306 ['p7',1,1,0,2], 307 ['p8',1,1,1,0] 308 ] 309 attrs = range(1,len(examples1[0])-1) 310 nPossibleVals = [0,2,2,2,3] 311 t1 = ID3.ID3Boot(examples1,attrs,nPossibleVals,maxDepth=1) 312 t1.Print()
313 314
315 -def TestQuantTree():
316 """ testing code for named trees 317 318 """ 319 examples1 = [['p1',0,1,0.1,0], 320 ['p2',0,0,0.1,1], 321 ['p3',0,0,1.1,2], 322 ['p4',0,1,1.1,2], 323 ['p5',1,0,0.1,2], 324 ['p6',1,0,1.1,2], 325 ['p7',1,1,0.1,2], 326 ['p8',1,1,1.1,0] 327 ] 328 attrs = range(1,len(examples1[0])-1) 329 nPossibleVals = [0,2,2,0,3] 330 boundsPerVar=[0,0,0,1,0] 331 332 print 'base' 333 t1 = QuantTreeBoot(examples1,attrs,nPossibleVals,boundsPerVar) 334 t1.Pickle('regress/QuantTree1.pkl') 335 t1.Print() 336 337 print 'depth limit' 338 t1 = QuantTreeBoot(examples1,attrs,nPossibleVals,boundsPerVar,maxDepth=1) 339 t1.Pickle('regress/QuantTree1.pkl') 340 t1.Print()
341
342 -def TestQuantTree2():
343 """ testing code for named trees 344 345 """ 346 examples1 = [['p1',0.1,1,0.1,0], 347 ['p2',0.1,0,0.1,1], 348 ['p3',0.1,0,1.1,2], 349 ['p4',0.1,1,1.1,2], 350 ['p5',1.1,0,0.1,2], 351 ['p6',1.1,0,1.1,2], 352 ['p7',1.1,1,0.1,2], 353 ['p8',1.1,1,1.1,0] 354 ] 355 attrs = range(1,len(examples1[0])-1) 356 nPossibleVals = [0,0,2,0,3] 357 boundsPerVar=[0,1,0,1,0] 358 359 t1 = QuantTreeBoot(examples1,attrs,nPossibleVals,boundsPerVar) 360 t1.Print() 361 t1.Pickle('regress/QuantTree2.pkl') 362 363 for example in examples1: 364 print example,t1.ClassifyExample(example)
365 366 if __name__ == "__main__": 367 TestTree() 368 TestQuantTree() 369 #TestQuantTree2() 370