Package rdkit :: Package ML :: Package DecTree :: Module Forest
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.DecTree.Forest

  1  # 
  2  #  Copyright (C) 2000-2008  greg Landrum 
  3  # 
  4  """ code for dealing with forests (collections) of decision trees 
  5   
  6  **NOTE** This code should be obsolete now that ML.Composite.Composite is up and running. 
  7   
  8  """ 
  9  import cPickle 
 10  import numpy 
 11  from rdkit.ML.DecTree import CrossValidate,PruneTree 
 12   
13 -class Forest(object):
14 """a forest of unique decision trees. 15 16 adding an existing tree just results in its count field being incremented 17 and the errors being averaged. 18 19 typical usage: 20 21 1) grow the forest with AddTree until happy with it 22 23 2) call AverageErrors to calculate the average error values 24 25 3) call SortTrees to put things in order by either error or count 26 27 """
28 - def MakeHistogram(self):
29 """ creates a histogram of error/count pairs 30 31 """ 32 nExamples = len(self.treeList) 33 histo = [] 34 i = 1 35 lastErr = self.errList[0] 36 countHere = self.countList[0] 37 eps = 0.001 38 while i < nExamples: 39 if self.errList[i]-lastErr > eps: 40 histo.append((lastErr,countHere)) 41 lastErr = self.errList[i] 42 countHere = self.countList[i] 43 else: 44 countHere = countHere + self.countList[i] 45 i = i + 1 46 47 return histo
48
49 - def CollectVotes(self,example):
50 """ collects votes across every member of the forest for the given example 51 52 **Returns** 53 54 a list of the results 55 56 """ 57 nTrees = len(self.treeList) 58 votes = [0]*nTrees 59 for i in range(nTrees): 60 votes[i] = self.treeList[i].ClassifyExample(example) 61 return votes
62
63 - def ClassifyExample(self,example):
64 """ classifies the given example using the entire forest 65 66 **returns** a result and a measure of confidence in it. 67 68 **FIX:** statistics sucks... I'm not seeing an obvious way to get 69 the confidence intervals. For that matter, I'm not seeing 70 an unobvious way. 71 72 For now, this is just treated as a voting problem with the confidence 73 measure being the percent of trees which voted for the winning result. 74 """ 75 self.treeVotes = self.CollectVotes(example) 76 votes = [0]*len(self._nPossible) 77 for i in range(len(self.treeList)): 78 res = self.treeVotes[i] 79 votes[res] = votes[res] + self.countList[i] 80 81 totVotes = sum(votes) 82 res = argmax(votes) 83 #print 'v:',res,votes,totVotes 84 return res,float(votes[res])/float(totVotes)
85
86 - def GetVoteDetails(self):
87 """ Returns the details of the last vote the forest conducted 88 89 this will be an empty list if no voting has yet been done 90 91 """ 92 return self.treeVotes
93
94 - def Grow(self,examples,attrs,nPossibleVals,nTries=10,pruneIt=0, 95 lessGreedy=0):
96 """ Grows the forest by adding trees 97 98 **Arguments** 99 100 - examples: the examples to be used for training 101 102 - attrs: a list of the attributes to be used in training 103 104 - nPossibleVals: a list with the number of possible values each variable 105 (as well as the result) can take on 106 107 - nTries: the number of new trees to add 108 109 - pruneIt: a toggle for whether or not the tree should be pruned 110 111 - lessGreedy: toggles the use of a less greedy construction algorithm where 112 each possible tree root is used. The best tree from each step is actually 113 added to the forest. 114 115 """ 116 self._nPossible = nPossibleVals 117 for i in range(nTries): 118 tree,frac = CrossValidate.CrossValidationDriver(examples,attrs,nPossibleVals, 119 silent=1,calcTotalError=1, 120 lessGreedy=lessGreedy) 121 if pruneIt: 122 tree,frac2 = PruneTree.PruneTree(tree,tree.GetTrainingExamples(), 123 tree.GetTestExamples(), 124 minimizeTestErrorOnly=0) 125 print 'prune: ', frac,frac2 126 frac = frac2 127 self.AddTree(tree,frac) 128 if i % (nTries/10) == 0: 129 print 'Cycle: % 4d'%(i)
130
131 - def Pickle(self,fileName='foo.pkl'):
132 """ Writes this forest off to a file so that it can be easily loaded later 133 134 **Arguments** 135 136 fileName is the name of the file to be written 137 138 """ 139 pFile = open(fileName,'wb+') 140 cPickle.dump(self,pFile,1) 141 pFile.close()
142
143 - def AddTree(self,tree,error):
144 """ Adds a tree to the forest 145 146 If an identical tree is already present, its count is incremented 147 148 **Arguments** 149 150 - tree: the new tree 151 152 - error: its error value 153 154 **NOTE:** the errList is run as an accumulator, 155 you probably want to call AverageErrors after finishing the forest 156 157 """ 158 if tree in self.treeList: 159 idx = self.treeList.index(tree) 160 self.errList[idx] = self.errList[idx]+error 161 self.countList[idx] = self.countList[idx] + 1 162 else: 163 self.treeList.append(tree) 164 self.errList.append(error) 165 self.countList.append(1)
166
167 - def AverageErrors(self):
168 """ convert summed error to average error 169 170 This does the conversion in place 171 """ 172 self.errList = [x/y for x,y in zip(self.errList,self.countList)]
173
174 - def SortTrees(self,sortOnError=1):
175 """ sorts the list of trees 176 177 **Arguments** 178 179 sortOnError: toggles sorting on the trees' errors rather than their counts 180 181 """ 182 if sortOnError: 183 order = numpy.argsort(self.errList) 184 else: 185 order = numpy.argsort(self.countList) 186 187 # these elaborate contortions are required because, at the time this 188 # code was written, Numeric arrays didn't unpickle so well... 189 self.treeList = [self.treeList[x] for x in order] 190 self.countList = [self.countList[x] for x in order] 191 self.errList = [self.errList[x] for x in order]
192
193 - def GetTree(self,i):
194 return self.treeList[i]
195 - def SetTree(self,i,val):
196 self.treeList[i] = val
197
198 - def GetCount(self,i):
199 return self.countList[i]
200 - def SetCount(self,i,val):
201 self.countList[i] = val
202
203 - def GetError(self,i):
204 return self.errList[i]
205 - def SetError(self,i,val):
206 self.errList[i] = val
207
208 - def GetDataTuple(self,i):
209 """ returns all relevant data about a particular tree in the forest 210 211 **Arguments** 212 213 i: an integer indicating which tree should be returned 214 215 **Returns** 216 217 a 3-tuple consisting of: 218 219 1) the tree 220 221 2) its count 222 223 3) its error 224 """ 225 return (self.treeList[i],self.countList[i],self.errList[i])
226
227 - def SetDataTuple(self,i,tup):
228 """ sets all relevant data for a particular tree in the forest 229 230 **Arguments** 231 232 - i: an integer indicating which tree should be returned 233 234 - tup: a 3-tuple consisting of: 235 236 1) the tree 237 238 2) its count 239 240 3) its error 241 """ 242 self.treeList[i],self.countList[i],self.errList[i] = tup
243
244 - def GetAllData(self):
245 """ Returns everything we know 246 247 **Returns** 248 249 a 3-tuple consisting of: 250 251 1) our list of trees 252 253 2) our list of tree counts 254 255 3) our list of tree errors 256 257 """ 258 return (self.treeList,self.countList,self.errList)
259
260 - def __len__(self):
261 """ allows len(forest) to work 262 263 """ 264 return len(self.treeList)
265
266 - def __getitem__(self,which):
267 """ allows forest[i] to work. return the data tuple 268 269 """ 270 return self.GetDataTuple(which)
271
272 - def __str__(self):
273 """ allows the forest to show itself as a string 274 275 """ 276 outStr= 'Forest\n' 277 for i in xrange(len(self.treeList)): 278 outStr = outStr + \ 279 ' Tree % 4d: % 5d occurances %%% 5.2f average error\n'%(i,self.countList[i], 280 100.*self.errList[i]) 281 return outStr
282
283 - def __init__(self):
284 self.treeList=[] 285 self.errList=[] 286 self.countList=[] 287 self.treeVotes=[]
288 289 if __name__ == '__main__': 290 from rdkit.ML.DecTree import DecTree 291 f = Forest() 292 n = DecTree.DecTreeNode(None,'foo') 293 f.AddTree(n,0.5) 294 f.AddTree(n,0.5) 295 f.AverageErrors() 296 f.SortTrees() 297 print f 298