1
2
3
4 """ code for dealing with forests (collections) of decision trees
5
6 **NOTE** This code should be obsolete now that ML.Composite.Composite is up and running.
7
8 """
9 import cPickle
10 from Numeric import *
11 from ML.DecTree import CrossValidate,PruneTree
12
14 """a forest of unique decision trees.
15
16 adding an existing tree just results in its count field being incremented
17 and the errors being averaged.
18
19 typical usage:
20
21 1) grow the forest with AddTree until happy with it
22
23 2) call AverageErrors to calculate the average error values
24
25 3) call SortTrees to put things in order by either error or count
26
27 """
29 """ creates a histogram of error/count pairs
30
31 """
32 nExamples = len(self.treeList)
33 histo = []
34 i = 1
35 lastErr = self.errList[0]
36 countHere = self.countList[0]
37 eps = 0.001
38 while i < nExamples:
39 if self.errList[i]-lastErr > eps:
40 histo.append((lastErr,countHere))
41 lastErr = self.errList[i]
42 countHere = self.countList[i]
43 else:
44 countHere = countHere + self.countList[i]
45 i = i + 1
46
47 return histo
48
50 """ collects votes across every member of the forest for the given example
51
52 **Returns**
53
54 a list of the results
55
56 """
57 nTrees = len(self.treeList)
58 votes = [0]*nTrees
59 for i in xrange(nTrees):
60 votes[i] = self.treeList[i].ClassifyExample(example)
61 return votes
62
64 """ classifies the given example using the entire forest
65
66 **returns** a result and a measure of confidence in it.
67
68 **FIX:** statistics sucks... I'm not seeing an obvious way to get
69 the confidence intervals. For that matter, I'm not seeing
70 an unobvious way.
71
72 For now, this is just treated as a voting problem with the confidence
73 measure being the percent of trees which voted for the winning result.
74 """
75 self.treeVotes = self.CollectVotes(example)
76 votes = [0]*len(self._nPossible)
77 for i in xrange(len(self.treeList)):
78 res = self.treeVotes[i]
79 votes[res] = votes[res] + self.countList[i]
80
81 totVotes = sum(votes)
82 res = argmax(votes)
83
84 return res,float(votes[res])/float(totVotes)
85
87 """ Returns the details of the last vote the forest conducted
88
89 this will be an empty list if no voting has yet been done
90
91 """
92 return self.treeVotes
93
94 - def Grow(self,examples,attrs,nPossibleVals,nTries=10,pruneIt=0,
95 lessGreedy=0):
96 """ Grows the forest by adding trees
97
98 **Arguments**
99
100 - examples: the examples to be used for training
101
102 - attrs: a list of the attributes to be used in training
103
104 - nPossibleVals: a list with the number of possible values each variable
105 (as well as the result) can take on
106
107 - nTries: the number of new trees to add
108
109 - pruneIt: a toggle for whether or not the tree should be pruned
110
111 - lessGreedy: toggles the use of a less greedy construction algorithm where
112 each possible tree root is used. The best tree from each step is actually
113 added to the forest.
114
115 """
116 self._nPossible = nPossibleVals
117 for i in xrange(nTries):
118 tree,frac = CrossValidate.CrossValidationDriver(examples,attrs,nPossibleVals,
119 silent=1,calcTotalError=1,
120 lessGreedy=lessGreedy)
121 if pruneIt:
122 tree,frac2 = PruneTree.PruneTree(tree,tree.GetTrainingExamples(),
123 tree.GetTestExamples(),
124 minimizeTestErrorOnly=0)
125 print 'prune: ', frac,frac2
126 frac = frac2
127 self.AddTree(tree,frac)
128 if i % (nTries/10) == 0:
129 print 'Cycle: % 4d'%(i)
130
131 - def Pickle(self,fileName='foo.pkl'):
132 """ Writes this forest off to a file so that it can be easily loaded later
133
134 **Arguments**
135
136 fileName is the name of the file to be written
137
138 """
139 pFile = open(fileName,'wb+')
140 cPickle.dump(self,pFile,1)
141 pFile.close()
142
144 """ Adds a tree to the forest
145
146 If an identical tree is already present, its count is incremented
147
148 **Arguments**
149
150 - tree: the new tree
151
152 - error: its error value
153
154 **NOTE:** the errList is run as an accumulator,
155 you probably want to call AverageErrors after finishing the forest
156
157 """
158 if tree in self.treeList:
159 idx = self.treeList.index(tree)
160 self.errList[idx] = self.errList[idx]+error
161 self.countList[idx] = self.countList[idx] + 1
162 else:
163 self.treeList.append(tree)
164 self.errList.append(error)
165 self.countList.append(1)
166
168 """ convert summed error to average error
169
170 This does the conversion in place
171 """
172 self.errList = map(lambda x,y:x/y,self.errList,self.countList)
173
175 """ sorts the list of trees
176
177 **Arguments**
178
179 sortOnError: toggles sorting on the trees' errors rather than their counts
180
181 """
182 if sortOnError:
183 order = argsort(self.errList)
184 else:
185 order = argsort(self.countList)
186
187
188
189 self.treeList = list(take(self.treeList,order))
190 self.countList = list(take(self.countList,order))
191 self.errList = list(take(self.errList,order))
192
193
195 return self.treeList[i]
197 self.treeList[i] = val
198
200 return self.countList[i]
202 self.countList[i] = val
203
205 return self.errList[i]
207 self.errList[i] = val
208
210 """ returns all relevant data about a particular tree in the forest
211
212 **Arguments**
213
214 i: an integer indicating which tree should be returned
215
216 **Returns**
217
218 a 3-tuple consisting of:
219
220 1) the tree
221
222 2) its count
223
224 3) its error
225 """
226 return (self.treeList[i],self.countList[i],self.errList[i])
227
229 """ sets all relevant data for a particular tree in the forest
230
231 **Arguments**
232
233 - i: an integer indicating which tree should be returned
234
235 - tup: a 3-tuple consisting of:
236
237 1) the tree
238
239 2) its count
240
241 3) its error
242 """
243 self.treeList[i],self.countList[i],self.errList[i] = tup
244
246 """ Returns everything we know
247
248 **Returns**
249
250 a 3-tuple consisting of:
251
252 1) our list of trees
253
254 2) our list of tree counts
255
256 3) our list of tree errors
257
258 """
259 return (self.treeList,self.countList,self.errList)
260
262 """ allows len(forest) to work
263
264 """
265 return len(self.treeList)
266
268 """ allows forest[i] to work. return the data tuple
269
270 """
271 return self.GetDataTuple(which)
272
274 """ allows the forest to show itself as a string
275
276 """
277 outStr= 'Forest\n'
278 for i in xrange(len(self.treeList)):
279 outStr = outStr + \
280 ' Tree % 4d: % 5d occurances %%% 5.2f average error\n'%(i,self.countList[i],
281 100.*self.errList[i])
282 return outStr
283
285 self.treeList=[]
286 self.errList=[]
287 self.countList=[]
288 self.treeVotes=[]
289
290 if __name__ == '__main__':
291 from ML.DecTree import DecTree
292 f = Forest()
293 n = DecTree.DecTreeNode(None,'foo')
294 f.AddTree(n,0.5)
295 f.AddTree(n,0.5)
296 f.AverageErrors()
297 f.SortTrees()
298 print f
299