1
2
3
4
5
6 """
7
8 """
9
10 from Numeric import *
11 import RandomArray
12 from ML.DecTree import SigTree
13 from ML import InfoTheory
14 from ML.FeatureSelect import CMIM
15 from DataStructs.VectCollection import VectCollection
16 import copy
17 import random
19 """ Generates a random subset of a group of indices
20
21 **Arguments**
22
23 - nToInclude: the size of the desired set
24
25 - nBits: the maximum index to be included in the set
26
27 **Returns**
28
29 a list of indices
30
31 """
32
33
34 res = random.sample(xrange(nBits),nToInclude)
35 return res
36
37 -def BuildSigTree(examples,nPossibleRes,ensemble=None,random=0,
38 metric=InfoTheory.InfoType.BIASENTROPY,
39 biasList=[1],
40 depth=0,maxDepth=-1,
41 useCMIM=0,allowCollections=False,
42 verbose=0,**kwargs):
43 """
44 **Arguments**
45
46 - examples: the examples to be classified. Each example
47 should be a sequence at least three entries long, with
48 entry 0 being a label, entry 1 a BitVector and entry -1
49 an activity value
50
51 - nPossibleRes: the number of result codes possible
52
53 - ensemble: (optional) if this argument is provided, it
54 should be a sequence which is used to limit the bits
55 which are actually considered as potential descriptors.
56 The default is None (use all bits).
57
58 - random: (optional) If this argument is nonzero, it
59 specifies the number of bits to be randomly selected
60 for consideration at this node (i.e. this toggles the
61 growth of Random Trees).
62 The default is 0 (no random descriptor selection)
63
64 - metric: (optional) This is an _InfoTheory.InfoType_ and
65 sets the metric used to rank the bits.
66 The default is _InfoTheory.InfoType.BIASENTROPY_
67
68 - biasList: (optional) If provided, this provides a bias
69 list for the bit ranker.
70 See the _InfoTheory.InfoBitRanker_ docs for an explanation
71 of bias.
72 The default value is [1], which biases towards actives.
73
74 - maxDepth: (optional) the maximum depth to which the tree
75 will be grown
76 The default is -1 (no depth limit).
77
78 - useCMIM: (optional) if this is >0, the CMIM algorithm
79 (conditional mutual information maximization) will be
80 used to select the descriptors used to build the trees.
81 The value of the variable should be set to the number
82 of descriptors to be used. This option and the
83 ensemble option are mutually exclusive (CMIM will not be
84 used if the ensemble is set), but it happily coexsts
85 with the random argument (to only consider random subsets
86 of the top N CMIM bits)
87 The default is 0 (do not use CMIM)
88
89 - depth: (optional) the current depth in the tree
90 This is used in the recursion and should not be set
91 by the client.
92
93 **Returns**
94
95 a SigTree.SigTreeNode with the root of the decision tree
96
97 """
98 if verbose: print ' '*depth,'Build'
99 tree=SigTree.SigTreeNode(None,'node',level=depth)
100 tree.SetData(-666)
101
102
103
104
105 resCodes = [int(x[-1]) for x in examples]
106
107 counts = [0]*nPossibleRes
108 for res in resCodes:
109 counts[res] += 1
110
111
112 nzCounts = nonzero(counts)
113 if verbose: print ' '*depth,'\tcounts:',counts
114 if len(nzCounts) == 1:
115
116
117
118 res = nzCounts[0]
119 tree.SetLabel(res)
120 tree.SetName(str(res))
121 tree.SetTerminal(1)
122 elif maxDepth>=0 and depth>maxDepth:
123
124
125
126
127 v = argmax(counts)
128 tree.SetLabel(v)
129 tree.SetName('%d?'%v)
130 tree.SetTerminal(1)
131 else:
132
133
134 fp = examples[0][1]
135 nBits = fp.GetNumBits()
136 ranker = InfoTheory.InfoBitRanker(nBits,nPossibleRes,metric)
137 if biasList: ranker.SetBiasList(biasList)
138 if useCMIM > 0 and not ensemble:
139 ensemble = CMIM.SelectFeatures(examples,useCMIM,bvCol=1)
140 if random:
141 if ensemble:
142 if len(ensemble)>random:
143 picks = _GenerateRandomEnsemble(random,len(ensemble))
144 availBits = list(take(ensemble,picks))
145 else:
146 availBits = range(len(ensemble))
147 else:
148 availBits = _GenerateRandomEnsemble(random,nBits)
149 else:
150 availBits=None
151 if availBits:
152 ranker.SetMaskBits(availBits)
153
154
155 useCollections=isinstance(examples[0][1],VectCollection)
156 for example in examples:
157
158 if not useCollections:
159 ranker.AccumulateVotes(example[1],example[-1])
160 else:
161 example[1].Reset()
162 ranker.AccumulateVotes(example[1].orVect,example[-1])
163
164 try:
165 bitInfo = ranker.GetTopN(1)[0]
166 best = int(bitInfo[0])
167 gain = bitInfo[1]
168 except:
169 import traceback
170 traceback.print_exc()
171 print 'get top n failed'
172 gain = -1.0
173 if gain <= 0.0:
174 v = argmax(counts)
175 tree.SetLabel(v)
176 tree.SetName('?%d?'%v)
177 tree.SetTerminal(1)
178 return tree
179 best = int(bitInfo[0])
180
181 if verbose: print ' '*depth,'\tbest:',bitInfo
182
183 tree.SetName('Bit-%d'%(best))
184 tree.SetLabel(best)
185
186 tree.SetTerminal(0)
187
188
189
190 onExamples = []
191 offExamples = []
192 for example in examples:
193 if example[1][best]:
194 if allowCollections and useCollections:
195 sig = copy.copy(example[1])
196 sig.DetachVectsNotMatchingBit(best)
197 ex = [example[0],sig]
198 if len(example)>2:
199 ex.extend(example[2:])
200 example = ex
201 onExamples.append(example)
202 else:
203 offExamples.append(example)
204
205 for ex in (offExamples,onExamples):
206 if len(ex) == 0:
207 v = argmax(counts)
208 tree.AddChild('%d??'%v,label=v,data=0.0,isTerminal=1)
209 else:
210 child = BuildSigTree(ex,nPossibleRes,random=random,
211 ensemble=ensemble,
212 metric=metric,biasList=biasList,
213 depth=depth+1,maxDepth=maxDepth,
214 verbose=verbose)
215 if child is None:
216 v = argmax(counts)
217 tree.AddChild('%d???'%v,label=v,data=0.0,isTerminal=1)
218 else:
219 tree.AddChildNode(child)
220 return tree
221
222
223 -def SigTreeBuilder(examples,attrs,nPossibleVals,initialVar=None,ensemble=None,
224 randomDescriptors=0,
225 **kwargs):
226 nRes = nPossibleVals[-1]
227 return BuildSigTree(examples,nRes,random=randomDescriptors,**kwargs)
228
230 import DataStructs
231 pts = [ ['p1',[3],0],
232 ['p2',[1,3],1],
233 ['p3',[2],1],
234 ['p4',[0],1],
235 ['p5',[],0],
236 ]
237 for pt in pts:
238 bv = DataStructs.ExplicitBitVect(4)
239 if pt[1]:
240 bv.SetBitsFromList(pt[1])
241 pt[1] = bv
242 t = BuildSigTree(pts,2)
243 t.Print()
244
245 RandomArray.seed(23,42)
246 t2 = BuildSigTree(pts,2,random=4)
247 t2.Print()
248
249 assert t==t2
250
251 RandomArray.seed(23,42)
252
253 t3 = BuildSigTree(pts,2,random=4,ensemble=range(4))
254 assert t3==t2
255
256 t3 = BuildSigTree(pts,2,useCMIM=3)
257 print t3
258
259 if __name__ == "__main__":
260 testTrees1()
261