1
2
3
4
5
6 """
7
8 """
9
10 from Numeric import *
11 import RandomArray
12 from ML.DecTree import QuantTree, ID3
13 from ML.InfoTheory import entropy
14 from ML.Data import Quantize
15
16 -def FindBest(resCodes,examples,nBoundsPerVar,nPossibleRes,
17 nPossibleVals,attrs,**kwargs):
18 bestGain =-1e6
19 best = -1
20 bestBounds = []
21
22 if not len(examples):
23 return best,bestGain,bestBounds
24
25 nToTake = kwargs.get('randomDescriptors',0)
26 if nToTake > 0:
27 nAttrs = len(attrs)
28 if nToTake < nAttrs:
29 ids = RandomArray.permutation(nAttrs)
30 tmp = [attrs[x] for x in ids[:nToTake]]
31
32 attrs = tmp
33
34 for var in attrs:
35 nBounds = nBoundsPerVar[var]
36 if nBounds > 0:
37
38 try:
39 vTable = [x[var] for x in examples]
40 except IndexError:
41 print 'index error retrieving variable: %d'%var
42 raise
43 qBounds,gainHere = Quantize.FindVarMultQuantBounds(vTable,nBounds,
44 resCodes,nPossibleRes)
45
46 elif nBounds==0:
47 vTable = ID3.GenVarTable(examples,nPossibleVals,[var])[0]
48 gainHere = entropy.InfoGain(vTable)
49 qBounds = []
50 else:
51 gainHere = -1e6
52 qBounds = []
53 if gainHere > bestGain:
54 bestGain = gainHere
55 bestBounds = qBounds
56 best = var
57 if best == -1:
58 print 'best unaltered'
59 print '\tattrs:',attrs
60 print '\tnBounds:',take(nBoundsPerVar,attrs)
61 print '\texamples:'
62 for example in examples:
63 print '\t\t',example
64
65 return best,bestGain,bestBounds
66
67
68 -def BuildQuantTree(examples,target,attrs,nPossibleVals,nBoundsPerVar,
69 depth=0,maxDepth=-1,**kwargs):
70 """
71 **Arguments**
72
73 - examples: a list of lists (nInstances x nVariables+1) of variable
74 values + instance values
75
76 - target: an int
77
78 - attrs: a list of ints indicating which variables can be used in the tree
79
80 - nPossibleVals: a list containing the number of possible values of
81 every variable.
82
83 - nBoundsPerVar: the number of bounds to include for each variable
84
85 - depth: (optional) the current depth in the tree
86
87 - maxDepth: (optional) the maximum depth to which the tree
88 will be grown
89 **Returns**
90
91 a QuantTree.QuantTreeNode with the decision tree
92
93 **NOTE:** This code cannot bootstrap (start from nothing...)
94 use _QuantTreeBoot_ (below) for that.
95 """
96 tree=QuantTree.QuantTreeNode(None,'node')
97 tree.SetData(-666)
98 nPossibleRes = nPossibleVals[-1]
99
100
101 resCodes = [int(x[-1]) for x in examples]
102 counts = [0]*nPossibleRes
103 for res in resCodes:
104 counts[res] += 1
105 nzCounts = nonzero(counts)
106
107 if len(nzCounts) == 1:
108
109
110
111 res = nzCounts[0]
112 tree.SetLabel(res)
113 tree.SetName(str(res))
114 tree.SetTerminal(1)
115 elif len(attrs) == 0 or (maxDepth>=0 and depth>maxDepth):
116
117
118
119
120 v = argmax(counts)
121 tree.SetLabel(v)
122 tree.SetName('%d?'%v)
123 tree.SetTerminal(1)
124 else:
125
126 best,bestGain,bestBounds = FindBest(resCodes,examples,nBoundsPerVar,
127 nPossibleRes,nPossibleVals,attrs,
128 **kwargs)
129
130
131 nextAttrs = attrs[:]
132 if not kwargs.get('recycleVars',0):
133 nextAttrs.remove(best)
134
135
136 tree.SetName('Var: %d'%(best))
137 tree.SetLabel(best)
138 tree.SetQuantBounds(bestBounds)
139 tree.SetTerminal(0)
140
141
142
143 indices = range(len(examples))
144 if len(bestBounds) > 0:
145 for bound in bestBounds:
146 nextExamples = []
147 for index in indices[:]:
148 ex = examples[index]
149 if ex[best] < bound:
150 nextExamples.append(ex)
151 indices.remove(index)
152
153 if len(nextExamples) == 0:
154
155
156
157 v = argmax(counts)
158 tree.AddChild('%d'%v,label=v,data=0.0,isTerminal=1)
159 else:
160
161 tree.AddChildNode(BuildQuantTree(nextExamples,best,
162 nextAttrs,nPossibleVals,
163 nBoundsPerVar,
164 depth=depth+1,maxDepth=maxDepth,
165 **kwargs))
166
167 nextExamples = []
168 for index in indices:
169 nextExamples.append(examples[index])
170 if len(nextExamples) == 0:
171 v = argmax(counts)
172 tree.AddChild('%d'%v,label=v,data=0.0,isTerminal=1)
173 else:
174 tree.AddChildNode(BuildQuantTree(nextExamples,best,
175 nextAttrs,nPossibleVals,
176 nBoundsPerVar,
177 depth=depth+1,maxDepth=maxDepth,
178 **kwargs))
179 else:
180 for val in xrange(nPossibleVals[best]):
181 nextExamples = []
182 for example in examples:
183 if example[best] == val:
184 nextExamples.append(example)
185 if len(nextExamples) == 0:
186 v = argmax(counts)
187 tree.AddChild('%d'%v,label=v,data=0.0,isTerminal=1)
188 else:
189 tree.AddChildNode(BuildQuantTree(nextExamples,best,
190 nextAttrs,nPossibleVals,
191 nBoundsPerVar,
192 depth=depth+1,maxDepth=maxDepth,
193 **kwargs))
194 return tree
195
196 -def QuantTreeBoot(examples,attrs,nPossibleVals,nBoundsPerVar,initialVar=None,
197 maxDepth=-1,**kwargs):
198 """ Bootstrapping code for the QuantTree
199
200 If _initialVar_ is not set, the algorithm will automatically
201 choose the first variable in the tree (the standard greedy
202 approach). Otherwise, _initialVar_ will be used as the first
203 split.
204
205 """
206 attrs = attrs[:]
207 for i in range(len(nBoundsPerVar)):
208 if nBoundsPerVar[i]==-1 and i in attrs:
209 attrs.remove(i)
210
211 tree=QuantTree.QuantTreeNode(None,'node')
212 nPossibleRes = nPossibleVals[-1]
213 tree._nResultCodes = nPossibleRes
214
215 resCodes = [int(x[-1]) for x in examples]
216 counts = [0]*nPossibleRes
217 for res in resCodes:
218 counts[res] += 1
219 if initialVar is None:
220 best,gainHere,qBounds = FindBest(resCodes,examples,nBoundsPerVar,
221 nPossibleRes,nPossibleVals,attrs,
222 **kwargs)
223 else:
224 best = initialVar
225 if nBoundsPerVar[best] > 0:
226 vTable = map(lambda x,z=best:x[z],examples)
227 qBounds,gainHere = Quantize.FindVarMultQuantBounds(vTable,nBoundsPerVar[best],
228 resCodes,nPossibleRes)
229 elif nBoundsPerVar[best] == 0:
230 vTable = ID3.GenVarTable(examples,nPossibleVals,[best])[0]
231 gainHere = entropy.InfoGain(vTable)
232 qBounds = []
233 else:
234 gainHere = -1e6
235 qBounds = []
236
237 tree.SetName('Var: %d'%(best))
238 tree.SetData(gainHere)
239 tree.SetLabel(best)
240 tree.SetTerminal(0)
241 tree.SetQuantBounds(qBounds)
242 nextAttrs = attrs[:]
243 if not kwargs.get('recycleVars',0):
244 nextAttrs.remove(best)
245
246 indices = range(len(examples))
247 if len(qBounds) > 0:
248 for bound in qBounds:
249 nextExamples = []
250 for index in indices[:]:
251 ex = examples[index]
252 if ex[best] < bound:
253 nextExamples.append(ex)
254 indices.remove(index)
255
256 if len(nextExamples):
257 tree.AddChildNode(BuildQuantTree(nextExamples,best,
258 nextAttrs,nPossibleVals,
259 nBoundsPerVar,
260 depth=1,maxDepth=maxDepth,
261 **kwargs))
262 else:
263 v = argmax(counts)
264 tree.AddChild('%d??'%(v),label=v,data=0.0,isTerminal=1)
265
266 nextExamples = []
267 for index in indices:
268 nextExamples.append(examples[index])
269 if len(nextExamples) != 0:
270 tree.AddChildNode(BuildQuantTree(nextExamples,best,
271 nextAttrs,nPossibleVals,
272 nBoundsPerVar,
273 depth=1,maxDepth=maxDepth,
274 **kwargs))
275 else:
276 v = argmax(counts)
277 tree.AddChild('%d??'%(v),label=v,data=0.0,isTerminal=1)
278 else:
279 for val in xrange(nPossibleVals[best]):
280 nextExamples = []
281 for example in examples:
282 if example[best] == val:
283 nextExamples.append(example)
284 if len(nextExamples) != 0:
285 tree.AddChildNode(BuildQuantTree(nextExamples,best,
286 nextAttrs,nPossibleVals,
287 nBoundsPerVar,
288 depth=1,maxDepth=maxDepth,
289 **kwargs))
290 else:
291 v = argmax(counts)
292 tree.AddChild('%d??'%(v),label=v,data=0.0,isTerminal=1)
293 return tree
294
295
297 """ testing code for named trees
298
299 """
300 examples1 = [['p1',0,1,0,0],
301 ['p2',0,0,0,1],
302 ['p3',0,0,1,2],
303 ['p4',0,1,1,2],
304 ['p5',1,0,0,2],
305 ['p6',1,0,1,2],
306 ['p7',1,1,0,2],
307 ['p8',1,1,1,0]
308 ]
309 attrs = range(1,len(examples1[0])-1)
310 nPossibleVals = [0,2,2,2,3]
311 t1 = ID3.ID3Boot(examples1,attrs,nPossibleVals,maxDepth=1)
312 t1.Print()
313
314
316 """ testing code for named trees
317
318 """
319 examples1 = [['p1',0,1,0.1,0],
320 ['p2',0,0,0.1,1],
321 ['p3',0,0,1.1,2],
322 ['p4',0,1,1.1,2],
323 ['p5',1,0,0.1,2],
324 ['p6',1,0,1.1,2],
325 ['p7',1,1,0.1,2],
326 ['p8',1,1,1.1,0]
327 ]
328 attrs = range(1,len(examples1[0])-1)
329 nPossibleVals = [0,2,2,0,3]
330 boundsPerVar=[0,0,0,1,0]
331
332 print 'base'
333 t1 = QuantTreeBoot(examples1,attrs,nPossibleVals,boundsPerVar)
334 t1.Pickle('regress/QuantTree1.pkl')
335 t1.Print()
336
337 print 'depth limit'
338 t1 = QuantTreeBoot(examples1,attrs,nPossibleVals,boundsPerVar,maxDepth=1)
339 t1.Pickle('regress/QuantTree1.pkl')
340 t1.Print()
341
343 """ testing code for named trees
344
345 """
346 examples1 = [['p1',0.1,1,0.1,0],
347 ['p2',0.1,0,0.1,1],
348 ['p3',0.1,0,1.1,2],
349 ['p4',0.1,1,1.1,2],
350 ['p5',1.1,0,0.1,2],
351 ['p6',1.1,0,1.1,2],
352 ['p7',1.1,1,0.1,2],
353 ['p8',1.1,1,1.1,0]
354 ]
355 attrs = range(1,len(examples1[0])-1)
356 nPossibleVals = [0,0,2,0,3]
357 boundsPerVar=[0,1,0,1,0]
358
359 t1 = QuantTreeBoot(examples1,attrs,nPossibleVals,boundsPerVar)
360 t1.Print()
361 t1.Pickle('regress/QuantTree2.pkl')
362
363 for example in examples1:
364 print example,t1.ClassifyExample(example)
365
366 if __name__ == "__main__":
367 TestTree()
368 TestQuantTree()
369
370