Package ML :: Module EnrichPlot
[hide private]
[frames] | no frames]

Source Code for Module ML.EnrichPlot

  1  # $Id: EnrichPlot.py 2 2006-05-06 22:54:39Z glandrum $ 
  2  # 
  3  #  Copyright (C) 2002-2006  greg Landrum and Rational Discovery LLC 
  4  # 
  5  #   @@ All Rights Reserved  @@ 
  6  # 
  7   
  8  """Command line tool to construct an enrichment plot from saved composite models 
  9   
 10  Usage:  EnrichPlot [optional args] -d dbname -t tablename <models> 
 11   
 12  Required Arguments: 
 13    -d "dbName": the name of the database for screening 
 14   
 15    -t "tablename": provide the name of the table with the data to be screened 
 16   
 17    <models>: file name(s) of pickled composite model(s). 
 18       If the -p argument is also provided (see below), this argument is ignored. 
 19        
 20  Optional Arguments: 
 21    - -a "list": the list of result codes to be considered active.  This will be 
 22          eval'ed, so be sure that it evaluates as a list or sequence of 
 23          integers. For example, -a "[1,2]" will consider activity values 1 and 2 
 24          to be active 
 25   
 26    - --enrich "list": identical to the -a argument above.       
 27   
 28    - --thresh: sets a threshold for the plot.  If the confidence falls below 
 29            this value, picking will be terminated 
 30   
 31    - -H: screen only the hold out set (works only if a version of  
 32          BuildComposite more recent than 1.2.2 was used). 
 33   
 34    - -T: screen only the training set (works only if a version of  
 35          BuildComposite more recent than 1.2.2 was used). 
 36   
 37    - -S: shuffle activity values before screening 
 38   
 39    - -R: randomize activity values before screening 
 40   
 41    - -F *filter frac*: filters the data before training to change the 
 42       distribution of activity values in the training set.  *filter frac* 
 43       is the fraction of the training set that should have the target value. 
 44       **See note in BuildComposite help about data filtering** 
 45   
 46    - -v *filter value*: filters the data before training to change the 
 47       distribution of activity values in the training set. *filter value* 
 48       is the target value to use in filtering. 
 49       **See note in BuildComposite help about data filtering** 
 50   
 51    - -p "tableName": provides the name of a db table containing the 
 52        models to be screened.  If you use this argument, you should also 
 53        use the -N argument (below) to specify a note value. 
 54         
 55    - -N "note": provides a note to be used to pull models from a db table. 
 56   
 57    - --plotFile "filename": writes the data to an output text file (filename.dat) 
 58      and creates a gnuplot input file (filename.gnu) to plot it 
 59   
 60    - --showPlot: causes the gnuplot plot constructed using --plotFile to be 
 61      displayed in gnuplot. 
 62   
 63  """ 
 64  import RDConfig 
 65  from Numeric import * 
 66  import cPickle,copy 
 67  #from Dbase.DbConnection import DbConnect 
 68  from ML.Data import DataUtils,SplitData,Stats 
 69  from Dbase.DbConnection import DbConnect 
 70  import DataStructs 
 71  from ML import CompositeRun 
 72  import sys,os,types 
 73   
 74  __VERSION_STRING="2.3.3" 
75 -def message(msg,noRet=0,dest=sys.stderr):
76 """ emits messages to _sys.stderr_ 77 override this in modules which import this one to redirect output 78 79 **Arguments** 80 81 - msg: the string to be displayed 82 83 """ 84 if noRet: 85 dest.write('%s '%(msg)) 86 else: 87 dest.write('%s\n'%(msg))
88 -def error(msg,dest=sys.stderr):
89 """ emits messages to _sys.stderr_ 90 override this in modules which import this one to redirect output 91 92 **Arguments** 93 94 - msg: the string to be displayed 95 96 """ 97 sys.stderr.write('ERROR: %s\n'%(msg))
98
99 -def ScreenModel(mdl,descs,data,picking=[1],indices=[],errorEstimate=0):
100 """ collects the results of screening an individual composite model that match 101 a particular value 102 103 **Arguments** 104 105 - mdl: the composite model 106 107 - descs: a list of descriptor names corresponding to the data set 108 109 - data: the data set, a list of points to be screened. 110 111 - picking: (Optional) a list of values that are to be collected. 112 For examples, if you want an enrichment plot for picking the values 113 1 and 2, you'd having picking=[1,2]. 114 115 **Returns** 116 117 a list of 4-tuples containing: 118 119 - the id of the point 120 121 - the true result (from the data set) 122 123 - the predicted result 124 125 - the confidence value for the prediction 126 127 """ 128 mdl.SetInputOrder(descs) 129 res = [] 130 if mdl.GetQuantBounds(): 131 needsQuant = 1 132 else: 133 needsQuant = 0 134 135 if not indices: indices = range(len(data)) 136 nTrueActives=0 137 for i in indices: 138 if errorEstimate: 139 use=[] 140 for j in range(len(mdl)): 141 tmp = mdl.GetModel(j) 142 if not hasattr(tmp,'_trainIndices') or \ 143 i not in tmp._trainIndices: 144 use.append(j) 145 else: 146 use=None 147 pt = data[i] 148 pred,conf = mdl.ClassifyExample(pt,onlyModels=use) 149 if needsQuant: 150 pt = mdl.QuantizeActivity(pt[:]) 151 trueRes = pt[-1] 152 if trueRes in picking: 153 nTrueActives+=1 154 if pred in picking: 155 res.append((pt[0],trueRes,pred,conf)) 156 return nTrueActives,res
157
158 -def AccumulateCounts(predictions,thresh=0,sortIt=1):
159 """ Accumulates the data for the enrichment plot for a single model 160 161 **Arguments** 162 163 - predictions: a list of 3-tuples (as returned by _ScreenModels_) 164 165 - thresh: a threshold for the confidence level. Anything below 166 this threshold will not be considered 167 168 - sortIt: toggles sorting on confidence levels 169 170 171 **Returns** 172 173 - a list of 3-tuples: 174 175 - the id of the active picked here 176 177 - num actives found so far 178 179 - number of picks made so far 180 181 """ 182 if sortIt: 183 predictions.sort(lambda x,y:cmp(y[3],x[3])) 184 res = [] 185 nCorrect = 0 186 nPts = 0 187 for i in range(len(predictions)): 188 id,real,pred,conf = predictions[i] 189 if conf > thresh: 190 if pred == real: 191 nCorrect += 1 192 nPts += 1 193 res.append((id,nCorrect,nPts)) 194 195 return res
196
197 -def MakePlot(details,final,counts,pickVects,nModels,nTrueActs=-1):
198 if not hasattr(details,'plotFile') or not details.plotFile: 199 return 200 201 dataFileName = '%s.dat'%(details.plotFile) 202 outF = open(dataFileName,'w+') 203 i = 0 204 while i < len(final) and counts[i] != 0: 205 if nModels>1: 206 mean,sd = Stats.MeanAndDev(pickVects[i]) 207 confInterval = Stats.GetConfidenceInterval(sd,len(pickVects[i]),level=90) 208 outF.write('%d %f %f %d %f\n'%(i+1,final[i][0]/counts[i], 209 final[i][1]/counts[i],counts[i],confInterval)) 210 else: 211 outF.write('%d %f %f %d\n'%(i+1,final[i][0]/counts[i], 212 final[i][1]/counts[i],counts[i])) 213 i+=1 214 outF.close() 215 plotFileName = '%s.gnu'%(details.plotFile) 216 gnuF = open(plotFileName,'w+') 217 gnuHdr="""# Generated by EnrichPlot.py version: %s 218 set size square 0.7 219 set xr [0:] 220 set data styl points 221 set ylab 'Num Correct Picks' 222 set xlab 'Num Picks' 223 set grid 224 set nokey 225 set term postscript enh color solid "Helvetica" 16 226 set term win 227 """%(__VERSION_STRING) 228 print >>gnuF,gnuHdr 229 if nTrueActs >0: 230 print >>gnuF,'set yr [0:%d]'%nTrueActs 231 print >>gnuF,'plot x with lines' 232 if nModels>1: 233 everyGap = i/20 234 print >>gnuF,'replot "%s" using 1:2 with lines,'%(dataFileName), 235 print >>gnuF,'"%s" every %d using 1:2:5 with yerrorbars'%(dataFileName, 236 everyGap) 237 else: 238 print >>gnuF,'replot "%s" with points'%(dataFileName) 239 gnuF.close() 240 241 if hasattr(details,'showPlot') and details.showPlot: 242 try: 243 import os 244 from Gnuplot import Gnuplot 245 p = Gnuplot() 246 #p('cd "%s"'%(os.getcwd())) 247 p('load "%s"'%(plotFileName)) 248 raw_input('press return to continue...\n') 249 except: 250 import traceback 251 traceback.print_exc()
252 253 254 255
256 -def Usage():
257 """ displays a usage message and exits """ 258 sys.stderr.write(__doc__) 259 sys.exit(-1)
260 261 if __name__=='__main__': 262 import getopt 263 try: 264 args,extras = getopt.getopt(sys.argv[1:],'d:t:a:N:p:cSTHF:v:', 265 ('thresh=','plotFile=','showPlot', 266 'pickleCol=','OOB','noSort','pickBase=', 267 'doROC','rocThresh=','enrich=')) 268 except: 269 import traceback 270 traceback.print_exc() 271 Usage() 272 273 274 details = CompositeRun.CompositeRun() 275 CompositeRun.SetDefaults(details) 276 277 details.activeTgt=[1] 278 details.doTraining = 0 279 details.doHoldout = 0 280 details.dbTableName = '' 281 details.plotFile = '' 282 details.showPlot = 0 283 details.pickleCol = -1 284 details.errorEstimate=0 285 details.sortIt=1 286 details.pickBase = '' 287 details.doROC=0 288 details.rocThresh=-1 289 for arg,val in args: 290 if arg == '-d': 291 details.dbName = val 292 if arg == '-t': 293 details.dbTableName = val 294 elif arg == '-a' or arg == '--enrich': 295 details.activeTgt = eval(val) 296 if(type(details.activeTgt) not in (types.TupleType,types.ListType)): 297 details.activeTgt = (details.activeTgt,) 298 299 elif arg == '--thresh': 300 details.threshold = float(val) 301 elif arg == '-N': 302 details.note = val 303 elif arg == '-p': 304 details.persistTblName = val 305 elif arg == '-S': 306 details.shuffleActivities = 1 307 elif arg == '-H': 308 details.doTraining = 0 309 details.doHoldout = 1 310 elif arg == '-T': 311 details.doTraining = 1 312 details.doHoldout = 0 313 elif arg == '-F': 314 details.filterFrac=float(val) 315 elif arg == '-v': 316 details.filterVal=float(val) 317 elif arg == '--plotFile': 318 details.plotFile = val 319 elif arg == '--showPlot': 320 details.showPlot=1 321 elif arg == '--pickleCol': 322 details.pickleCol=int(val)-1 323 elif arg == '--OOB': 324 details.errorEstimate=1 325 elif arg == '--noSort': 326 details.sortIt=0 327 elif arg == '--doROC': 328 details.doROC=1 329 elif arg == '--rocThresh': 330 details.rocThresh=int(val) 331 elif arg == '--pickBase': 332 details.pickBase=val 333 334 if not details.dbName or not details.dbTableName: 335 Usage() 336 print '*******Please provide both the -d and -t arguments' 337 338 message('Building Data set\n') 339 dataSet = DataUtils.DBToData(details.dbName,details.dbTableName, 340 user=RDConfig.defaultDBUser, 341 password=RDConfig.defaultDBPassword, 342 pickleCol=details.pickleCol, 343 pickleClass=DataStructs.ExplicitBitVect) 344 345 descs = dataSet.GetVarNames() 346 nPts = dataSet.GetNPts() 347 message('npts: %d\n'%(nPts)) 348 final = zeros((nPts,2),Float) 349 counts = zeros(nPts,Int) 350 selPts = [None]*nPts 351 352 models = [] 353 if details.persistTblName: 354 conn = DbConnect(details.dbName,details.persistTblName) 355 message('-> Retrieving models from database') 356 curs = conn.GetCursor() 357 curs.execute("select model from %s where note='%s'"%(details.persistTblName,details.note)) 358 message('-> Reconstructing models') 359 try: 360 blob = curs.fetchone() 361 except: 362 blob = None 363 while blob: 364 message(' Building model %d'%len(models)) 365 blob = blob[0] 366 try: 367 models.append(cPickle.loads(str(blob))) 368 except: 369 import traceback 370 traceback.print_exc() 371 print 'Model failed' 372 else: 373 message(' <-Done') 374 try: 375 blob = curs.fetchone() 376 except: 377 blob = None 378 curs = None 379 else: 380 for modelName in extras: 381 try: 382 model = cPickle.load(open(modelName,'rb')) 383 except: 384 import traceback 385 print 'problems with model %s:'%modelName 386 traceback.print_exc() 387 else: 388 models.append(model) 389 nModels = len(models) 390 pickVects = {} 391 halfwayPts = [1e8]*len(models) 392 for whichModel,model in enumerate(models): 393 tmpD = dataSet 394 try: 395 seed = model._randomSeed 396 except AttributeError: 397 pass 398 else: 399 DataUtils.InitRandomNumbers(seed) 400 if details.shuffleActivities: 401 DataUtils.RandomizeActivities(tmpD, 402 shuffle=1) 403 if hasattr(model,'_splitFrac') and (details.doHoldout or details.doTraining): 404 trainIdx,testIdx = SplitData.SplitIndices(tmpD.GetNPts(),model._splitFrac, 405 silent=1) 406 if details.filterFrac != 0.0: 407 trainFilt,temp = DataUtils.FilterData(tmpD,details.filterVal, 408 details.filterFrac,-1, 409 indicesToUse=trainIdx, 410 indicesOnly=1) 411 testIdx += temp 412 trainIdx = trainFilt 413 if details.doTraining: 414 testIdx,trainIdx = trainIdx,testIdx 415 else: 416 testIdx = range(tmpD.GetNPts()) 417 418 message('screening %d examples'%(len(testIdx))) 419 nTrueActives,screenRes = ScreenModel(model,descs,tmpD,picking=details.activeTgt, 420 indices=testIdx, 421 errorEstimate=details.errorEstimate) 422 message('accumulating') 423 runningCounts = AccumulateCounts(screenRes, 424 sortIt=details.sortIt, 425 thresh=details.threshold) 426 if details.pickBase: 427 pickFile = open('%s.%d.picks'%(details.pickBase,whichModel+1),'w+') 428 else: 429 pickFile = None 430 431 432 for i,entry in enumerate(runningCounts): 433 entry = runningCounts[i] 434 selPts[i] = entry[0] 435 final[i][0] += entry[1] 436 final[i][1] += entry[2] 437 v = pickVects.get(i,[]) 438 v.append(entry[1]) 439 pickVects[i] = v 440 counts[i] += 1 441 if pickFile: 442 pickFile.write('%s\n'%(entry[0])) 443 if entry[1] >= nTrueActives/2 and entry[2]<halfwayPts[whichModel]: 444 halfwayPts[whichModel]=entry[2] 445 message('Halfway point: %d\n'%halfwayPts[whichModel]) 446 447 if details.plotFile: 448 MakePlot(details,final,counts,pickVects,nModels,nTrueActs=nTrueActives) 449 else: 450 if nModels>1: 451 print '#Index\tAvg_num_correct\tConf90Pct\tAvg_num_picked\tNum_picks\tlast_selection' 452 else: 453 print '#Index\tAvg_num_correct\tAvg_num_picked\tNum_picks\tlast_selection' 454 455 i = 0 456 while i < nPts and counts[i] != 0: 457 if nModels>1: 458 mean,sd = Stats.MeanAndDev(pickVects[i]) 459 confInterval = Stats.GetConfidenceInterval(sd,len(pickVects[i]),level=90) 460 print '%d\t%f\t%f\t%f\t%d\t%s'%(i+1,final[i][0]/counts[i],confInterval, 461 final[i][1]/counts[i], 462 counts[i],str(selPts[i])) 463 else: 464 print '%d\t%f\t%f\t%d\t%s'%(i+1,final[i][0]/counts[i], 465 final[i][1]/counts[i], 466 counts[i],str(selPts[i])) 467 i += 1 468 469 mean,sd = Stats.MeanAndDev(halfwayPts) 470 print 'Halfway point: %.2f(%.2f)'%(mean,sd) 471