1
2
3
4
5
6
7
8 """Command line tool to construct an enrichment plot from saved composite models
9
10 Usage: EnrichPlot [optional args] -d dbname -t tablename <models>
11
12 Required Arguments:
13 -d "dbName": the name of the database for screening
14
15 -t "tablename": provide the name of the table with the data to be screened
16
17 <models>: file name(s) of pickled composite model(s).
18 If the -p argument is also provided (see below), this argument is ignored.
19
20 Optional Arguments:
21 - -a "list": the list of result codes to be considered active. This will be
22 eval'ed, so be sure that it evaluates as a list or sequence of
23 integers. For example, -a "[1,2]" will consider activity values 1 and 2
24 to be active
25
26 - --enrich "list": identical to the -a argument above.
27
28 - --thresh: sets a threshold for the plot. If the confidence falls below
29 this value, picking will be terminated
30
31 - -H: screen only the hold out set (works only if a version of
32 BuildComposite more recent than 1.2.2 was used).
33
34 - -T: screen only the training set (works only if a version of
35 BuildComposite more recent than 1.2.2 was used).
36
37 - -S: shuffle activity values before screening
38
39 - -R: randomize activity values before screening
40
41 - -F *filter frac*: filters the data before training to change the
42 distribution of activity values in the training set. *filter frac*
43 is the fraction of the training set that should have the target value.
44 **See note in BuildComposite help about data filtering**
45
46 - -v *filter value*: filters the data before training to change the
47 distribution of activity values in the training set. *filter value*
48 is the target value to use in filtering.
49 **See note in BuildComposite help about data filtering**
50
51 - -p "tableName": provides the name of a db table containing the
52 models to be screened. If you use this argument, you should also
53 use the -N argument (below) to specify a note value.
54
55 - -N "note": provides a note to be used to pull models from a db table.
56
57 - --plotFile "filename": writes the data to an output text file (filename.dat)
58 and creates a gnuplot input file (filename.gnu) to plot it
59
60 - --showPlot: causes the gnuplot plot constructed using --plotFile to be
61 displayed in gnuplot.
62
63 """
64 import RDConfig
65 from Numeric import *
66 import cPickle,copy
67
68 from ML.Data import DataUtils,SplitData,Stats
69 from Dbase.DbConnection import DbConnect
70 import DataStructs
71 from ML import CompositeRun
72 import sys,os,types
73
74 __VERSION_STRING="2.3.3"
75 -def message(msg,noRet=0,dest=sys.stderr):
76 """ emits messages to _sys.stderr_
77 override this in modules which import this one to redirect output
78
79 **Arguments**
80
81 - msg: the string to be displayed
82
83 """
84 if noRet:
85 dest.write('%s '%(msg))
86 else:
87 dest.write('%s\n'%(msg))
88 -def error(msg,dest=sys.stderr):
89 """ emits messages to _sys.stderr_
90 override this in modules which import this one to redirect output
91
92 **Arguments**
93
94 - msg: the string to be displayed
95
96 """
97 sys.stderr.write('ERROR: %s\n'%(msg))
98
99 -def ScreenModel(mdl,descs,data,picking=[1],indices=[],errorEstimate=0):
100 """ collects the results of screening an individual composite model that match
101 a particular value
102
103 **Arguments**
104
105 - mdl: the composite model
106
107 - descs: a list of descriptor names corresponding to the data set
108
109 - data: the data set, a list of points to be screened.
110
111 - picking: (Optional) a list of values that are to be collected.
112 For examples, if you want an enrichment plot for picking the values
113 1 and 2, you'd having picking=[1,2].
114
115 **Returns**
116
117 a list of 4-tuples containing:
118
119 - the id of the point
120
121 - the true result (from the data set)
122
123 - the predicted result
124
125 - the confidence value for the prediction
126
127 """
128 mdl.SetInputOrder(descs)
129 res = []
130 if mdl.GetQuantBounds():
131 needsQuant = 1
132 else:
133 needsQuant = 0
134
135 if not indices: indices = range(len(data))
136 nTrueActives=0
137 for i in indices:
138 if errorEstimate:
139 use=[]
140 for j in range(len(mdl)):
141 tmp = mdl.GetModel(j)
142 if not hasattr(tmp,'_trainIndices') or \
143 i not in tmp._trainIndices:
144 use.append(j)
145 else:
146 use=None
147 pt = data[i]
148 pred,conf = mdl.ClassifyExample(pt,onlyModels=use)
149 if needsQuant:
150 pt = mdl.QuantizeActivity(pt[:])
151 trueRes = pt[-1]
152 if trueRes in picking:
153 nTrueActives+=1
154 if pred in picking:
155 res.append((pt[0],trueRes,pred,conf))
156 return nTrueActives,res
157
159 """ Accumulates the data for the enrichment plot for a single model
160
161 **Arguments**
162
163 - predictions: a list of 3-tuples (as returned by _ScreenModels_)
164
165 - thresh: a threshold for the confidence level. Anything below
166 this threshold will not be considered
167
168 - sortIt: toggles sorting on confidence levels
169
170
171 **Returns**
172
173 - a list of 3-tuples:
174
175 - the id of the active picked here
176
177 - num actives found so far
178
179 - number of picks made so far
180
181 """
182 if sortIt:
183 predictions.sort(lambda x,y:cmp(y[3],x[3]))
184 res = []
185 nCorrect = 0
186 nPts = 0
187 for i in range(len(predictions)):
188 id,real,pred,conf = predictions[i]
189 if conf > thresh:
190 if pred == real:
191 nCorrect += 1
192 nPts += 1
193 res.append((id,nCorrect,nPts))
194
195 return res
196
197 -def MakePlot(details,final,counts,pickVects,nModels,nTrueActs=-1):
198 if not hasattr(details,'plotFile') or not details.plotFile:
199 return
200
201 dataFileName = '%s.dat'%(details.plotFile)
202 outF = open(dataFileName,'w+')
203 i = 0
204 while i < len(final) and counts[i] != 0:
205 if nModels>1:
206 mean,sd = Stats.MeanAndDev(pickVects[i])
207 confInterval = Stats.GetConfidenceInterval(sd,len(pickVects[i]),level=90)
208 outF.write('%d %f %f %d %f\n'%(i+1,final[i][0]/counts[i],
209 final[i][1]/counts[i],counts[i],confInterval))
210 else:
211 outF.write('%d %f %f %d\n'%(i+1,final[i][0]/counts[i],
212 final[i][1]/counts[i],counts[i]))
213 i+=1
214 outF.close()
215 plotFileName = '%s.gnu'%(details.plotFile)
216 gnuF = open(plotFileName,'w+')
217 gnuHdr="""# Generated by EnrichPlot.py version: %s
218 set size square 0.7
219 set xr [0:]
220 set data styl points
221 set ylab 'Num Correct Picks'
222 set xlab 'Num Picks'
223 set grid
224 set nokey
225 set term postscript enh color solid "Helvetica" 16
226 set term win
227 """%(__VERSION_STRING)
228 print >>gnuF,gnuHdr
229 if nTrueActs >0:
230 print >>gnuF,'set yr [0:%d]'%nTrueActs
231 print >>gnuF,'plot x with lines'
232 if nModels>1:
233 everyGap = i/20
234 print >>gnuF,'replot "%s" using 1:2 with lines,'%(dataFileName),
235 print >>gnuF,'"%s" every %d using 1:2:5 with yerrorbars'%(dataFileName,
236 everyGap)
237 else:
238 print >>gnuF,'replot "%s" with points'%(dataFileName)
239 gnuF.close()
240
241 if hasattr(details,'showPlot') and details.showPlot:
242 try:
243 import os
244 from Gnuplot import Gnuplot
245 p = Gnuplot()
246
247 p('load "%s"'%(plotFileName))
248 raw_input('press return to continue...\n')
249 except:
250 import traceback
251 traceback.print_exc()
252
253
254
255
257 """ displays a usage message and exits """
258 sys.stderr.write(__doc__)
259 sys.exit(-1)
260
261 if __name__=='__main__':
262 import getopt
263 try:
264 args,extras = getopt.getopt(sys.argv[1:],'d:t:a:N:p:cSTHF:v:',
265 ('thresh=','plotFile=','showPlot',
266 'pickleCol=','OOB','noSort','pickBase=',
267 'doROC','rocThresh=','enrich='))
268 except:
269 import traceback
270 traceback.print_exc()
271 Usage()
272
273
274 details = CompositeRun.CompositeRun()
275 CompositeRun.SetDefaults(details)
276
277 details.activeTgt=[1]
278 details.doTraining = 0
279 details.doHoldout = 0
280 details.dbTableName = ''
281 details.plotFile = ''
282 details.showPlot = 0
283 details.pickleCol = -1
284 details.errorEstimate=0
285 details.sortIt=1
286 details.pickBase = ''
287 details.doROC=0
288 details.rocThresh=-1
289 for arg,val in args:
290 if arg == '-d':
291 details.dbName = val
292 if arg == '-t':
293 details.dbTableName = val
294 elif arg == '-a' or arg == '--enrich':
295 details.activeTgt = eval(val)
296 if(type(details.activeTgt) not in (types.TupleType,types.ListType)):
297 details.activeTgt = (details.activeTgt,)
298
299 elif arg == '--thresh':
300 details.threshold = float(val)
301 elif arg == '-N':
302 details.note = val
303 elif arg == '-p':
304 details.persistTblName = val
305 elif arg == '-S':
306 details.shuffleActivities = 1
307 elif arg == '-H':
308 details.doTraining = 0
309 details.doHoldout = 1
310 elif arg == '-T':
311 details.doTraining = 1
312 details.doHoldout = 0
313 elif arg == '-F':
314 details.filterFrac=float(val)
315 elif arg == '-v':
316 details.filterVal=float(val)
317 elif arg == '--plotFile':
318 details.plotFile = val
319 elif arg == '--showPlot':
320 details.showPlot=1
321 elif arg == '--pickleCol':
322 details.pickleCol=int(val)-1
323 elif arg == '--OOB':
324 details.errorEstimate=1
325 elif arg == '--noSort':
326 details.sortIt=0
327 elif arg == '--doROC':
328 details.doROC=1
329 elif arg == '--rocThresh':
330 details.rocThresh=int(val)
331 elif arg == '--pickBase':
332 details.pickBase=val
333
334 if not details.dbName or not details.dbTableName:
335 Usage()
336 print '*******Please provide both the -d and -t arguments'
337
338 message('Building Data set\n')
339 dataSet = DataUtils.DBToData(details.dbName,details.dbTableName,
340 user=RDConfig.defaultDBUser,
341 password=RDConfig.defaultDBPassword,
342 pickleCol=details.pickleCol,
343 pickleClass=DataStructs.ExplicitBitVect)
344
345 descs = dataSet.GetVarNames()
346 nPts = dataSet.GetNPts()
347 message('npts: %d\n'%(nPts))
348 final = zeros((nPts,2),Float)
349 counts = zeros(nPts,Int)
350 selPts = [None]*nPts
351
352 models = []
353 if details.persistTblName:
354 conn = DbConnect(details.dbName,details.persistTblName)
355 message('-> Retrieving models from database')
356 curs = conn.GetCursor()
357 curs.execute("select model from %s where note='%s'"%(details.persistTblName,details.note))
358 message('-> Reconstructing models')
359 try:
360 blob = curs.fetchone()
361 except:
362 blob = None
363 while blob:
364 message(' Building model %d'%len(models))
365 blob = blob[0]
366 try:
367 models.append(cPickle.loads(str(blob)))
368 except:
369 import traceback
370 traceback.print_exc()
371 print 'Model failed'
372 else:
373 message(' <-Done')
374 try:
375 blob = curs.fetchone()
376 except:
377 blob = None
378 curs = None
379 else:
380 for modelName in extras:
381 try:
382 model = cPickle.load(open(modelName,'rb'))
383 except:
384 import traceback
385 print 'problems with model %s:'%modelName
386 traceback.print_exc()
387 else:
388 models.append(model)
389 nModels = len(models)
390 pickVects = {}
391 halfwayPts = [1e8]*len(models)
392 for whichModel,model in enumerate(models):
393 tmpD = dataSet
394 try:
395 seed = model._randomSeed
396 except AttributeError:
397 pass
398 else:
399 DataUtils.InitRandomNumbers(seed)
400 if details.shuffleActivities:
401 DataUtils.RandomizeActivities(tmpD,
402 shuffle=1)
403 if hasattr(model,'_splitFrac') and (details.doHoldout or details.doTraining):
404 trainIdx,testIdx = SplitData.SplitIndices(tmpD.GetNPts(),model._splitFrac,
405 silent=1)
406 if details.filterFrac != 0.0:
407 trainFilt,temp = DataUtils.FilterData(tmpD,details.filterVal,
408 details.filterFrac,-1,
409 indicesToUse=trainIdx,
410 indicesOnly=1)
411 testIdx += temp
412 trainIdx = trainFilt
413 if details.doTraining:
414 testIdx,trainIdx = trainIdx,testIdx
415 else:
416 testIdx = range(tmpD.GetNPts())
417
418 message('screening %d examples'%(len(testIdx)))
419 nTrueActives,screenRes = ScreenModel(model,descs,tmpD,picking=details.activeTgt,
420 indices=testIdx,
421 errorEstimate=details.errorEstimate)
422 message('accumulating')
423 runningCounts = AccumulateCounts(screenRes,
424 sortIt=details.sortIt,
425 thresh=details.threshold)
426 if details.pickBase:
427 pickFile = open('%s.%d.picks'%(details.pickBase,whichModel+1),'w+')
428 else:
429 pickFile = None
430
431
432 for i,entry in enumerate(runningCounts):
433 entry = runningCounts[i]
434 selPts[i] = entry[0]
435 final[i][0] += entry[1]
436 final[i][1] += entry[2]
437 v = pickVects.get(i,[])
438 v.append(entry[1])
439 pickVects[i] = v
440 counts[i] += 1
441 if pickFile:
442 pickFile.write('%s\n'%(entry[0]))
443 if entry[1] >= nTrueActives/2 and entry[2]<halfwayPts[whichModel]:
444 halfwayPts[whichModel]=entry[2]
445 message('Halfway point: %d\n'%halfwayPts[whichModel])
446
447 if details.plotFile:
448 MakePlot(details,final,counts,pickVects,nModels,nTrueActs=nTrueActives)
449 else:
450 if nModels>1:
451 print '#Index\tAvg_num_correct\tConf90Pct\tAvg_num_picked\tNum_picks\tlast_selection'
452 else:
453 print '#Index\tAvg_num_correct\tAvg_num_picked\tNum_picks\tlast_selection'
454
455 i = 0
456 while i < nPts and counts[i] != 0:
457 if nModels>1:
458 mean,sd = Stats.MeanAndDev(pickVects[i])
459 confInterval = Stats.GetConfidenceInterval(sd,len(pickVects[i]),level=90)
460 print '%d\t%f\t%f\t%f\t%d\t%s'%(i+1,final[i][0]/counts[i],confInterval,
461 final[i][1]/counts[i],
462 counts[i],str(selPts[i]))
463 else:
464 print '%d\t%f\t%f\t%d\t%s'%(i+1,final[i][0]/counts[i],
465 final[i][1]/counts[i],
466 counts[i],str(selPts[i]))
467 i += 1
468
469 mean,sd = Stats.MeanAndDev(halfwayPts)
470 print 'Halfway point: %.2f(%.2f)'%(mean,sd)
471