Package ML :: Package DecTree :: Module Forest
[hide private]
[frames] | no frames]

Source Code for Module ML.DecTree.Forest

  1  # 
  2  #  Copyright (C) 2000  greg Landrum 
  3  # 
  4  """ code for dealing with forests (collections) of decision trees 
  5   
  6  **NOTE** This code should be obsolete now that ML.Composite.Composite is up and running. 
  7   
  8  """ 
  9  import cPickle 
 10  from Numeric import * 
 11  from ML.DecTree import CrossValidate,PruneTree 
 12   
13 -class Forest(object):
14 """a forest of unique decision trees. 15 16 adding an existing tree just results in its count field being incremented 17 and the errors being averaged. 18 19 typical usage: 20 21 1) grow the forest with AddTree until happy with it 22 23 2) call AverageErrors to calculate the average error values 24 25 3) call SortTrees to put things in order by either error or count 26 27 """
28 - def MakeHistogram(self):
29 """ creates a histogram of error/count pairs 30 31 """ 32 nExamples = len(self.treeList) 33 histo = [] 34 i = 1 35 lastErr = self.errList[0] 36 countHere = self.countList[0] 37 eps = 0.001 38 while i < nExamples: 39 if self.errList[i]-lastErr > eps: 40 histo.append((lastErr,countHere)) 41 lastErr = self.errList[i] 42 countHere = self.countList[i] 43 else: 44 countHere = countHere + self.countList[i] 45 i = i + 1 46 47 return histo
48
49 - def CollectVotes(self,example):
50 """ collects votes across every member of the forest for the given example 51 52 **Returns** 53 54 a list of the results 55 56 """ 57 nTrees = len(self.treeList) 58 votes = [0]*nTrees 59 for i in xrange(nTrees): 60 votes[i] = self.treeList[i].ClassifyExample(example) 61 return votes
62
63 - def ClassifyExample(self,example):
64 """ classifies the given example using the entire forest 65 66 **returns** a result and a measure of confidence in it. 67 68 **FIX:** statistics sucks... I'm not seeing an obvious way to get 69 the confidence intervals. For that matter, I'm not seeing 70 an unobvious way. 71 72 For now, this is just treated as a voting problem with the confidence 73 measure being the percent of trees which voted for the winning result. 74 """ 75 self.treeVotes = self.CollectVotes(example) 76 votes = [0]*len(self._nPossible) 77 for i in xrange(len(self.treeList)): 78 res = self.treeVotes[i] 79 votes[res] = votes[res] + self.countList[i] 80 81 totVotes = sum(votes) 82 res = argmax(votes) 83 #print 'v:',res,votes,totVotes 84 return res,float(votes[res])/float(totVotes)
85
86 - def GetVoteDetails(self):
87 """ Returns the details of the last vote the forest conducted 88 89 this will be an empty list if no voting has yet been done 90 91 """ 92 return self.treeVotes
93
94 - def Grow(self,examples,attrs,nPossibleVals,nTries=10,pruneIt=0, 95 lessGreedy=0):
96 """ Grows the forest by adding trees 97 98 **Arguments** 99 100 - examples: the examples to be used for training 101 102 - attrs: a list of the attributes to be used in training 103 104 - nPossibleVals: a list with the number of possible values each variable 105 (as well as the result) can take on 106 107 - nTries: the number of new trees to add 108 109 - pruneIt: a toggle for whether or not the tree should be pruned 110 111 - lessGreedy: toggles the use of a less greedy construction algorithm where 112 each possible tree root is used. The best tree from each step is actually 113 added to the forest. 114 115 """ 116 self._nPossible = nPossibleVals 117 for i in xrange(nTries): 118 tree,frac = CrossValidate.CrossValidationDriver(examples,attrs,nPossibleVals, 119 silent=1,calcTotalError=1, 120 lessGreedy=lessGreedy) 121 if pruneIt: 122 tree,frac2 = PruneTree.PruneTree(tree,tree.GetTrainingExamples(), 123 tree.GetTestExamples(), 124 minimizeTestErrorOnly=0) 125 print 'prune: ', frac,frac2 126 frac = frac2 127 self.AddTree(tree,frac) 128 if i % (nTries/10) == 0: 129 print 'Cycle: % 4d'%(i)
130
131 - def Pickle(self,fileName='foo.pkl'):
132 """ Writes this forest off to a file so that it can be easily loaded later 133 134 **Arguments** 135 136 fileName is the name of the file to be written 137 138 """ 139 pFile = open(fileName,'wb+') 140 cPickle.dump(self,pFile,1) 141 pFile.close()
142
143 - def AddTree(self,tree,error):
144 """ Adds a tree to the forest 145 146 If an identical tree is already present, its count is incremented 147 148 **Arguments** 149 150 - tree: the new tree 151 152 - error: its error value 153 154 **NOTE:** the errList is run as an accumulator, 155 you probably want to call AverageErrors after finishing the forest 156 157 """ 158 if tree in self.treeList: 159 idx = self.treeList.index(tree) 160 self.errList[idx] = self.errList[idx]+error 161 self.countList[idx] = self.countList[idx] + 1 162 else: 163 self.treeList.append(tree) 164 self.errList.append(error) 165 self.countList.append(1)
166
167 - def AverageErrors(self):
168 """ convert summed error to average error 169 170 This does the conversion in place 171 """ 172 self.errList = map(lambda x,y:x/y,self.errList,self.countList)
173
174 - def SortTrees(self,sortOnError=1):
175 """ sorts the list of trees 176 177 **Arguments** 178 179 sortOnError: toggles sorting on the trees' errors rather than their counts 180 181 """ 182 if sortOnError: 183 order = argsort(self.errList) 184 else: 185 order = argsort(self.countList) 186 187 # these elaborate contortions are required because, at the time this 188 # code was written, Numeric arrays didn't unpickle so well... 189 self.treeList = list(take(self.treeList,order)) 190 self.countList = list(take(self.countList,order)) 191 self.errList = list(take(self.errList,order))
192 193
194 - def GetTree(self,i):
195 return self.treeList[i]
196 - def SetTree(self,i,val):
197 self.treeList[i] = val
198
199 - def GetCount(self,i):
200 return self.countList[i]
201 - def SetCount(self,i,val):
202 self.countList[i] = val
203
204 - def GetError(self,i):
205 return self.errList[i]
206 - def SetError(self,i,val):
207 self.errList[i] = val
208
209 - def GetDataTuple(self,i):
210 """ returns all relevant data about a particular tree in the forest 211 212 **Arguments** 213 214 i: an integer indicating which tree should be returned 215 216 **Returns** 217 218 a 3-tuple consisting of: 219 220 1) the tree 221 222 2) its count 223 224 3) its error 225 """ 226 return (self.treeList[i],self.countList[i],self.errList[i])
227
228 - def SetDataTuple(self,i,tup):
229 """ sets all relevant data for a particular tree in the forest 230 231 **Arguments** 232 233 - i: an integer indicating which tree should be returned 234 235 - tup: a 3-tuple consisting of: 236 237 1) the tree 238 239 2) its count 240 241 3) its error 242 """ 243 self.treeList[i],self.countList[i],self.errList[i] = tup
244
245 - def GetAllData(self):
246 """ Returns everything we know 247 248 **Returns** 249 250 a 3-tuple consisting of: 251 252 1) our list of trees 253 254 2) our list of tree counts 255 256 3) our list of tree errors 257 258 """ 259 return (self.treeList,self.countList,self.errList)
260
261 - def __len__(self):
262 """ allows len(forest) to work 263 264 """ 265 return len(self.treeList)
266
267 - def __getitem__(self,which):
268 """ allows forest[i] to work. return the data tuple 269 270 """ 271 return self.GetDataTuple(which)
272
273 - def __str__(self):
274 """ allows the forest to show itself as a string 275 276 """ 277 outStr= 'Forest\n' 278 for i in xrange(len(self.treeList)): 279 outStr = outStr + \ 280 ' Tree % 4d: % 5d occurances %%% 5.2f average error\n'%(i,self.countList[i], 281 100.*self.errList[i]) 282 return outStr
283
284 - def __init__(self):
285 self.treeList=[] 286 self.errList=[] 287 self.countList=[] 288 self.treeVotes=[]
289 290 if __name__ == '__main__': 291 from ML.DecTree import DecTree 292 f = Forest() 293 n = DecTree.DecTreeNode(None,'foo') 294 f.AddTree(n,0.5) 295 f.AddTree(n,0.5) 296 f.AverageErrors() 297 f.SortTrees() 298 print f 299