|
Package rdkit ::
Package ML ::
Module AnalyzeComposite
|
|
1
2
3
4
5
6
7 """ command line utility to report on the contributions of descriptors to
8 tree-based composite models
9
10 Usage: AnalyzeComposite [optional args] <models>
11
12 <models>: file name(s) of pickled composite model(s)
13 (this is the name of the db table if using a database)
14
15 Optional Arguments:
16
17 -n number: the number of levels of each model to consider
18
19 -d dbname: the database from which to read the models
20
21 -N Note: the note string to search for to pull models from the database
22
23 -X: Send the results to Excel. Note: will alter the current
24 worksheet (by adding data to the end) and only works on
25 systems with Excel installed. It *is* safe to call this
26 multiple times with a single worksheet.
27
28 -v: be verbose whilst screening
29 """
30
31 import numpy
32 import sys,cPickle
33 from rdkit.ML.DecTree import TreeUtils,Tree
34 from rdkit.ML.Data import Stats
35 from rdkit.Dbase.DbConnection import DbConnect
36 from rdkit.ML import ScreenComposite
37 try:
38 from rdkit.Excel.ExcelWrapper import ExcelWrapper as Excel
39 except ImportError:
40 Excel = None
41
42 __VERSION_STRING="2.2.0"
43
44 -def ProcessIt(composites,nToConsider=3,verbose=0,reportToExcel=0):
45 composite=composites[0]
46 nComposites =len(composites)
47 ns = composite.GetDescriptorNames()
48
49 if len(ns)>2:
50 globalRes = {}
51
52 nDone = 1
53 descNames = {}
54 for composite in composites:
55 if verbose > 0:
56 print '#------------------------------------'
57 print 'Doing: ',nDone
58 nModels = len(composite)
59 nDone += 1
60 res = {}
61 for i in range(len(composite)):
62 model = composite.GetModel(i)
63 if isinstance(model,Tree.TreeNode):
64 levels = TreeUtils.CollectLabelLevels(model,{},0,nToConsider)
65 TreeUtils.CollectDescriptorNames(model,descNames,0,nToConsider)
66 for descId in levels.keys():
67 v = res.get(descId,numpy.zeros(nToConsider,numpy.float))
68 v[levels[descId]] += 1./nModels
69 res[descId] = v
70 for k in res:
71 v = globalRes.get(k,numpy.zeros(nToConsider,numpy.float))
72 v += res[k]/nComposites
73 globalRes[k] = v
74 if verbose > 0:
75 for k in res.keys():
76 name = descNames[k]
77 strRes = ', '.join(['%4.2f'%x for x in res[k]])
78 print '%s,%s,%5.4f'%(name,strRes,sum(res[k]))
79
80 print
81
82
83 if verbose >= 0:
84 print '# Average Descriptor Positions'
85 retVal = []
86 if reportToExcel and Excel is not None:
87 xl = Excel()
88 xlCol = 1
89 xlRow = xl.FindLastRow(1,xlCol)
90 xlRow+=1
91 xl[xlRow,xlCol]=' '.join(sys.argv)
92 xlRow+=1
93 else:
94 xl = None
95 for k in globalRes.keys():
96 name = descNames[k]
97 if verbose >= 0:
98 strRes = ', '.join(['%4.2f'%x for x in globalRes[k]])
99 print '%s,%s,%5.4f'%(name,strRes,sum(globalRes[k]))
100 if xl:
101 xlCol=1
102 xl[xlRow,xlCol]=name
103 xlCol += 1
104 for v in globalRes[k]:
105 xl[xlRow,xlCol]=v
106 xlCol+=1
107 xl[xlRow,xlCol]=sum(globalRes[k])
108 xlRow += 1
109 tmp = [name]
110 tmp.extend(globalRes[k])
111 tmp.append(sum(globalRes[k]))
112 retVal.append(tmp)
113 if verbose >= 0:
114 print
115 else:
116 retVal = []
117 return retVal
118
119
121 fields = 'overall_error,holdout_error,overall_result_matrix,holdout_result_matrix,overall_correct_conf,overall_incorrect_conf,holdout_correct_conf,holdout_incorrect_conf'
122 try:
123 data = conn.GetData(fields=fields,where=where)
124 except:
125 import traceback
126 traceback.print_exc()
127 return None
128 nPts = len(data)
129 if not nPts:
130 sys.stderr.write('no runs found\n')
131 return None
132 overall = numpy.zeros(nPts,numpy.float)
133 overallEnrich = numpy.zeros(nPts,numpy.float)
134 oCorConf = 0.0
135 oInCorConf = 0.0
136 holdout = numpy.zeros(nPts,numpy.float)
137 holdoutEnrich = numpy.zeros(nPts,numpy.float)
138 hCorConf = 0.0
139 hInCorConf = 0.0
140 overallMatrix = None
141 holdoutMatrix = None
142 for i in range(nPts):
143 if data[i][0] is not None:
144 overall[i] = data[i][0]
145 oCorConf += data[i][4]
146 oInCorConf += data[i][5]
147 if data[i][1] is not None:
148 holdout[i] = data[i][1]
149 haveHoldout=1
150 else:
151 haveHoldout=0
152 tmpOverall = 1.*eval(data[i][2])
153 if enrich >=0:
154 overallEnrich[i] = ScreenComposite.CalcEnrichment(tmpOverall,tgt=enrich)
155 if haveHoldout:
156 tmpHoldout = 1.*eval(data[i][3])
157 if enrich >=0:
158 holdoutEnrich[i] = ScreenComposite.CalcEnrichment(tmpHoldout,tgt=enrich)
159 if overallMatrix is None:
160 if data[i][2] is not None:
161 overallMatrix = tmpOverall
162 if haveHoldout and data[i][3] is not None:
163 holdoutMatrix = tmpHoldout
164 else:
165 overallMatrix += tmpOverall
166 if haveHoldout:
167 holdoutMatrix += tmpHoldout
168 if haveHoldout:
169 hCorConf += data[i][6]
170 hInCorConf += data[i][7]
171
172 avgOverall = sum(overall)/nPts
173 oCorConf /= nPts
174 oInCorConf /= nPts
175 overallMatrix /= nPts
176 oSort = numpy.argsort(overall)
177 oMin = overall[oSort[0]]
178 overall -= avgOverall
179 devOverall = sqrt(sum(overall**2)/(nPts-1))
180 res = {}
181 res['oAvg'] = 100*avgOverall
182 res['oDev'] = 100*devOverall
183 res['oCorrectConf'] = 100*oCorConf
184 res['oIncorrectConf'] = 100*oInCorConf
185 res['oResultMat']=overallMatrix
186 res['oBestIdx']=oSort[0]
187 res['oBestErr']=100*oMin
188
189 if enrich>=0:
190 mean,dev = Stats.MeanAndDev(overallEnrich)
191 res['oAvgEnrich'] = mean
192 res['oDevEnrich'] = dev
193
194 if haveHoldout:
195 avgHoldout = sum(holdout)/nPts
196 hCorConf /= nPts
197 hInCorConf /= nPts
198 holdoutMatrix /= nPts
199 hSort = numpy.argsort(holdout)
200 hMin = holdout[hSort[0]]
201 holdout -= avgHoldout
202 devHoldout = sqrt(sum(holdout**2)/(nPts-1))
203 res['hAvg'] = 100*avgHoldout
204 res['hDev'] = 100*devHoldout
205 res['hCorrectConf'] = 100*hCorConf
206 res['hIncorrectConf'] = 100*hInCorConf
207 res['hResultMat']=holdoutMatrix
208 res['hBestIdx']=hSort[0]
209 res['hBestErr']=100*hMin
210 if enrich>=0:
211 mean,dev = Stats.MeanAndDev(holdoutEnrich)
212 res['hAvgEnrich'] = mean
213 res['hDevEnrich'] = dev
214 return res
215
217 statD = statD.copy()
218 statD['oBestIdx'] = statD['oBestIdx']+1
219 txt="""
220 # Error Statistics:
221 \tOverall: %(oAvg)6.3f%% (%(oDev)6.3f) %(oCorrectConf)4.1f/%(oIncorrectConf)4.1f
222 \t\tBest: %(oBestIdx)d %(oBestErr)6.3f%%"""%(statD)
223 if statD.has_key('hAvg'):
224 statD['hBestIdx'] = statD['hBestIdx']+1
225 txt += """
226 \tHoldout: %(hAvg)6.3f%% (%(hDev)6.3f) %(hCorrectConf)4.1f/%(hIncorrectConf)4.1f
227 \t\tBest: %(hBestIdx)d %(hBestErr)6.3f%%
228 """%(statD)
229 print txt
230 print
231 print '# Results matrices:'
232 print '\tOverall:'
233 tmp = transpose(statD['oResultMat'])
234 colCounts = sum(tmp)
235 rowCounts = sum(tmp,1)
236 for i in range(len(tmp)):
237 if rowCounts[i]==0: rowCounts[i]=1
238 row = tmp[i]
239 print '\t\t',
240 for j in range(len(row)):
241 print '% 6.2f'%row[j],
242 print '\t| % 4.2f'%(100.*tmp[i,i]/rowCounts[i])
243 print '\t\t',
244 for i in range(len(tmp)):
245 print '------',
246 print
247 print '\t\t',
248 for i in range(len(tmp)):
249 if colCounts[i]==0: colCounts[i]=1
250 print '% 6.2f'%(100.*tmp[i,i]/colCounts[i]),
251 print
252 if enrich>-1 and statD.has_key('oAvgEnrich'):
253 print '\t\tEnrich(%d): %.3f (%.3f)'%(enrich,statD['oAvgEnrich'],statD['oDevEnrich'])
254
255
256 if statD.has_key('hResultMat'):
257 print '\tHoldout:'
258 tmp = transpose(statD['hResultMat'])
259 colCounts = sum(tmp)
260 rowCounts = sum(tmp,1)
261 for i in range(len(tmp)):
262 if rowCounts[i]==0: rowCounts[i]=1
263 row = tmp[i]
264 print '\t\t',
265 for j in range(len(row)):
266 print '% 6.2f'%row[j],
267 print '\t| % 4.2f'%(100.*tmp[i,i]/rowCounts[i])
268 print '\t\t',
269 for i in range(len(tmp)):
270 print '------',
271 print
272 print '\t\t',
273 for i in range(len(tmp)):
274 if colCounts[i]==0: colCounts[i]=1
275 print '% 6.2f'%(100.*tmp[i,i]/colCounts[i]),
276 print
277 if enrich>-1 and statD.has_key('hAvgEnrich'):
278 print '\t\tEnrich(%d): %.3f (%.3f)'%(enrich,statD['hAvgEnrich'],statD['hDevEnrich'])
279
280
281 return
282
283
287
288 if __name__ == "__main__":
289 import getopt
290 try:
291 args,extras = getopt.getopt(sys.argv[1:],'n:d:N:vX',('skip',
292 'enrich=',
293 ))
294 except:
295 Usage()
296
297 count = 3
298 db = None
299 note = ''
300 verbose = 0
301 skip = 0
302 enrich = 1
303 reportToExcel=0
304 for arg,val in args:
305 if arg == '-n':
306 count = int(val)+1
307 elif arg == '-d':
308 db = val
309 elif arg == '-N':
310 note = val
311 elif arg == '-v':
312 verbose = 1
313 elif arg == '-X':
314 if Excel is not None:
315 reportToExcel = 1
316 else:
317 ScreenComposite.message('NOTE: Excel support not enabled, -X option ignored.')
318
319 elif arg == '--skip':
320 skip = 1
321 elif arg == '--enrich':
322 enrich = int(val)
323 composites = []
324 if db is None:
325 for arg in extras:
326 composite = cPickle.load(open(arg,'rb'))
327 composites.append(composite)
328 else:
329 tbl = extras[0]
330 conn = DbConnect(db,tbl)
331 if note:
332 where="where note='%s'"%(note)
333 else:
334 where = ''
335 if not skip:
336 pkls = conn.GetData(fields='model',where=where)
337 composites = []
338 for pkl in pkls:
339 pkl = str(pkl[0])
340 comp = cPickle.loads(pkl)
341 composites.append(comp)
342
343 if len(composites):
344 ProcessIt(composites,count,verbose=verbose,reportToExcel=reportToExcel)
345 elif not skip:
346 print 'ERROR: no composite models found'
347 sys.exit(-1)
348
349 if db:
350 res = ErrorStats(conn,where,enrich=enrich)
351 if res:
352 ShowStats(res)
353