Package ML :: Package DecTree :: Module PruneTree
[hide private]
[frames] | no frames]

Source Code for Module ML.DecTree.PruneTree

  1  # 
  2  #  Copyright (C) 2000,2003  greg Landrum and Rational Discovery LLC 
  3  # 
  4  """ Contains functionality for doing tree pruning 
  5   
  6  """ 
  7  from Numeric import * 
  8  from ML.DecTree import CrossValidate, DecTree 
  9  import copy 
 10   
 11  _verbose = 0 
 12   
13 -def MaxCount(examples):
14 """ given a set of examples, returns the most common result code 15 16 **Arguments** 17 18 examples: a list of examples to be counted 19 20 **Returns** 21 22 the most common result code 23 24 """ 25 resList = [x[-1] for x in examples] 26 maxVal = max(resList) 27 counts = [None]*(maxVal+1) 28 for i in xrange(maxVal+1): 29 counts[i] = sum([x==i for x in resList]) 30 31 return argmax(counts)
32
33 -def _GetLocalError(node):
34 nWrong = 0 35 for example in node.GetExamples(): 36 pred = node.ClassifyExample(example,appendExamples=0) 37 if pred != example[-1]: 38 nWrong +=1 39 #if _verbose: print '------------------>MISS:',example,pred 40 return nWrong
41
42 -def _Pruner(node,level=0):
43 """Recursively finds and removes the nodes whose removals improve classification 44 45 **Arguments** 46 47 - node: the tree to be pruned. The pruning data should already be contained 48 within node (i.e. node.GetExamples() should return the pruning data) 49 50 - level: (optional) the level of recursion, used only in _verbose printing 51 52 53 **Returns** 54 55 the pruned version of node 56 57 58 **Notes** 59 60 - This uses a greedy algorithm which basically does a DFS traversal of the tree, 61 removing nodes whenever possible. 62 63 - If removing a node does not affect the accuracy, it *will be* removed. We 64 favor smaller trees. 65 66 """ 67 if _verbose: print ' '*level,'<%d> '%level,'>>> Pruner' 68 children = node.GetChildren()[:] 69 70 bestTree = copy.deepcopy(node) 71 bestErr = 1e6 72 emptyChildren=[] 73 # 74 # Loop over the children of this node, removing them when doing so 75 # either improves the local error or leaves it unchanged (we're 76 # introducing a bias for simpler trees). 77 # 78 for i in range(len(children)): 79 child = children[i] 80 examples = child.GetExamples() 81 if _verbose: 82 print ' '*level,'<%d> '%level,' Child:',i,child.GetLabel() 83 bestTree.Print() 84 print 85 if len(examples): 86 if _verbose: print ' '*level,'<%d> '%level,' Examples',len(examples) 87 if not child.GetTerminal(): 88 if _verbose: print ' '*level,'<%d> '%level,' Nonterminal' 89 90 workTree = copy.deepcopy(bestTree) 91 # 92 # First recurse on the child (try removing things below it) 93 # 94 newNode = _Pruner(child,level=level+1) 95 workTree.ReplaceChildIndex(i,newNode) 96 tempErr = _GetLocalError(workTree) 97 if tempErr<=bestErr: 98 bestErr = tempErr 99 bestTree = copy.deepcopy(workTree) 100 if _verbose: 101 print ' '*level,'<%d> '%level,'>->->->->->' 102 print ' '*level,'<%d> '%level,'replacing:',i,child.GetLabel() 103 child.Print() 104 print ' '*level,'<%d> '%level,'with:' 105 newNode.Print() 106 print ' '*level,'<%d> '%level,'<-<-<-<-<-<' 107 else: 108 workTree.ReplaceChildIndex(i,child) 109 # 110 # Now try replacing the child entirely 111 # 112 bestGuess = MaxCount(child.GetExamples()) 113 newNode = DecTree.DecTreeNode(workTree,'L:%d'%(bestGuess), 114 label=bestGuess,isTerminal=1) 115 newNode.SetExamples(child.GetExamples()) 116 workTree.ReplaceChildIndex(i,newNode) 117 if _verbose: 118 print ' '*level,'<%d> '%level,'ATTEMPT:' 119 workTree.Print() 120 newErr = _GetLocalError(workTree) 121 if _verbose: print ' '*level,'<%d> '%level,'---> ',newErr,bestErr 122 if newErr <= bestErr: 123 bestErr = newErr 124 bestTree = copy.deepcopy(workTree) 125 if _verbose: 126 print ' '*level,'<%d> '%level,'PRUNING:' 127 workTree.Print() 128 else: 129 if _verbose: print ' '*level,'<%d> '%level,'FAIL' 130 # whoops... put the child back in: 131 workTree.ReplaceChildIndex(i,child) 132 else: 133 if _verbose: print ' '*level,'<%d> '%level,' Terminal' 134 else: 135 if _verbose: print ' '*level,'<%d> '%level,' No Examples',len(examples) 136 # 137 # FIX: we need to figure out what to do here (nodes that contain 138 # no examples in the testing set). I can concoct arguments for 139 # leaving them in and for removing them. At the moment they are 140 # left intact. 141 # 142 pass 143 144 if _verbose: print ' '*level,'<%d> '%level,'<<< out' 145 return bestTree
146
147 -def PruneTree(tree,trainExamples,testExamples,minimizeTestErrorOnly=1):
148 """ implements a reduced-error pruning of decision trees 149 150 This algorithm is described on page 69 of Mitchell's book. 151 152 Pruning can be done using just the set of testExamples (the validation set) 153 or both the testExamples and the trainExamples by setting minimizeTestErrorOnly 154 to 0. 155 156 **Arguments** 157 158 - tree: the initial tree to be pruned 159 160 - trainExamples: the examples used to train the tree 161 162 - testExamples: the examples held out for testing the tree 163 164 - minimizeTestErrorOnly: if this toggle is zero, all examples (i.e. 165 _trainExamples_ + _testExamples_ will be used to evaluate the error. 166 167 **Returns** 168 169 a 2-tuple containing: 170 171 1) the best tree 172 173 2) the best error (the one which corresponds to that tree) 174 175 """ 176 if minimizeTestErrorOnly: 177 testSet = testExamples 178 else: 179 testSet = trainExamples + testExamples 180 181 # remove any stored examples the tree may have 182 tree.ClearExamples() 183 184 # 185 # screen the test data through the tree so that we end up with the 186 # appropriate points stored at each node of the tree 187 # 188 totErr,badEx = CrossValidate.CrossValidate(tree,testSet,appendExamples=1) 189 190 191 # 192 # Prune 193 # 194 newTree = _Pruner(tree) 195 196 # 197 # And recalculate the errors 198 # 199 totErr,badEx = CrossValidate.CrossValidate(newTree,testSet) 200 newTree.SetBadExamples(badEx) 201 202 return newTree,totErr
203 204 205 # ------- 206 # testing code 207 # -------
208 -def _testRandom():
209 from ML.DecTree import randomtest 210 #examples,attrs,nPossibleVals = randomtest.GenRandomExamples(nVars=20,randScale=0.25,nExamples = 200) 211 examples,attrs,nPossibleVals = randomtest.GenRandomExamples(nVars=10,randScale=0.5,nExamples = 200) 212 tree,frac = CrossValidate.CrossValidationDriver(examples,attrs,nPossibleVals) 213 tree.Print() 214 tree.Pickle('orig.pkl') 215 print 'original error is:', frac 216 217 print '----Pruning' 218 newTree,frac2 = PruneTree(tree,tree.GetTrainingExamples(),tree.GetTestExamples()) 219 newTree.Print() 220 print 'pruned error is:',frac2 221 newTree.Pickle('prune.pkl')
222 223
224 -def _testSpecific():
225 from ML.DecTree import ID3 226 oPts= [ \ 227 [0,0,1,0], 228 [0,1,1,1], 229 [1,0,1,1], 230 [1,1,0,0], 231 [1,1,1,1], 232 ] 233 tPts = oPts+[[0,1,1,0],[0,1,1,0]] 234 235 tree = ID3.ID3Boot(oPts,attrs=range(3),nPossibleVals=[2]*4) 236 tree.Print() 237 err,badEx = CrossValidate.CrossValidate(tree,oPts) 238 print 'original error:',err 239 240 241 err,badEx = CrossValidate.CrossValidate(tree,tPts) 242 print 'original holdout error:',err 243 newTree,frac2 = PruneTree(tree,oPts,tPts) 244 newTree.Print() 245 err,badEx = CrossValidate.CrossValidate(newTree,tPts) 246 print 'pruned holdout error is:',err 247 print badEx 248 249 print len(tree),len(newTree)
250
251 -def _testChain():
252 from ML.DecTree import ID3 253 oPts= [ \ 254 [1,0,0,0,1], 255 [1,0,0,0,1], 256 [1,0,0,0,1], 257 [1,0,0,0,1], 258 [1,0,0,0,1], 259 [1,0,0,0,1], 260 [1,0,0,0,1], 261 [0,0,1,1,0], 262 [0,0,1,1,0], 263 [0,0,1,1,1], 264 [0,1,0,1,0], 265 [0,1,0,1,0], 266 [0,1,0,0,1], 267 ] 268 tPts = oPts 269 270 tree = ID3.ID3Boot(oPts,attrs=range(len(oPts[0])-1),nPossibleVals=[2]*len(oPts[0])) 271 tree.Print() 272 err,badEx = CrossValidate.CrossValidate(tree,oPts) 273 print 'original error:',err 274 275 276 err,badEx = CrossValidate.CrossValidate(tree,tPts) 277 print 'original holdout error:',err 278 newTree,frac2 = PruneTree(tree,oPts,tPts) 279 newTree.Print() 280 err,badEx = CrossValidate.CrossValidate(newTree,tPts) 281 print 'pruned holdout error is:',err 282 print badEx
283 284 285 if __name__ == '__main__': 286 _verbose=1 287 #_testRandom() 288 289 _testChain() 290