1
2
3
4
5
6
7
8 """Command line tool to construct an enrichment plot from saved composite models
9
10 Usage: EnrichPlot [optional args] -d dbname -t tablename <models>
11
12 Required Arguments:
13 -d "dbName": the name of the database for screening
14
15 -t "tablename": provide the name of the table with the data to be screened
16
17 <models>: file name(s) of pickled composite model(s).
18 If the -p argument is also provided (see below), this argument is ignored.
19
20 Optional Arguments:
21 - -a "list": the list of result codes to be considered active. This will be
22 eval'ed, so be sure that it evaluates as a list or sequence of
23 integers. For example, -a "[1,2]" will consider activity values 1 and 2
24 to be active
25
26 - --enrich "list": identical to the -a argument above.
27
28 - --thresh: sets a threshold for the plot. If the confidence falls below
29 this value, picking will be terminated
30
31 - -H: screen only the hold out set (works only if a version of
32 BuildComposite more recent than 1.2.2 was used).
33
34 - -T: screen only the training set (works only if a version of
35 BuildComposite more recent than 1.2.2 was used).
36
37 - -S: shuffle activity values before screening
38
39 - -R: randomize activity values before screening
40
41 - -F *filter frac*: filters the data before training to change the
42 distribution of activity values in the training set. *filter frac*
43 is the fraction of the training set that should have the target value.
44 **See note in BuildComposite help about data filtering**
45
46 - -v *filter value*: filters the data before training to change the
47 distribution of activity values in the training set. *filter value*
48 is the target value to use in filtering.
49 **See note in BuildComposite help about data filtering**
50
51 - -p "tableName": provides the name of a db table containing the
52 models to be screened. If you use this argument, you should also
53 use the -N argument (below) to specify a note value.
54
55 - -N "note": provides a note to be used to pull models from a db table.
56
57 - --plotFile "filename": writes the data to an output text file (filename.dat)
58 and creates a gnuplot input file (filename.gnu) to plot it
59
60 - --showPlot: causes the gnuplot plot constructed using --plotFile to be
61 displayed in gnuplot.
62
63 """
64 from rdkit import RDConfig
65 import numpy
66 import cPickle,copy
67
68 from rdkit.ML.Data import DataUtils,SplitData,Stats
69 from rdkit.Dbase.DbConnection import DbConnect
70 from rdkit import DataStructs
71 from rdkit.ML import CompositeRun
72 import sys,os,types
73
74 __VERSION_STRING="2.4.0"
75 -def message(msg,noRet=0,dest=sys.stderr):
76 """ emits messages to _sys.stderr_
77 override this in modules which import this one to redirect output
78
79 **Arguments**
80
81 - msg: the string to be displayed
82
83 """
84 if noRet:
85 dest.write('%s '%(msg))
86 else:
87 dest.write('%s\n'%(msg))
88 -def error(msg,dest=sys.stderr):
89 """ emits messages to _sys.stderr_
90 override this in modules which import this one to redirect output
91
92 **Arguments**
93
94 - msg: the string to be displayed
95
96 """
97 sys.stderr.write('ERROR: %s\n'%(msg))
98
99 -def ScreenModel(mdl,descs,data,picking=[1],indices=[],errorEstimate=0):
100 """ collects the results of screening an individual composite model that match
101 a particular value
102
103 **Arguments**
104
105 - mdl: the composite model
106
107 - descs: a list of descriptor names corresponding to the data set
108
109 - data: the data set, a list of points to be screened.
110
111 - picking: (Optional) a list of values that are to be collected.
112 For examples, if you want an enrichment plot for picking the values
113 1 and 2, you'd having picking=[1,2].
114
115 **Returns**
116
117 a list of 4-tuples containing:
118
119 - the id of the point
120
121 - the true result (from the data set)
122
123 - the predicted result
124
125 - the confidence value for the prediction
126
127 """
128 mdl.SetInputOrder(descs)
129
130 for j in range(len(mdl)):
131 tmp = mdl.GetModel(j)
132 if hasattr(tmp,'_trainIndices') and type(tmp._trainIndices)!=types.DictType:
133 tis = {}
134 if hasattr(tmp,'_trainIndices'):
135 for v in tmp._trainIndices: tis[v]=1
136 tmp._trainIndices=tis
137
138 res = []
139 if mdl.GetQuantBounds():
140 needsQuant = 1
141 else:
142 needsQuant = 0
143
144 if not indices: indices = range(len(data))
145 nTrueActives=0
146 for i in indices:
147 if errorEstimate:
148 use=[]
149 for j in range(len(mdl)):
150 tmp = mdl.GetModel(j)
151 if not tmp._trainIndices.get(i,0):
152 use.append(j)
153 else:
154 use=None
155 pt = data[i]
156 pred,conf = mdl.ClassifyExample(pt,onlyModels=use)
157 if needsQuant:
158 pt = mdl.QuantizeActivity(pt[:])
159 trueRes = pt[-1]
160 if trueRes in picking:
161 nTrueActives+=1
162 if pred in picking:
163 res.append((pt[0],trueRes,pred,conf))
164 return nTrueActives,res
165
167 """ Accumulates the data for the enrichment plot for a single model
168
169 **Arguments**
170
171 - predictions: a list of 3-tuples (as returned by _ScreenModels_)
172
173 - thresh: a threshold for the confidence level. Anything below
174 this threshold will not be considered
175
176 - sortIt: toggles sorting on confidence levels
177
178
179 **Returns**
180
181 - a list of 3-tuples:
182
183 - the id of the active picked here
184
185 - num actives found so far
186
187 - number of picks made so far
188
189 """
190 if sortIt:
191 predictions.sort(lambda x,y:cmp(y[3],x[3]))
192 res = []
193 nCorrect = 0
194 nPts = 0
195 for i in range(len(predictions)):
196 id,real,pred,conf = predictions[i]
197 if conf > thresh:
198 if pred == real:
199 nCorrect += 1
200 nPts += 1
201 res.append((id,nCorrect,nPts))
202
203 return res
204
205 -def MakePlot(details,final,counts,pickVects,nModels,nTrueActs=-1):
206 if not hasattr(details,'plotFile') or not details.plotFile:
207 return
208
209 dataFileName = '%s.dat'%(details.plotFile)
210 outF = open(dataFileName,'w+')
211 i = 0
212 while i < len(final) and counts[i] != 0:
213 if nModels>1:
214 mean,sd = Stats.MeanAndDev(pickVects[i])
215 confInterval = Stats.GetConfidenceInterval(sd,len(pickVects[i]),level=90)
216 outF.write('%d %f %f %d %f\n'%(i+1,final[i][0]/counts[i],
217 final[i][1]/counts[i],counts[i],confInterval))
218 else:
219 outF.write('%d %f %f %d\n'%(i+1,final[i][0]/counts[i],
220 final[i][1]/counts[i],counts[i]))
221 i+=1
222 outF.close()
223 plotFileName = '%s.gnu'%(details.plotFile)
224 gnuF = open(plotFileName,'w+')
225 gnuHdr="""# Generated by EnrichPlot.py version: %s
226 set size square 0.7
227 set xr [0:]
228 set data styl points
229 set ylab 'Num Correct Picks'
230 set xlab 'Num Picks'
231 set grid
232 set nokey
233 set term postscript enh color solid "Helvetica" 16
234 set term X
235 """%(__VERSION_STRING)
236 print >>gnuF,gnuHdr
237 if nTrueActs >0:
238 print >>gnuF,'set yr [0:%d]'%nTrueActs
239 print >>gnuF,'plot x with lines'
240 if nModels>1:
241 everyGap = i/20
242 print >>gnuF,'replot "%s" using 1:2 with lines,'%(dataFileName),
243 print >>gnuF,'"%s" every %d using 1:2:5 with yerrorbars'%(dataFileName,
244 everyGap)
245 else:
246 print >>gnuF,'replot "%s" with points'%(dataFileName)
247 gnuF.close()
248
249 if hasattr(details,'showPlot') and details.showPlot:
250 try:
251 import os
252 from Gnuplot import Gnuplot
253 p = Gnuplot()
254
255 p('load "%s"'%(plotFileName))
256 raw_input('press return to continue...\n')
257 except:
258 import traceback
259 traceback.print_exc()
260
261
262
263
265 """ displays a usage message and exits """
266 sys.stderr.write(__doc__)
267 sys.exit(-1)
268
269 if __name__=='__main__':
270 import getopt
271 try:
272 args,extras = getopt.getopt(sys.argv[1:],'d:t:a:N:p:cSTHF:v:',
273 ('thresh=','plotFile=','showPlot',
274 'pickleCol=','OOB','noSort','pickBase=',
275 'doROC','rocThresh=','enrich='))
276 except:
277 import traceback
278 traceback.print_exc()
279 Usage()
280
281
282 details = CompositeRun.CompositeRun()
283 CompositeRun.SetDefaults(details)
284
285 details.activeTgt=[1]
286 details.doTraining = 0
287 details.doHoldout = 0
288 details.dbTableName = ''
289 details.plotFile = ''
290 details.showPlot = 0
291 details.pickleCol = -1
292 details.errorEstimate=0
293 details.sortIt=1
294 details.pickBase = ''
295 details.doROC=0
296 details.rocThresh=-1
297 for arg,val in args:
298 if arg == '-d':
299 details.dbName = val
300 if arg == '-t':
301 details.dbTableName = val
302 elif arg == '-a' or arg == '--enrich':
303 details.activeTgt = eval(val)
304 if(type(details.activeTgt) not in (types.TupleType,types.ListType)):
305 details.activeTgt = (details.activeTgt,)
306
307 elif arg == '--thresh':
308 details.threshold = float(val)
309 elif arg == '-N':
310 details.note = val
311 elif arg == '-p':
312 details.persistTblName = val
313 elif arg == '-S':
314 details.shuffleActivities = 1
315 elif arg == '-H':
316 details.doTraining = 0
317 details.doHoldout = 1
318 elif arg == '-T':
319 details.doTraining = 1
320 details.doHoldout = 0
321 elif arg == '-F':
322 details.filterFrac=float(val)
323 elif arg == '-v':
324 details.filterVal=float(val)
325 elif arg == '--plotFile':
326 details.plotFile = val
327 elif arg == '--showPlot':
328 details.showPlot=1
329 elif arg == '--pickleCol':
330 details.pickleCol=int(val)-1
331 elif arg == '--OOB':
332 details.errorEstimate=1
333 elif arg == '--noSort':
334 details.sortIt=0
335 elif arg == '--doROC':
336 details.doROC=1
337 elif arg == '--rocThresh':
338 details.rocThresh=int(val)
339 elif arg == '--pickBase':
340 details.pickBase=val
341
342 if not details.dbName or not details.dbTableName:
343 Usage()
344 print '*******Please provide both the -d and -t arguments'
345
346 message('Building Data set\n')
347 dataSet = DataUtils.DBToData(details.dbName,details.dbTableName,
348 user=RDConfig.defaultDBUser,
349 password=RDConfig.defaultDBPassword,
350 pickleCol=details.pickleCol,
351 pickleClass=DataStructs.ExplicitBitVect)
352
353 descs = dataSet.GetVarNames()
354 nPts = dataSet.GetNPts()
355 message('npts: %d\n'%(nPts))
356 final = numpy.zeros((nPts,2),numpy.float)
357 counts = numpy.zeros(nPts,numpy.integer)
358 selPts = [None]*nPts
359
360 models = []
361 if details.persistTblName:
362 conn = DbConnect(details.dbName,details.persistTblName)
363 message('-> Retrieving models from database')
364 curs = conn.GetCursor()
365 curs.execute("select model from %s where note='%s'"%(details.persistTblName,details.note))
366 message('-> Reconstructing models')
367 try:
368 blob = curs.fetchone()
369 except:
370 blob = None
371 while blob:
372 message(' Building model %d'%len(models))
373 blob = blob[0]
374 try:
375 models.append(cPickle.loads(str(blob)))
376 except:
377 import traceback
378 traceback.print_exc()
379 print 'Model failed'
380 else:
381 message(' <-Done')
382 try:
383 blob = curs.fetchone()
384 except:
385 blob = None
386 curs = None
387 else:
388 for modelName in extras:
389 try:
390 model = cPickle.load(open(modelName,'rb'))
391 except:
392 import traceback
393 print 'problems with model %s:'%modelName
394 traceback.print_exc()
395 else:
396 models.append(model)
397 nModels = len(models)
398 pickVects = {}
399 halfwayPts = [1e8]*len(models)
400 for whichModel,model in enumerate(models):
401 tmpD = dataSet
402 try:
403 seed = model._randomSeed
404 except AttributeError:
405 pass
406 else:
407 DataUtils.InitRandomNumbers(seed)
408 if details.shuffleActivities:
409 DataUtils.RandomizeActivities(tmpD,
410 shuffle=1)
411 if hasattr(model,'_splitFrac') and (details.doHoldout or details.doTraining):
412 trainIdx,testIdx = SplitData.SplitIndices(tmpD.GetNPts(),model._splitFrac,
413 silent=1)
414 if details.filterFrac != 0.0:
415 trainFilt,temp = DataUtils.FilterData(tmpD,details.filterVal,
416 details.filterFrac,-1,
417 indicesToUse=trainIdx,
418 indicesOnly=1)
419 testIdx += temp
420 trainIdx = trainFilt
421 if details.doTraining:
422 testIdx,trainIdx = trainIdx,testIdx
423 else:
424 testIdx = range(tmpD.GetNPts())
425
426 message('screening %d examples'%(len(testIdx)))
427 nTrueActives,screenRes = ScreenModel(model,descs,tmpD,picking=details.activeTgt,
428 indices=testIdx,
429 errorEstimate=details.errorEstimate)
430 message('accumulating')
431 runningCounts = AccumulateCounts(screenRes,
432 sortIt=details.sortIt,
433 thresh=details.threshold)
434 if details.pickBase:
435 pickFile = open('%s.%d.picks'%(details.pickBase,whichModel+1),'w+')
436 else:
437 pickFile = None
438
439
440 for i,entry in enumerate(runningCounts):
441 entry = runningCounts[i]
442 selPts[i] = entry[0]
443 final[i][0] += entry[1]
444 final[i][1] += entry[2]
445 v = pickVects.get(i,[])
446 v.append(entry[1])
447 pickVects[i] = v
448 counts[i] += 1
449 if pickFile:
450 pickFile.write('%s\n'%(entry[0]))
451 if entry[1] >= nTrueActives/2 and entry[2]<halfwayPts[whichModel]:
452 halfwayPts[whichModel]=entry[2]
453 message('Halfway point: %d\n'%halfwayPts[whichModel])
454
455 if details.plotFile:
456 MakePlot(details,final,counts,pickVects,nModels,nTrueActs=nTrueActives)
457 else:
458 if nModels>1:
459 print '#Index\tAvg_num_correct\tConf90Pct\tAvg_num_picked\tNum_picks\tlast_selection'
460 else:
461 print '#Index\tAvg_num_correct\tAvg_num_picked\tNum_picks\tlast_selection'
462
463 i = 0
464 while i < nPts and counts[i] != 0:
465 if nModels>1:
466 mean,sd = Stats.MeanAndDev(pickVects[i])
467 confInterval = Stats.GetConfidenceInterval(sd,len(pickVects[i]),level=90)
468 print '%d\t%f\t%f\t%f\t%d\t%s'%(i+1,final[i][0]/counts[i],confInterval,
469 final[i][1]/counts[i],
470 counts[i],str(selPts[i]))
471 else:
472 print '%d\t%f\t%f\t%d\t%s'%(i+1,final[i][0]/counts[i],
473 final[i][1]/counts[i],
474 counts[i],str(selPts[i]))
475 i += 1
476
477 mean,sd = Stats.MeanAndDev(halfwayPts)
478 print 'Halfway point: %.2f(%.2f)'%(mean,sd)
479