1
2
3
4
5
6
7
8 """
9
10 """
11
12 import numpy
13 import random
14 from rdkit.ML.DecTree import QuantTree, ID3
15 from rdkit.ML.InfoTheory import entropy
16 from rdkit.ML.Data import Quantize
17
18 -def FindBest(resCodes,examples,nBoundsPerVar,nPossibleRes,
19 nPossibleVals,attrs,exIndices=None,**kwargs):
20 bestGain =-1e6
21 best = -1
22 bestBounds = []
23
24 if exIndices is None:
25 exIndices=range(len(examples))
26
27 if not len(exIndices):
28 return best,bestGain,bestBounds
29
30 nToTake = kwargs.get('randomDescriptors',0)
31 if nToTake > 0:
32 nAttrs = len(attrs)
33 if nToTake < nAttrs:
34 ids = range(nAttrs)
35 random.shuffle(ids)
36 tmp = [attrs[x] for x in ids[:nToTake]]
37
38 attrs = tmp
39
40 for var in attrs:
41 nBounds = nBoundsPerVar[var]
42 if nBounds > 0:
43
44 try:
45 vTable = [examples[x][var] for x in exIndices]
46 except IndexError:
47 print 'index error retrieving variable: %d'%var
48 raise
49 qBounds,gainHere = Quantize.FindVarMultQuantBounds(vTable,nBounds,
50 resCodes,nPossibleRes)
51
52 elif nBounds==0:
53 vTable = ID3.GenVarTable((examples[x] for x in exIndices),
54 nPossibleVals,[var])[0]
55 gainHere = entropy.InfoGain(vTable)
56 qBounds = []
57 else:
58 gainHere = -1e6
59 qBounds = []
60 if gainHere > bestGain:
61 bestGain = gainHere
62 bestBounds = qBounds
63 best = var
64 elif bestGain==gainHere:
65 if len(qBounds)<len(bestBounds):
66 best = var
67 bestBounds = qBounds
68 if best == -1:
69 print 'best unaltered'
70 print '\tattrs:',attrs
71 print '\tnBounds:',take(nBoundsPerVar,attrs)
72 print '\texamples:'
73 for example in (examples[x] for x in exIndices):
74 print '\t\t',example
75
76
77 if 0:
78 print 'BEST:',len(exIndices),best,bestGain,bestBounds
79 if(len(exIndices)<10):
80 print len(exIndices),len(resCodes),len(examples)
81 exs = [examples[x] for x in exIndices]
82 vals = [x[best] for x in exs]
83 sortIdx = numpy.argsort(vals)
84 sortVals = [exs[x] for x in sortIdx]
85 sortResults = [resCodes[x] for x in sortIdx]
86 for i in range(len(vals)):
87 print ' ',i,['%.4f'%x for x in sortVals[i][1:-1]],sortResults[i]
88 return best,bestGain,bestBounds
89
90
91 -def BuildQuantTree(examples,target,attrs,nPossibleVals,nBoundsPerVar,
92 depth=0,maxDepth=-1,exIndices=None,**kwargs):
93 """
94 **Arguments**
95
96 - examples: a list of lists (nInstances x nVariables+1) of variable
97 values + instance values
98
99 - target: an int
100
101 - attrs: a list of ints indicating which variables can be used in the tree
102
103 - nPossibleVals: a list containing the number of possible values of
104 every variable.
105
106 - nBoundsPerVar: the number of bounds to include for each variable
107
108 - depth: (optional) the current depth in the tree
109
110 - maxDepth: (optional) the maximum depth to which the tree
111 will be grown
112 **Returns**
113
114 a QuantTree.QuantTreeNode with the decision tree
115
116 **NOTE:** This code cannot bootstrap (start from nothing...)
117 use _QuantTreeBoot_ (below) for that.
118 """
119 tree=QuantTree.QuantTreeNode(None,'node')
120 tree.SetData(-666)
121 nPossibleRes = nPossibleVals[-1]
122
123 if exIndices is None:
124 exIndices=range(len(examples))
125
126
127 resCodes = [int(x[-1]) for x in (examples[y] for y in exIndices)]
128 counts = [0]*nPossibleRes
129 for res in resCodes:
130 counts[res] += 1
131 nzCounts = numpy.nonzero(counts)[0]
132
133 if len(nzCounts) == 1:
134
135
136
137 res = nzCounts[0]
138 tree.SetLabel(res)
139 tree.SetName(str(res))
140 tree.SetTerminal(1)
141 elif len(attrs) == 0 or (maxDepth>=0 and depth>maxDepth):
142
143
144
145
146 v = numpy.argmax(counts)
147 tree.SetLabel(v)
148 tree.SetName('%d?'%v)
149 tree.SetTerminal(1)
150 else:
151
152 best,bestGain,bestBounds = FindBest(resCodes,examples,nBoundsPerVar,
153 nPossibleRes,nPossibleVals,attrs,
154 exIndices=exIndices,
155 **kwargs)
156
157
158 nextAttrs = attrs[:]
159 if not kwargs.get('recycleVars',0):
160 nextAttrs.remove(best)
161
162
163 tree.SetName('Var: %d'%(best))
164 tree.SetLabel(best)
165 tree.SetQuantBounds(bestBounds)
166 tree.SetTerminal(0)
167
168
169
170 indices = exIndices[:]
171 if len(bestBounds) > 0:
172 for bound in bestBounds:
173 nextExamples = []
174 for index in indices[:]:
175 ex = examples[index]
176 if ex[best] < bound:
177 nextExamples.append(index)
178 indices.remove(index)
179
180 if len(nextExamples) == 0:
181
182
183
184 v = numpy.argmax(counts)
185 tree.AddChild('%d'%v,label=v,data=0.0,isTerminal=1)
186 else:
187
188 tree.AddChildNode(BuildQuantTree(examples,best,
189 nextAttrs,nPossibleVals,
190 nBoundsPerVar,
191 depth=depth+1,maxDepth=maxDepth,
192 exIndices=nextExamples,
193 **kwargs))
194
195 nextExamples = []
196 for index in indices:
197 nextExamples.append(index)
198 if len(nextExamples) == 0:
199 v = numpy.argmax(counts)
200 tree.AddChild('%d'%v,label=v,data=0.0,isTerminal=1)
201 else:
202 tree.AddChildNode(BuildQuantTree(examples,best,
203 nextAttrs,nPossibleVals,
204 nBoundsPerVar,
205 depth=depth+1,maxDepth=maxDepth,
206 exIndices=nextExamples,
207 **kwargs))
208 else:
209 for val in xrange(nPossibleVals[best]):
210 nextExamples = []
211 for idx in exIndices:
212 if examples[idx][best] == val:
213 nextExamples.append(idx)
214 if len(nextExamples) == 0:
215 v = numpy.argmax(counts)
216 tree.AddChild('%d'%v,label=v,data=0.0,isTerminal=1)
217 else:
218 tree.AddChildNode(BuildQuantTree(examples,best,
219 nextAttrs,nPossibleVals,
220 nBoundsPerVar,
221 depth=depth+1,maxDepth=maxDepth,
222 exIndices=nextExamples,
223 **kwargs))
224 return tree
225
226 -def QuantTreeBoot(examples,attrs,nPossibleVals,nBoundsPerVar,initialVar=None,
227 maxDepth=-1,**kwargs):
228 """ Bootstrapping code for the QuantTree
229
230 If _initialVar_ is not set, the algorithm will automatically
231 choose the first variable in the tree (the standard greedy
232 approach). Otherwise, _initialVar_ will be used as the first
233 split.
234
235 """
236 attrs = attrs[:]
237 for i in range(len(nBoundsPerVar)):
238 if nBoundsPerVar[i]==-1 and i in attrs:
239 attrs.remove(i)
240
241 tree=QuantTree.QuantTreeNode(None,'node')
242 nPossibleRes = nPossibleVals[-1]
243 tree._nResultCodes = nPossibleRes
244
245 resCodes = [int(x[-1]) for x in examples]
246 counts = [0]*nPossibleRes
247 for res in resCodes:
248 counts[res] += 1
249 if initialVar is None:
250 best,gainHere,qBounds = FindBest(resCodes,examples,nBoundsPerVar,
251 nPossibleRes,nPossibleVals,attrs,
252 **kwargs)
253 else:
254 best = initialVar
255 if nBoundsPerVar[best] > 0:
256 vTable = map(lambda x,z=best:x[z],examples)
257 qBounds,gainHere = Quantize.FindVarMultQuantBounds(vTable,nBoundsPerVar[best],
258 resCodes,nPossibleRes)
259 elif nBoundsPerVar[best] == 0:
260 vTable = ID3.GenVarTable(examples,nPossibleVals,[best])[0]
261 gainHere = entropy.InfoGain(vTable)
262 qBounds = []
263 else:
264 gainHere = -1e6
265 qBounds = []
266
267 tree.SetName('Var: %d'%(best))
268 tree.SetData(gainHere)
269 tree.SetLabel(best)
270 tree.SetTerminal(0)
271 tree.SetQuantBounds(qBounds)
272 nextAttrs = attrs[:]
273 if not kwargs.get('recycleVars',0):
274 nextAttrs.remove(best)
275
276 indices = range(len(examples))
277 if len(qBounds) > 0:
278 for bound in qBounds:
279 nextExamples = []
280 for index in indices[:]:
281 ex = examples[index]
282 if ex[best] < bound:
283 nextExamples.append(ex)
284 indices.remove(index)
285
286 if len(nextExamples):
287 tree.AddChildNode(BuildQuantTree(nextExamples,best,
288 nextAttrs,nPossibleVals,
289 nBoundsPerVar,
290 depth=1,maxDepth=maxDepth,
291 **kwargs))
292 else:
293 v = numpy.argmax(counts)
294 tree.AddChild('%d??'%(v),label=v,data=0.0,isTerminal=1)
295
296 nextExamples = []
297 for index in indices:
298 nextExamples.append(examples[index])
299 if len(nextExamples) != 0:
300 tree.AddChildNode(BuildQuantTree(nextExamples,best,
301 nextAttrs,nPossibleVals,
302 nBoundsPerVar,
303 depth=1,maxDepth=maxDepth,
304 **kwargs))
305 else:
306 v = numpy.argmax(counts)
307 tree.AddChild('%d??'%(v),label=v,data=0.0,isTerminal=1)
308 else:
309 for val in xrange(nPossibleVals[best]):
310 nextExamples = []
311 for example in examples:
312 if example[best] == val:
313 nextExamples.append(example)
314 if len(nextExamples) != 0:
315 tree.AddChildNode(BuildQuantTree(nextExamples,best,
316 nextAttrs,nPossibleVals,
317 nBoundsPerVar,
318 depth=1,maxDepth=maxDepth,
319 **kwargs))
320 else:
321 v = numpy.argmax(counts)
322 tree.AddChild('%d??'%(v),label=v,data=0.0,isTerminal=1)
323 return tree
324
325
327 """ testing code for named trees
328
329 """
330 examples1 = [['p1',0,1,0,0],
331 ['p2',0,0,0,1],
332 ['p3',0,0,1,2],
333 ['p4',0,1,1,2],
334 ['p5',1,0,0,2],
335 ['p6',1,0,1,2],
336 ['p7',1,1,0,2],
337 ['p8',1,1,1,0]
338 ]
339 attrs = range(1,len(examples1[0])-1)
340 nPossibleVals = [0,2,2,2,3]
341 t1 = ID3.ID3Boot(examples1,attrs,nPossibleVals,maxDepth=1)
342 t1.Print()
343
344
346 """ testing code for named trees
347
348 """
349 examples1 = [['p1',0,1,0.1,0],
350 ['p2',0,0,0.1,1],
351 ['p3',0,0,1.1,2],
352 ['p4',0,1,1.1,2],
353 ['p5',1,0,0.1,2],
354 ['p6',1,0,1.1,2],
355 ['p7',1,1,0.1,2],
356 ['p8',1,1,1.1,0]
357 ]
358 attrs = range(1,len(examples1[0])-1)
359 nPossibleVals = [0,2,2,0,3]
360 boundsPerVar=[0,0,0,1,0]
361
362 print 'base'
363 t1 = QuantTreeBoot(examples1,attrs,nPossibleVals,boundsPerVar)
364 t1.Pickle('test_data/QuantTree1.pkl')
365 t1.Print()
366
367 print 'depth limit'
368 t1 = QuantTreeBoot(examples1,attrs,nPossibleVals,boundsPerVar,maxDepth=1)
369 t1.Pickle('test_data/QuantTree1.pkl')
370 t1.Print()
371
373 """ testing code for named trees
374
375 """
376 examples1 = [['p1',0.1,1,0.1,0],
377 ['p2',0.1,0,0.1,1],
378 ['p3',0.1,0,1.1,2],
379 ['p4',0.1,1,1.1,2],
380 ['p5',1.1,0,0.1,2],
381 ['p6',1.1,0,1.1,2],
382 ['p7',1.1,1,0.1,2],
383 ['p8',1.1,1,1.1,0]
384 ]
385 attrs = range(1,len(examples1[0])-1)
386 nPossibleVals = [0,0,2,0,3]
387 boundsPerVar=[0,1,0,1,0]
388
389 t1 = QuantTreeBoot(examples1,attrs,nPossibleVals,boundsPerVar)
390 t1.Print()
391 t1.Pickle('test_data/QuantTree2.pkl')
392
393 for example in examples1:
394 print example,t1.ClassifyExample(example)
395
396 if __name__ == "__main__":
397 TestTree()
398 TestQuantTree()
399
400