Package rdkit :: Package ML :: Package DecTree :: Module BuildQuantTree
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.DecTree.BuildQuantTree

  1  ## Automatically adapted for numpy.oldnumeric Jun 27, 2008 by -c 
  2   
  3  # $Id: BuildQuantTree.py 997 2009-02-25 06:12:43Z glandrum $ 
  4  # 
  5  #  Copyright (C) 2001-2008  greg Landrum and Rational Discovery LLC 
  6  #  All Rights Reserved 
  7  # 
  8  """  
  9   
 10  """ 
 11   
 12  import numpy 
 13  import random 
 14  from rdkit.ML.DecTree import QuantTree, ID3 
 15  from rdkit.ML.InfoTheory import entropy 
 16  from rdkit.ML.Data import Quantize 
 17   
18 -def FindBest(resCodes,examples,nBoundsPerVar,nPossibleRes, 19 nPossibleVals,attrs,exIndices=None,**kwargs):
20 bestGain =-1e6 21 best = -1 22 bestBounds = [] 23 24 if exIndices is None: 25 exIndices=range(len(examples)) 26 27 if not len(exIndices): 28 return best,bestGain,bestBounds 29 30 nToTake = kwargs.get('randomDescriptors',0) 31 if nToTake > 0: 32 nAttrs = len(attrs) 33 if nToTake < nAttrs: 34 ids = range(nAttrs) 35 random.shuffle(ids) 36 tmp = [attrs[x] for x in ids[:nToTake]] 37 #print '\tavail:',tmp 38 attrs = tmp 39 40 for var in attrs: 41 nBounds = nBoundsPerVar[var] 42 if nBounds > 0: 43 #vTable = map(lambda x,z=var:x[z],examples) 44 try: 45 vTable = [examples[x][var] for x in exIndices] 46 except IndexError: 47 print 'index error retrieving variable: %d'%var 48 raise 49 qBounds,gainHere = Quantize.FindVarMultQuantBounds(vTable,nBounds, 50 resCodes,nPossibleRes) 51 #print '\tvar:',var,qBounds,gainHere 52 elif nBounds==0: 53 vTable = ID3.GenVarTable((examples[x] for x in exIndices), 54 nPossibleVals,[var])[0] 55 gainHere = entropy.InfoGain(vTable) 56 qBounds = [] 57 else: 58 gainHere = -1e6 59 qBounds = [] 60 if gainHere > bestGain: 61 bestGain = gainHere 62 bestBounds = qBounds 63 best = var 64 elif bestGain==gainHere: 65 if len(qBounds)<len(bestBounds): 66 best = var 67 bestBounds = qBounds 68 if best == -1: 69 print 'best unaltered' 70 print '\tattrs:',attrs 71 print '\tnBounds:',take(nBoundsPerVar,attrs) 72 print '\texamples:' 73 for example in (examples[x] for x in exIndices): 74 print '\t\t',example 75 76 77 if 0: 78 print 'BEST:',len(exIndices),best,bestGain,bestBounds 79 if(len(exIndices)<10): 80 print len(exIndices),len(resCodes),len(examples) 81 exs = [examples[x] for x in exIndices] 82 vals = [x[best] for x in exs] 83 sortIdx = numpy.argsort(vals) 84 sortVals = [exs[x] for x in sortIdx] 85 sortResults = [resCodes[x] for x in sortIdx] 86 for i in range(len(vals)): 87 print ' ',i,['%.4f'%x for x in sortVals[i][1:-1]],sortResults[i] 88 return best,bestGain,bestBounds
89 90
91 -def BuildQuantTree(examples,target,attrs,nPossibleVals,nBoundsPerVar, 92 depth=0,maxDepth=-1,exIndices=None,**kwargs):
93 """ 94 **Arguments** 95 96 - examples: a list of lists (nInstances x nVariables+1) of variable 97 values + instance values 98 99 - target: an int 100 101 - attrs: a list of ints indicating which variables can be used in the tree 102 103 - nPossibleVals: a list containing the number of possible values of 104 every variable. 105 106 - nBoundsPerVar: the number of bounds to include for each variable 107 108 - depth: (optional) the current depth in the tree 109 110 - maxDepth: (optional) the maximum depth to which the tree 111 will be grown 112 **Returns** 113 114 a QuantTree.QuantTreeNode with the decision tree 115 116 **NOTE:** This code cannot bootstrap (start from nothing...) 117 use _QuantTreeBoot_ (below) for that. 118 """ 119 tree=QuantTree.QuantTreeNode(None,'node') 120 tree.SetData(-666) 121 nPossibleRes = nPossibleVals[-1] 122 123 if exIndices is None: 124 exIndices=range(len(examples)) 125 126 # counts of each result code: 127 resCodes = [int(x[-1]) for x in (examples[y] for y in exIndices)] 128 counts = [0]*nPossibleRes 129 for res in resCodes: 130 counts[res] += 1 131 nzCounts = numpy.nonzero(counts)[0] 132 133 if len(nzCounts) == 1: 134 # bottomed out because there is only one result code left 135 # with any counts (i.e. there's only one type of example 136 # left... this is GOOD!). 137 res = nzCounts[0] 138 tree.SetLabel(res) 139 tree.SetName(str(res)) 140 tree.SetTerminal(1) 141 elif len(attrs) == 0 or (maxDepth>=0 and depth>maxDepth): 142 # Bottomed out: no variables left or max depth hit 143 # We don't really know what to do here, so 144 # use the heuristic of picking the most prevalent 145 # result 146 v = numpy.argmax(counts) 147 tree.SetLabel(v) 148 tree.SetName('%d?'%v) 149 tree.SetTerminal(1) 150 else: 151 # find the variable which gives us the largest information gain 152 best,bestGain,bestBounds = FindBest(resCodes,examples,nBoundsPerVar, 153 nPossibleRes,nPossibleVals,attrs, 154 exIndices=exIndices, 155 **kwargs) 156 157 # remove that variable from the lists of possible variables 158 nextAttrs = attrs[:] 159 if not kwargs.get('recycleVars',0): 160 nextAttrs.remove(best) 161 162 # set some info at this node 163 tree.SetName('Var: %d'%(best)) 164 tree.SetLabel(best) 165 tree.SetQuantBounds(bestBounds) 166 tree.SetTerminal(0) 167 168 # loop over possible values of the new variable and 169 # build a subtree for each one 170 indices = exIndices[:] 171 if len(bestBounds) > 0: 172 for bound in bestBounds: 173 nextExamples = [] 174 for index in indices[:]: 175 ex = examples[index] 176 if ex[best] < bound: 177 nextExamples.append(index) 178 indices.remove(index) 179 180 if len(nextExamples) == 0: 181 # this particular value of the variable has no examples, 182 # so there's not much sense in recursing. 183 # This can (and does) happen. 184 v = numpy.argmax(counts) 185 tree.AddChild('%d'%v,label=v,data=0.0,isTerminal=1) 186 else: 187 # recurse 188 tree.AddChildNode(BuildQuantTree(examples,best, 189 nextAttrs,nPossibleVals, 190 nBoundsPerVar, 191 depth=depth+1,maxDepth=maxDepth, 192 exIndices=nextExamples, 193 **kwargs)) 194 # add the last points remaining 195 nextExamples = [] 196 for index in indices: 197 nextExamples.append(index) 198 if len(nextExamples) == 0: 199 v = numpy.argmax(counts) 200 tree.AddChild('%d'%v,label=v,data=0.0,isTerminal=1) 201 else: 202 tree.AddChildNode(BuildQuantTree(examples,best, 203 nextAttrs,nPossibleVals, 204 nBoundsPerVar, 205 depth=depth+1,maxDepth=maxDepth, 206 exIndices=nextExamples, 207 **kwargs)) 208 else: 209 for val in xrange(nPossibleVals[best]): 210 nextExamples = [] 211 for idx in exIndices: 212 if examples[idx][best] == val: 213 nextExamples.append(idx) 214 if len(nextExamples) == 0: 215 v = numpy.argmax(counts) 216 tree.AddChild('%d'%v,label=v,data=0.0,isTerminal=1) 217 else: 218 tree.AddChildNode(BuildQuantTree(examples,best, 219 nextAttrs,nPossibleVals, 220 nBoundsPerVar, 221 depth=depth+1,maxDepth=maxDepth, 222 exIndices=nextExamples, 223 **kwargs)) 224 return tree
225
226 -def QuantTreeBoot(examples,attrs,nPossibleVals,nBoundsPerVar,initialVar=None, 227 maxDepth=-1,**kwargs):
228 """ Bootstrapping code for the QuantTree 229 230 If _initialVar_ is not set, the algorithm will automatically 231 choose the first variable in the tree (the standard greedy 232 approach). Otherwise, _initialVar_ will be used as the first 233 split. 234 235 """ 236 attrs = attrs[:] 237 for i in range(len(nBoundsPerVar)): 238 if nBoundsPerVar[i]==-1 and i in attrs: 239 attrs.remove(i) 240 241 tree=QuantTree.QuantTreeNode(None,'node') 242 nPossibleRes = nPossibleVals[-1] 243 tree._nResultCodes = nPossibleRes 244 245 resCodes = [int(x[-1]) for x in examples] 246 counts = [0]*nPossibleRes 247 for res in resCodes: 248 counts[res] += 1 249 if initialVar is None: 250 best,gainHere,qBounds = FindBest(resCodes,examples,nBoundsPerVar, 251 nPossibleRes,nPossibleVals,attrs, 252 **kwargs) 253 else: 254 best = initialVar 255 if nBoundsPerVar[best] > 0: 256 vTable = map(lambda x,z=best:x[z],examples) 257 qBounds,gainHere = Quantize.FindVarMultQuantBounds(vTable,nBoundsPerVar[best], 258 resCodes,nPossibleRes) 259 elif nBoundsPerVar[best] == 0: 260 vTable = ID3.GenVarTable(examples,nPossibleVals,[best])[0] 261 gainHere = entropy.InfoGain(vTable) 262 qBounds = [] 263 else: 264 gainHere = -1e6 265 qBounds = [] 266 267 tree.SetName('Var: %d'%(best)) 268 tree.SetData(gainHere) 269 tree.SetLabel(best) 270 tree.SetTerminal(0) 271 tree.SetQuantBounds(qBounds) 272 nextAttrs = attrs[:] 273 if not kwargs.get('recycleVars',0): 274 nextAttrs.remove(best) 275 276 indices = range(len(examples)) 277 if len(qBounds) > 0: 278 for bound in qBounds: 279 nextExamples = [] 280 for index in indices[:]: 281 ex = examples[index] 282 if ex[best] < bound: 283 nextExamples.append(ex) 284 indices.remove(index) 285 286 if len(nextExamples): 287 tree.AddChildNode(BuildQuantTree(nextExamples,best, 288 nextAttrs,nPossibleVals, 289 nBoundsPerVar, 290 depth=1,maxDepth=maxDepth, 291 **kwargs)) 292 else: 293 v = numpy.argmax(counts) 294 tree.AddChild('%d??'%(v),label=v,data=0.0,isTerminal=1) 295 # add the last points remaining 296 nextExamples = [] 297 for index in indices: 298 nextExamples.append(examples[index]) 299 if len(nextExamples) != 0: 300 tree.AddChildNode(BuildQuantTree(nextExamples,best, 301 nextAttrs,nPossibleVals, 302 nBoundsPerVar, 303 depth=1,maxDepth=maxDepth, 304 **kwargs)) 305 else: 306 v = numpy.argmax(counts) 307 tree.AddChild('%d??'%(v),label=v,data=0.0,isTerminal=1) 308 else: 309 for val in xrange(nPossibleVals[best]): 310 nextExamples = [] 311 for example in examples: 312 if example[best] == val: 313 nextExamples.append(example) 314 if len(nextExamples) != 0: 315 tree.AddChildNode(BuildQuantTree(nextExamples,best, 316 nextAttrs,nPossibleVals, 317 nBoundsPerVar, 318 depth=1,maxDepth=maxDepth, 319 **kwargs)) 320 else: 321 v = numpy.argmax(counts) 322 tree.AddChild('%d??'%(v),label=v,data=0.0,isTerminal=1) 323 return tree
324 325
326 -def TestTree():
327 """ testing code for named trees 328 329 """ 330 examples1 = [['p1',0,1,0,0], 331 ['p2',0,0,0,1], 332 ['p3',0,0,1,2], 333 ['p4',0,1,1,2], 334 ['p5',1,0,0,2], 335 ['p6',1,0,1,2], 336 ['p7',1,1,0,2], 337 ['p8',1,1,1,0] 338 ] 339 attrs = range(1,len(examples1[0])-1) 340 nPossibleVals = [0,2,2,2,3] 341 t1 = ID3.ID3Boot(examples1,attrs,nPossibleVals,maxDepth=1) 342 t1.Print()
343 344
345 -def TestQuantTree():
346 """ testing code for named trees 347 348 """ 349 examples1 = [['p1',0,1,0.1,0], 350 ['p2',0,0,0.1,1], 351 ['p3',0,0,1.1,2], 352 ['p4',0,1,1.1,2], 353 ['p5',1,0,0.1,2], 354 ['p6',1,0,1.1,2], 355 ['p7',1,1,0.1,2], 356 ['p8',1,1,1.1,0] 357 ] 358 attrs = range(1,len(examples1[0])-1) 359 nPossibleVals = [0,2,2,0,3] 360 boundsPerVar=[0,0,0,1,0] 361 362 print 'base' 363 t1 = QuantTreeBoot(examples1,attrs,nPossibleVals,boundsPerVar) 364 t1.Pickle('test_data/QuantTree1.pkl') 365 t1.Print() 366 367 print 'depth limit' 368 t1 = QuantTreeBoot(examples1,attrs,nPossibleVals,boundsPerVar,maxDepth=1) 369 t1.Pickle('test_data/QuantTree1.pkl') 370 t1.Print()
371
372 -def TestQuantTree2():
373 """ testing code for named trees 374 375 """ 376 examples1 = [['p1',0.1,1,0.1,0], 377 ['p2',0.1,0,0.1,1], 378 ['p3',0.1,0,1.1,2], 379 ['p4',0.1,1,1.1,2], 380 ['p5',1.1,0,0.1,2], 381 ['p6',1.1,0,1.1,2], 382 ['p7',1.1,1,0.1,2], 383 ['p8',1.1,1,1.1,0] 384 ] 385 attrs = range(1,len(examples1[0])-1) 386 nPossibleVals = [0,0,2,0,3] 387 boundsPerVar=[0,1,0,1,0] 388 389 t1 = QuantTreeBoot(examples1,attrs,nPossibleVals,boundsPerVar) 390 t1.Print() 391 t1.Pickle('test_data/QuantTree2.pkl') 392 393 for example in examples1: 394 print example,t1.ClassifyExample(example)
395 396 if __name__ == "__main__": 397 TestTree() 398 TestQuantTree() 399 #TestQuantTree2() 400