Package rdkit :: Package ML :: Package Composite :: Module Composite
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.Composite.Composite

  1  # $Id: Composite.py 997 2009-02-25 06:12:43Z glandrum $ 
  2  # 
  3  #  Copyright (C) 2000-2008  greg Landrum and Rational Discovery LLC 
  4  #   All Rights Reserved 
  5  # 
  6  """ code for dealing with composite models 
  7   
  8  For a model to be useable here, it should support the following API: 
  9   
 10    - _ClassifyExample(example)_, returns a classification 
 11   
 12  Other compatibility notes: 
 13   
 14   1) To use _Composite.Grow_ there must be some kind of builder 
 15      functionality which returns a 2-tuple containing (model,percent accuracy). 
 16   
 17   2) The models should be pickleable 
 18   
 19   3) It would be very happy if the models support the __cmp__ method so that 
 20      membership tests used to make sure models are unique work. 
 21   
 22   
 23   
 24  """ 
 25  from rdkit.ML.Data import DataUtils 
 26  import cPickle 
 27  import math 
 28  import numpy 
 29   
30 -class Composite(object):
31 """a composite model 32 33 34 **Notes** 35 36 - adding a model which is already present just results in its count 37 field being incremented and the errors being averaged. 38 39 - typical usage: 40 41 1) grow the composite with AddModel until happy with it 42 43 2) call AverageErrors to calculate the average error values 44 45 3) call SortModels to put things in order by either error or count 46 47 - Composites can support individual models requiring either quantized or 48 nonquantized data. This is done by keeping a set of quantization bounds 49 (_QuantBounds_) in the composite and quantizing data passed in when required. 50 Quantization bounds can be set and interrogated using the 51 _Get/SetQuantBounds()_ methods. When models are added to the composite, 52 it can be indicated whether or not they require quantization. 53 54 - Composites are also capable of extracting relevant variables from longer lists. 55 This functionality was added to ease interactions with Excel and COM applications. 56 This is accessible using _SetDescriptorNames()_ to register the descriptors about 57 which the composite cares and _SetInputOrder()_ to tell the composite what the 58 ordering of input vectors will be. **Note** there is a limitation on this: each 59 model needs to take the same set of descriptors as inputs. This could be changed. 60 61 """
62 - def __init__(self):
63 self.modelList=[] 64 self.errList=[] 65 self.countList=[] 66 self.modelVotes=[] 67 self.quantBounds = None 68 self.nPossibleVals = None 69 self.quantizationRequirements=[] 70 self._descNames = [] 71 self._mapOrder = None 72 self.activityQuant=[]
73
74 - def SetModelFilterData(self, modelFilterFrac=0.0, modelFilterVal=0.0) :
75 self._modelFilterFrac = modelFilterFrac 76 self._modelFilterVal = modelFilterVal
77
78 - def SetDescriptorNames(self,names):
79 """ registers the names of the descriptors this composite uses 80 81 **Arguments** 82 83 - names: a list of descriptor names (strings). 84 85 **NOTE** 86 87 the _names_ list is not 88 copied, so if you modify it later, the composite itself will also be modified. 89 90 """ 91 self._descNames = names
92 - def GetDescriptorNames(self):
93 """ returns the names of the descriptors this composite uses 94 95 """ 96 return self._descNames
97
98 - def SetQuantBounds(self,qBounds,nPossible=None):
99 """ sets the quantization bounds that the composite will use 100 101 **Arguments** 102 103 - qBounds: a list of quantization bounds, each quantbound is a 104 list of boundaries 105 106 - nPossible: a list of integers indicating how many possible values 107 each descriptor can take on. 108 109 **NOTE** 110 111 - if the two lists are of different lengths, this will assert out 112 113 - neither list is copied, so if you modify it later, the composite 114 itself will also be modified. 115 116 117 """ 118 if nPossible is not None: 119 assert len(qBounds)==len(nPossible),'qBounds/nPossible mismatch' 120 self.quantBounds = qBounds 121 self.nPossibleVals = nPossible
122
123 - def GetQuantBounds(self):
124 """ returns the quantization bounds 125 126 **Returns** 127 128 a 2-tuple consisting of: 129 130 1) the list of quantization bounds 131 132 2) the nPossibleVals list 133 134 """ 135 return self.quantBounds,self.nPossibleVals
136
137 - def GetActivityQuantBounds(self):
138 if not hasattr(self,'activityQuant'): 139 self.activityQuant=[] 140 return self.activityQuant
141 - def SetActivityQuantBounds(self,bounds):
142 self.activityQuant=bounds
143 - def QuantizeActivity(self,example,activityQuant=None,actCol=-1):
144 if activityQuant is None: 145 activityQuant=self.activityQuant 146 if activityQuant: 147 example = example[:] 148 act = example[actCol] 149 for box in xrange(len(activityQuant)): 150 if act < activityQuant[box]: 151 act = box 152 break 153 else: 154 act = box + 1 155 example[actCol] = act 156 return example
157
158 - def QuantizeExample(self,example,quantBounds=None):
159 """ quantizes an example 160 161 **Arguments** 162 163 - example: a data point (list, tuple or numpy array) 164 165 - quantBounds: a list of quantization bounds, each quantbound is a 166 list of boundaries. If this argument is not provided, the composite 167 will use its own quantBounds 168 169 **Returns** 170 171 the quantized example as a list 172 173 **Notes** 174 175 - If _example_ is different in length from _quantBounds_, this will 176 assert out. 177 178 - This is primarily intended for internal use 179 180 """ 181 if quantBounds is None: 182 quantBounds = self.quantBounds 183 assert len(example)==len(quantBounds),'example/quantBounds mismatch' 184 quantExample = [None]*len(example) 185 for i in xrange(len(quantBounds)): 186 bounds = quantBounds[i] 187 p = example[i] 188 if len(bounds): 189 for box in xrange(len(bounds)): 190 if p < bounds[box]: 191 p = box 192 break 193 else: 194 p = box + 1 195 else: 196 if i != 0: 197 p = int(p) 198 quantExample[i] = p 199 return quantExample
200
201 - def MakeHistogram(self):
202 """ creates a histogram of error/count pairs 203 204 **Returns** 205 206 the histogram as a series of (error, count) 2-tuples 207 208 209 """ 210 nExamples = len(self.modelList) 211 histo = [] 212 i = 1 213 lastErr = self.errList[0] 214 countHere = self.countList[0] 215 eps = 0.001 216 while i < nExamples: 217 if self.errList[i]-lastErr > eps: 218 histo.append((lastErr,countHere)) 219 lastErr = self.errList[i] 220 countHere = self.countList[i] 221 else: 222 countHere = countHere + self.countList[i] 223 i = i + 1 224 225 return histo
226
227 - def CollectVotes(self,example,quantExample,appendExample=0, 228 onlyModels=None):
229 """ collects votes across every member of the composite for the given example 230 231 **Arguments** 232 233 - example: the example to be voted upon 234 235 - quantExample: the quantized form of the example 236 237 - appendExample: toggles saving the example on the models 238 239 - onlyModels: if provided, this should be a sequence of model 240 indices. Only the specified models will be used in the 241 prediction. 242 243 **Returns** 244 245 a list with a vote from each member 246 247 """ 248 if not onlyModels: 249 onlyModels = range(len(self)) 250 251 nModels = len(onlyModels) 252 votes = [-1]*len(self) 253 for i in onlyModels: 254 if self.quantizationRequirements[i]: 255 votes[i] = int(round(self.modelList[i].ClassifyExample(quantExample, 256 appendExamples=appendExample))) 257 else: 258 votes[i] = int(round(self.modelList[i].ClassifyExample(example, 259 appendExamples=appendExample))) 260 261 return votes
262
263 - def ClassifyExample(self,example,threshold=0,appendExample=0, 264 onlyModels=None):
265 """ classifies the given example using the entire composite 266 267 **Arguments** 268 269 - example: the data to be classified 270 271 - threshold: if this is a number greater than zero, then a 272 classification will only be returned if the confidence is 273 above _threshold_. Anything lower is returned as -1. 274 275 - appendExample: toggles saving the example on the models 276 277 - onlyModels: if provided, this should be a sequence of model 278 indices. Only the specified models will be used in the 279 prediction. 280 281 **Returns** 282 283 a (result,confidence) tuple 284 285 286 **FIX:** 287 statistics sucks... I'm not seeing an obvious way to get 288 the confidence intervals. For that matter, I'm not seeing 289 an unobvious way. 290 291 For now, this is just treated as a voting problem with the confidence 292 measure being the percent of models which voted for the winning result. 293 294 """ 295 if self._mapOrder is not None: 296 example = self._RemapInput(example) 297 if self.GetActivityQuantBounds(): 298 example = self.QuantizeActivity(example) 299 if self.quantBounds is not None and 1 in self.quantizationRequirements: 300 quantExample = self.QuantizeExample(example,self.quantBounds) 301 else: 302 quantExample = [] 303 304 if not onlyModels: 305 onlyModels = range(len(self)) 306 self.modelVotes = self.CollectVotes(example,quantExample,appendExample=appendExample, 307 onlyModels=onlyModels) 308 309 votes = [0]*self.nPossibleVals[-1] 310 for i in onlyModels: 311 res = self.modelVotes[i] 312 votes[res] = votes[res] + self.countList[i] 313 314 totVotes = sum(votes) 315 res = numpy.argmax(votes) 316 conf = float(votes[res])/float(totVotes) 317 if conf > threshold: 318 return res,conf 319 else: 320 return -1,conf
321
322 - def GetVoteDetails(self):
323 """ returns the votes from the last classification 324 325 This will be _None_ if nothing has yet be classified 326 """ 327 return self.modelVotes
328
329 - def _RemapInput(self,inputVect):
330 """ remaps the input so that it matches the expected internal ordering 331 332 **Arguments** 333 334 - inputVect: the input to be reordered 335 336 **Returns** 337 338 - a list with the reordered (and possible shorter) data 339 340 **Note** 341 342 - you must call _SetDescriptorNames()_ and _SetInputOrder()_ for this to work 343 344 - this is primarily intended for internal use 345 346 """ 347 order = self._mapOrder 348 349 if order is None: 350 return inputVect 351 remappedInput = [None]*len(order) 352 353 for i in xrange(len(order)-1): 354 remappedInput[i] = inputVect[order[i]] 355 if order[-1] == -1: 356 remappedInput[-1] = 0 357 else: 358 remappedInput[-1] = inputVect[order[-1]] 359 return remappedInput
360
361 - def GetInputOrder(self):
362 """ returns the input order (used in remapping inputs) 363 364 """ 365 return self._mapOrder
366
367 - def SetInputOrder(self,colNames):
368 """ sets the input order 369 370 **Arguments** 371 372 - colNames: a list of the names of the data columns that will be passed in 373 374 **Note** 375 376 - you must call _SetDescriptorNames()_ first for this to work 377 378 - if the local descriptor names do not appear in _colNames_, this will 379 raise an _IndexError_ exception. 380 """ 381 import types 382 if type(colNames)!=types.ListType: 383 colNames = list(colNames) 384 descs = [x.upper() for x in self.GetDescriptorNames()] 385 self._mapOrder = [None]*len(descs) 386 colNames = [x.upper() for x in colNames] 387 388 # FIX: I believe that we're safe assuming that field 0 389 # is always the label, and therefore safe to ignore errors, 390 # but this may not be the case 391 try: 392 self._mapOrder[0] = colNames.index(descs[0]) 393 except ValueError: 394 self._mapOrder[0] = 0 395 396 for i in xrange(1,len(descs)-1): 397 try: 398 self._mapOrder[i] = colNames.index(descs[i]) 399 except ValueError: 400 raise ValueError,'cannot find descriptor name: %s in set %s'%(repr(descs[i]),repr(colNames)) 401 try: 402 self._mapOrder[-1] = colNames.index(descs[-1]) 403 except ValueError: 404 # ok, there's no obvious match for the final column (activity) 405 # We'll take the last one: 406 #self._mapOrder[-1] = len(descs)-1 407 self._mapOrder[-1] = -1
408
409 - def Grow(self,examples,attrs,nPossibleVals,buildDriver,pruner=None, 410 nTries=10,pruneIt=0, 411 needsQuantization=1,progressCallback=None, 412 **buildArgs):
413 """ Grows the composite 414 415 **Arguments** 416 417 - examples: a list of examples to be used in training 418 419 - attrs: a list of the variables to be used in training 420 421 - nPossibleVals: this is used to provide a list of the number 422 of possible values for each variable. It is used if the 423 local quantBounds have not been set (for example for when you 424 are working with data which is already quantized). 425 426 - buildDriver: the function to call to build the new models 427 428 - pruner: a function used to "prune" (reduce the complexity of) 429 the resulting model. 430 431 - nTries: the number of new models to add 432 433 - pruneIt: toggles whether or not pruning is done 434 435 - needsQuantization: used to indicate whether or not this type of model 436 requires quantized data 437 438 - **buildArgs: all other keyword args are passed to _buildDriver_ 439 440 **Note** 441 442 - new models are *added* to the existing ones 443 444 """ 445 try: 446 silent = buildArgs['silent'] 447 except: 448 silent = 0 449 buildArgs['silent']=1 450 buildArgs['calcTotalError']=1 451 452 if self._mapOrder is not None: 453 examples = map(self._RemapInput,examples) 454 if self.GetActivityQuantBounds(): 455 for i in xrange(len(examples)): 456 examples[i] = self.QuantizeActivity(examples[i]) 457 nPossibleVals[-1]=len(self.GetActivityQuantBounds())+1 458 if self.nPossibleVals is None: 459 self.nPossibleVals = nPossibleVals[:] 460 if needsQuantization: 461 trainExamples = [None]*len(examples) 462 nPossibleVals = self.nPossibleVals 463 for i in xrange(len(examples)): 464 trainExamples[i] = self.QuantizeExample(examples[i],self.quantBounds) 465 else: 466 trainExamples = examples 467 468 for i in xrange(nTries): 469 trainSet = None 470 471 if (hasattr(self, '_modelFilterFrac')) and (self._modelFilterFrac != 0) : 472 trainIdx, temp = DataUtils.FilterData(trainExamples, self._modelFilterVal, 473 self._modelFilterFrac,-1, indicesOnly=1) 474 trainSet = [trainExamples[x] for x in trainIdx] 475 476 else: 477 trainSet = trainExamples 478 479 #print "Training model %i with %i out of %i examples"%(i, len(trainSet), len(trainExamples)) 480 model,frac = apply(buildDriver,(trainSet,attrs,nPossibleVals), 481 buildArgs) 482 if pruneIt: 483 model,frac2 = pruner(model,model.GetTrainingExamples(), 484 model.GetTestExamples(), 485 minimizeTestErrorOnly=0) 486 frac = frac2 487 if hasattr(self, '_modelFilterFrac') and self._modelFilterFrac!=0 and \ 488 hasattr(model,'_trainIndices'): 489 # correct the model's training indices: 490 trainIndices = [trainIdx[x] for x in model._trainIndices] 491 model._trainIndices = trainIndices 492 493 self.AddModel(model,frac,needsQuantization) 494 if not silent and (nTries < 10 or i % (nTries/10) == 0): 495 print 'Cycle: % 4d'%(i) 496 if progressCallback is not None: 497 progressCallback(i)
498 499
500 - def ClearModelExamples(self):
501 for i in range(len(self)): 502 m = self.GetModel(i) 503 try: 504 m.ClearExamples() 505 except AttributeError: 506 pass
507
508 - def Pickle(self,fileName='foo.pkl',saveExamples=0):
509 """ Writes this composite off to a file so that it can be easily loaded later 510 511 **Arguments** 512 513 - fileName: the name of the file to be written 514 515 - saveExamples: if this is zero, the individual models will have 516 their stored examples cleared. 517 518 """ 519 if not saveExamples: 520 self.ClearModelExamples() 521 522 pFile = open(fileName,'wb+') 523 cPickle.dump(self,pFile,1) 524 pFile.close()
525
526 - def AddModel(self,model,error,needsQuantization=1):
527 """ Adds a model to the composite 528 529 **Arguments** 530 531 - model: the model to be added 532 533 - error: the model's error 534 535 - needsQuantization: a toggle to indicate whether or not this model 536 requires quantized inputs 537 538 **NOTE** 539 540 - this can be used as an alternative to _Grow()_ if you already have 541 some models constructed 542 543 - the errList is run as an accumulator, 544 you probably want to call _AverageErrors_ after finishing the forest 545 546 """ 547 if model in self.modelList: 548 try: 549 idx = self.modelList.index(model) 550 except ValueError: 551 # FIX: we should never get here, but sometimes we do anyway 552 self.modelList.append(model) 553 self.errList.append(error) 554 self.countList.append(1) 555 self.quantizationRequirements.append(needsQuantization) 556 else: 557 self.errList[idx] = self.errList[idx]+error 558 self.countList[idx] = self.countList[idx] + 1 559 else: 560 self.modelList.append(model) 561 self.errList.append(error) 562 self.countList.append(1) 563 self.quantizationRequirements.append(needsQuantization)
564
565 - def AverageErrors(self):
566 """ convert local summed error to average error 567 568 """ 569 self.errList = map(lambda x,y:x/y,self.errList,self.countList)
570
571 - def SortModels(self,sortOnError=1):
572 """ sorts the list of models 573 574 **Arguments** 575 576 sortOnError: toggles sorting on the models' errors rather than their counts 577 578 579 """ 580 if sortOnError: 581 order = numpy.argsort(self.errList) 582 else: 583 order = numpy.argsort(self.countList) 584 585 # these elaborate contortions are required because, at the time this 586 # code was written, Numeric arrays didn't unpickle so well... 587 self.modelList = [self.modelList[x] for x in order] 588 self.countList = [self.countList[x] for x in order] 589 self.errList = [self.errList[x] for x in order]
590 591
592 - def GetModel(self,i):
593 """ returns a particular model 594 595 """ 596 return self.modelList[i]
597 - def SetModel(self,i,val):
598 """ replaces a particular model 599 600 **Note** 601 602 This is included for the sake of completeness, but you need to be 603 *very* careful when you use it. 604 605 """ 606 self.modelList[i] = val
607
608 - def GetCount(self,i):
609 """ returns the count of the _i_th model 610 611 """ 612 return self.countList[i]
613 - def SetCount(self,i,val):
614 """ sets the count of the _i_th model 615 616 """ 617 self.countList[i] = val
618
619 - def GetError(self,i):
620 """ returns the error of the _i_th model 621 622 """ 623 return self.errList[i]
624 - def SetError(self,i,val):
625 """ sets the error of the _i_th model 626 627 """ 628 self.errList[i] = val
629
630 - def GetDataTuple(self,i):
631 """ returns all relevant data about a particular model 632 633 **Arguments** 634 635 i: an integer indicating which model should be returned 636 637 **Returns** 638 639 a 3-tuple consisting of: 640 641 1) the model 642 643 2) its count 644 645 3) its error 646 """ 647 return (self.modelList[i],self.countList[i],self.errList[i])
648 - def SetDataTuple(self,i,tup):
649 """ sets all relevant data for a particular tree in the forest 650 651 **Arguments** 652 653 - i: an integer indicating which model should be returned 654 655 - tup: a 3-tuple consisting of: 656 657 1) the model 658 659 2) its count 660 661 3) its error 662 663 **Note** 664 665 This is included for the sake of completeness, but you need to be 666 *very* careful when you use it. 667 668 """ 669 self.modelList[i],self.countList[i],self.errList[i] = tup
670
671 - def GetAllData(self):
672 """ Returns everything we know 673 674 **Returns** 675 676 a 3-tuple consisting of: 677 678 1) our list of models 679 680 2) our list of model counts 681 682 3) our list of model errors 683 684 """ 685 return (self.modelList,self.countList,self.errList)
686
687 - def __len__(self):
688 """ allows len(composite) to work 689 690 """ 691 return len(self.modelList)
692
693 - def __getitem__(self,which):
694 """ allows composite[i] to work, returns the data tuple 695 696 """ 697 return self.GetDataTuple(which)
698
699 - def __str__(self):
700 """ returns a string representation of the composite 701 702 """ 703 outStr= 'Composite\n' 704 for i in xrange(len(self.modelList)): 705 outStr = outStr + \ 706 ' Model % 4d: % 5d occurances %%% 5.2f average error\n'%(i,self.countList[i], 707 100.*self.errList[i]) 708 return outStr
709 710 if __name__ == '__main__': 711 if 0: 712 from rdkit.ML.DecTree import DecTree 713 c = Composite() 714 n = DecTree.DecTreeNode(None,'foo') 715 c.AddModel(n,0.5) 716 c.AddModel(n,0.5) 717 c.AverageErrors() 718 c.SortModels() 719 print c 720 721 qB = [[],[.5,1,1.5]] 722 exs = [['foo',0],['foo',.4],['foo',.6],['foo',1.1],['foo',2.0]] 723 print 'quantBounds:',qB 724 for ex in exs: 725 q = c.QuantizeExample(ex,qB) 726 print ex,q 727 else: 728 pass 729