Package rdkit :: Package ML :: Module EnrichPlot
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.EnrichPlot

  1  # $Id: EnrichPlot.py 997 2009-02-25 06:12:43Z glandrum $ 
  2  # 
  3  #  Copyright (C) 2002-2006  greg Landrum and Rational Discovery LLC 
  4  # 
  5  #   @@ All Rights Reserved  @@ 
  6  # 
  7   
  8  """Command line tool to construct an enrichment plot from saved composite models 
  9   
 10  Usage:  EnrichPlot [optional args] -d dbname -t tablename <models> 
 11   
 12  Required Arguments: 
 13    -d "dbName": the name of the database for screening 
 14   
 15    -t "tablename": provide the name of the table with the data to be screened 
 16   
 17    <models>: file name(s) of pickled composite model(s). 
 18       If the -p argument is also provided (see below), this argument is ignored. 
 19        
 20  Optional Arguments: 
 21    - -a "list": the list of result codes to be considered active.  This will be 
 22          eval'ed, so be sure that it evaluates as a list or sequence of 
 23          integers. For example, -a "[1,2]" will consider activity values 1 and 2 
 24          to be active 
 25   
 26    - --enrich "list": identical to the -a argument above.       
 27   
 28    - --thresh: sets a threshold for the plot.  If the confidence falls below 
 29            this value, picking will be terminated 
 30   
 31    - -H: screen only the hold out set (works only if a version of  
 32          BuildComposite more recent than 1.2.2 was used). 
 33   
 34    - -T: screen only the training set (works only if a version of  
 35          BuildComposite more recent than 1.2.2 was used). 
 36   
 37    - -S: shuffle activity values before screening 
 38   
 39    - -R: randomize activity values before screening 
 40   
 41    - -F *filter frac*: filters the data before training to change the 
 42       distribution of activity values in the training set.  *filter frac* 
 43       is the fraction of the training set that should have the target value. 
 44       **See note in BuildComposite help about data filtering** 
 45   
 46    - -v *filter value*: filters the data before training to change the 
 47       distribution of activity values in the training set. *filter value* 
 48       is the target value to use in filtering. 
 49       **See note in BuildComposite help about data filtering** 
 50   
 51    - -p "tableName": provides the name of a db table containing the 
 52        models to be screened.  If you use this argument, you should also 
 53        use the -N argument (below) to specify a note value. 
 54         
 55    - -N "note": provides a note to be used to pull models from a db table. 
 56   
 57    - --plotFile "filename": writes the data to an output text file (filename.dat) 
 58      and creates a gnuplot input file (filename.gnu) to plot it 
 59   
 60    - --showPlot: causes the gnuplot plot constructed using --plotFile to be 
 61      displayed in gnuplot. 
 62   
 63  """ 
 64  from rdkit import RDConfig 
 65  import numpy 
 66  import cPickle,copy 
 67  #from rdkit.Dbase.DbConnection import DbConnect 
 68  from rdkit.ML.Data import DataUtils,SplitData,Stats 
 69  from rdkit.Dbase.DbConnection import DbConnect 
 70  from rdkit import DataStructs 
 71  from rdkit.ML import CompositeRun 
 72  import sys,os,types 
 73   
 74  __VERSION_STRING="2.4.0" 
75 -def message(msg,noRet=0,dest=sys.stderr):
76 """ emits messages to _sys.stderr_ 77 override this in modules which import this one to redirect output 78 79 **Arguments** 80 81 - msg: the string to be displayed 82 83 """ 84 if noRet: 85 dest.write('%s '%(msg)) 86 else: 87 dest.write('%s\n'%(msg))
88 -def error(msg,dest=sys.stderr):
89 """ emits messages to _sys.stderr_ 90 override this in modules which import this one to redirect output 91 92 **Arguments** 93 94 - msg: the string to be displayed 95 96 """ 97 sys.stderr.write('ERROR: %s\n'%(msg))
98
99 -def ScreenModel(mdl,descs,data,picking=[1],indices=[],errorEstimate=0):
100 """ collects the results of screening an individual composite model that match 101 a particular value 102 103 **Arguments** 104 105 - mdl: the composite model 106 107 - descs: a list of descriptor names corresponding to the data set 108 109 - data: the data set, a list of points to be screened. 110 111 - picking: (Optional) a list of values that are to be collected. 112 For examples, if you want an enrichment plot for picking the values 113 1 and 2, you'd having picking=[1,2]. 114 115 **Returns** 116 117 a list of 4-tuples containing: 118 119 - the id of the point 120 121 - the true result (from the data set) 122 123 - the predicted result 124 125 - the confidence value for the prediction 126 127 """ 128 mdl.SetInputOrder(descs) 129 130 for j in range(len(mdl)): 131 tmp = mdl.GetModel(j) 132 if hasattr(tmp,'_trainIndices') and type(tmp._trainIndices)!=types.DictType: 133 tis = {} 134 if hasattr(tmp,'_trainIndices'): 135 for v in tmp._trainIndices: tis[v]=1 136 tmp._trainIndices=tis 137 138 res = [] 139 if mdl.GetQuantBounds(): 140 needsQuant = 1 141 else: 142 needsQuant = 0 143 144 if not indices: indices = range(len(data)) 145 nTrueActives=0 146 for i in indices: 147 if errorEstimate: 148 use=[] 149 for j in range(len(mdl)): 150 tmp = mdl.GetModel(j) 151 if not tmp._trainIndices.get(i,0): 152 use.append(j) 153 else: 154 use=None 155 pt = data[i] 156 pred,conf = mdl.ClassifyExample(pt,onlyModels=use) 157 if needsQuant: 158 pt = mdl.QuantizeActivity(pt[:]) 159 trueRes = pt[-1] 160 if trueRes in picking: 161 nTrueActives+=1 162 if pred in picking: 163 res.append((pt[0],trueRes,pred,conf)) 164 return nTrueActives,res
165
166 -def AccumulateCounts(predictions,thresh=0,sortIt=1):
167 """ Accumulates the data for the enrichment plot for a single model 168 169 **Arguments** 170 171 - predictions: a list of 3-tuples (as returned by _ScreenModels_) 172 173 - thresh: a threshold for the confidence level. Anything below 174 this threshold will not be considered 175 176 - sortIt: toggles sorting on confidence levels 177 178 179 **Returns** 180 181 - a list of 3-tuples: 182 183 - the id of the active picked here 184 185 - num actives found so far 186 187 - number of picks made so far 188 189 """ 190 if sortIt: 191 predictions.sort(lambda x,y:cmp(y[3],x[3])) 192 res = [] 193 nCorrect = 0 194 nPts = 0 195 for i in range(len(predictions)): 196 id,real,pred,conf = predictions[i] 197 if conf > thresh: 198 if pred == real: 199 nCorrect += 1 200 nPts += 1 201 res.append((id,nCorrect,nPts)) 202 203 return res
204
205 -def MakePlot(details,final,counts,pickVects,nModels,nTrueActs=-1):
206 if not hasattr(details,'plotFile') or not details.plotFile: 207 return 208 209 dataFileName = '%s.dat'%(details.plotFile) 210 outF = open(dataFileName,'w+') 211 i = 0 212 while i < len(final) and counts[i] != 0: 213 if nModels>1: 214 mean,sd = Stats.MeanAndDev(pickVects[i]) 215 confInterval = Stats.GetConfidenceInterval(sd,len(pickVects[i]),level=90) 216 outF.write('%d %f %f %d %f\n'%(i+1,final[i][0]/counts[i], 217 final[i][1]/counts[i],counts[i],confInterval)) 218 else: 219 outF.write('%d %f %f %d\n'%(i+1,final[i][0]/counts[i], 220 final[i][1]/counts[i],counts[i])) 221 i+=1 222 outF.close() 223 plotFileName = '%s.gnu'%(details.plotFile) 224 gnuF = open(plotFileName,'w+') 225 gnuHdr="""# Generated by EnrichPlot.py version: %s 226 set size square 0.7 227 set xr [0:] 228 set data styl points 229 set ylab 'Num Correct Picks' 230 set xlab 'Num Picks' 231 set grid 232 set nokey 233 set term postscript enh color solid "Helvetica" 16 234 set term X 235 """%(__VERSION_STRING) 236 print >>gnuF,gnuHdr 237 if nTrueActs >0: 238 print >>gnuF,'set yr [0:%d]'%nTrueActs 239 print >>gnuF,'plot x with lines' 240 if nModels>1: 241 everyGap = i/20 242 print >>gnuF,'replot "%s" using 1:2 with lines,'%(dataFileName), 243 print >>gnuF,'"%s" every %d using 1:2:5 with yerrorbars'%(dataFileName, 244 everyGap) 245 else: 246 print >>gnuF,'replot "%s" with points'%(dataFileName) 247 gnuF.close() 248 249 if hasattr(details,'showPlot') and details.showPlot: 250 try: 251 import os 252 from Gnuplot import Gnuplot 253 p = Gnuplot() 254 #p('cd "%s"'%(os.getcwd())) 255 p('load "%s"'%(plotFileName)) 256 raw_input('press return to continue...\n') 257 except: 258 import traceback 259 traceback.print_exc()
260 261 262 263
264 -def Usage():
265 """ displays a usage message and exits """ 266 sys.stderr.write(__doc__) 267 sys.exit(-1)
268 269 if __name__=='__main__': 270 import getopt 271 try: 272 args,extras = getopt.getopt(sys.argv[1:],'d:t:a:N:p:cSTHF:v:', 273 ('thresh=','plotFile=','showPlot', 274 'pickleCol=','OOB','noSort','pickBase=', 275 'doROC','rocThresh=','enrich=')) 276 except: 277 import traceback 278 traceback.print_exc() 279 Usage() 280 281 282 details = CompositeRun.CompositeRun() 283 CompositeRun.SetDefaults(details) 284 285 details.activeTgt=[1] 286 details.doTraining = 0 287 details.doHoldout = 0 288 details.dbTableName = '' 289 details.plotFile = '' 290 details.showPlot = 0 291 details.pickleCol = -1 292 details.errorEstimate=0 293 details.sortIt=1 294 details.pickBase = '' 295 details.doROC=0 296 details.rocThresh=-1 297 for arg,val in args: 298 if arg == '-d': 299 details.dbName = val 300 if arg == '-t': 301 details.dbTableName = val 302 elif arg == '-a' or arg == '--enrich': 303 details.activeTgt = eval(val) 304 if(type(details.activeTgt) not in (types.TupleType,types.ListType)): 305 details.activeTgt = (details.activeTgt,) 306 307 elif arg == '--thresh': 308 details.threshold = float(val) 309 elif arg == '-N': 310 details.note = val 311 elif arg == '-p': 312 details.persistTblName = val 313 elif arg == '-S': 314 details.shuffleActivities = 1 315 elif arg == '-H': 316 details.doTraining = 0 317 details.doHoldout = 1 318 elif arg == '-T': 319 details.doTraining = 1 320 details.doHoldout = 0 321 elif arg == '-F': 322 details.filterFrac=float(val) 323 elif arg == '-v': 324 details.filterVal=float(val) 325 elif arg == '--plotFile': 326 details.plotFile = val 327 elif arg == '--showPlot': 328 details.showPlot=1 329 elif arg == '--pickleCol': 330 details.pickleCol=int(val)-1 331 elif arg == '--OOB': 332 details.errorEstimate=1 333 elif arg == '--noSort': 334 details.sortIt=0 335 elif arg == '--doROC': 336 details.doROC=1 337 elif arg == '--rocThresh': 338 details.rocThresh=int(val) 339 elif arg == '--pickBase': 340 details.pickBase=val 341 342 if not details.dbName or not details.dbTableName: 343 Usage() 344 print '*******Please provide both the -d and -t arguments' 345 346 message('Building Data set\n') 347 dataSet = DataUtils.DBToData(details.dbName,details.dbTableName, 348 user=RDConfig.defaultDBUser, 349 password=RDConfig.defaultDBPassword, 350 pickleCol=details.pickleCol, 351 pickleClass=DataStructs.ExplicitBitVect) 352 353 descs = dataSet.GetVarNames() 354 nPts = dataSet.GetNPts() 355 message('npts: %d\n'%(nPts)) 356 final = numpy.zeros((nPts,2),numpy.float) 357 counts = numpy.zeros(nPts,numpy.integer) 358 selPts = [None]*nPts 359 360 models = [] 361 if details.persistTblName: 362 conn = DbConnect(details.dbName,details.persistTblName) 363 message('-> Retrieving models from database') 364 curs = conn.GetCursor() 365 curs.execute("select model from %s where note='%s'"%(details.persistTblName,details.note)) 366 message('-> Reconstructing models') 367 try: 368 blob = curs.fetchone() 369 except: 370 blob = None 371 while blob: 372 message(' Building model %d'%len(models)) 373 blob = blob[0] 374 try: 375 models.append(cPickle.loads(str(blob))) 376 except: 377 import traceback 378 traceback.print_exc() 379 print 'Model failed' 380 else: 381 message(' <-Done') 382 try: 383 blob = curs.fetchone() 384 except: 385 blob = None 386 curs = None 387 else: 388 for modelName in extras: 389 try: 390 model = cPickle.load(open(modelName,'rb')) 391 except: 392 import traceback 393 print 'problems with model %s:'%modelName 394 traceback.print_exc() 395 else: 396 models.append(model) 397 nModels = len(models) 398 pickVects = {} 399 halfwayPts = [1e8]*len(models) 400 for whichModel,model in enumerate(models): 401 tmpD = dataSet 402 try: 403 seed = model._randomSeed 404 except AttributeError: 405 pass 406 else: 407 DataUtils.InitRandomNumbers(seed) 408 if details.shuffleActivities: 409 DataUtils.RandomizeActivities(tmpD, 410 shuffle=1) 411 if hasattr(model,'_splitFrac') and (details.doHoldout or details.doTraining): 412 trainIdx,testIdx = SplitData.SplitIndices(tmpD.GetNPts(),model._splitFrac, 413 silent=1) 414 if details.filterFrac != 0.0: 415 trainFilt,temp = DataUtils.FilterData(tmpD,details.filterVal, 416 details.filterFrac,-1, 417 indicesToUse=trainIdx, 418 indicesOnly=1) 419 testIdx += temp 420 trainIdx = trainFilt 421 if details.doTraining: 422 testIdx,trainIdx = trainIdx,testIdx 423 else: 424 testIdx = range(tmpD.GetNPts()) 425 426 message('screening %d examples'%(len(testIdx))) 427 nTrueActives,screenRes = ScreenModel(model,descs,tmpD,picking=details.activeTgt, 428 indices=testIdx, 429 errorEstimate=details.errorEstimate) 430 message('accumulating') 431 runningCounts = AccumulateCounts(screenRes, 432 sortIt=details.sortIt, 433 thresh=details.threshold) 434 if details.pickBase: 435 pickFile = open('%s.%d.picks'%(details.pickBase,whichModel+1),'w+') 436 else: 437 pickFile = None 438 439 440 for i,entry in enumerate(runningCounts): 441 entry = runningCounts[i] 442 selPts[i] = entry[0] 443 final[i][0] += entry[1] 444 final[i][1] += entry[2] 445 v = pickVects.get(i,[]) 446 v.append(entry[1]) 447 pickVects[i] = v 448 counts[i] += 1 449 if pickFile: 450 pickFile.write('%s\n'%(entry[0])) 451 if entry[1] >= nTrueActives/2 and entry[2]<halfwayPts[whichModel]: 452 halfwayPts[whichModel]=entry[2] 453 message('Halfway point: %d\n'%halfwayPts[whichModel]) 454 455 if details.plotFile: 456 MakePlot(details,final,counts,pickVects,nModels,nTrueActs=nTrueActives) 457 else: 458 if nModels>1: 459 print '#Index\tAvg_num_correct\tConf90Pct\tAvg_num_picked\tNum_picks\tlast_selection' 460 else: 461 print '#Index\tAvg_num_correct\tAvg_num_picked\tNum_picks\tlast_selection' 462 463 i = 0 464 while i < nPts and counts[i] != 0: 465 if nModels>1: 466 mean,sd = Stats.MeanAndDev(pickVects[i]) 467 confInterval = Stats.GetConfidenceInterval(sd,len(pickVects[i]),level=90) 468 print '%d\t%f\t%f\t%f\t%d\t%s'%(i+1,final[i][0]/counts[i],confInterval, 469 final[i][1]/counts[i], 470 counts[i],str(selPts[i])) 471 else: 472 print '%d\t%f\t%f\t%d\t%s'%(i+1,final[i][0]/counts[i], 473 final[i][1]/counts[i], 474 counts[i],str(selPts[i])) 475 i += 1 476 477 mean,sd = Stats.MeanAndDev(halfwayPts) 478 print 'Halfway point: %.2f(%.2f)'%(mean,sd) 479