Package ML :: Package DecTree :: Module CrossValidate
[hide private]
[frames] | no frames]

Source Code for Module ML.DecTree.CrossValidate

  1  # 
  2  #  Copyright (C) 2000  greg Landrum 
  3  # 
  4  """ handles doing cross validation with decision trees 
  5   
  6  This is, perhaps, a little misleading.  For the purposes of this module, 
  7  cross validation == evaluating the accuracy of a tree. 
  8   
  9   
 10  """ 
 11  from ML.DecTree import ID3 
 12  from ML.Data import SplitData 
 13  from Numeric import * 
 14   
 15  import RDRandom 
 16   
17 -def ChooseOptimalRoot(examples,trainExamples,testExamples,attrs, 18 nPossibleVals,treeBuilder,nQuantBounds=[], 19 **kwargs):
20 """ loops through all possible tree roots and chooses the one which produces the best tree 21 22 **Arguments** 23 24 - examples: the full set of examples 25 26 - trainExamples: the training examples 27 28 - testExamples: the testing examples 29 30 - attrs: a list of attributes to consider in the tree building 31 32 - nPossibleVals: a list of the number of possible values each variable can adopt 33 34 - treeBuilder: the function to be used to actually build the tree 35 36 - nQuantBounds: an optional list. If present, it's assumed that the builder 37 algorithm takes this argument as well (for building QuantTrees) 38 39 **Returns** 40 41 The best tree found 42 43 **Notes** 44 45 1) Trees are built using _trainExamples_ 46 47 2) Testing of each tree (to determine which is best) is done using _CrossValidate_ and 48 the entire set of data (i.e. all of _examples_) 49 50 3) _trainExamples_ is not used at all, which immediately raises the question of 51 why it's even being passed in 52 53 """ 54 attrs = attrs[:] 55 if nQuantBounds: 56 for i in range(len(nQuantBounds)): 57 if nQuantBounds[i]==-1 and i in attrs: 58 attrs.remove(i) 59 nAttrs = len(attrs) 60 trees = [None]*nAttrs 61 errs = [0]*nAttrs 62 errs[0] = 1e6 63 64 for i in xrange(1,nAttrs): 65 argD = {'initialVar':attrs[i]} 66 argD.update(kwargs) 67 if nQuantBounds is None or nQuantBounds == []: 68 trees[i] = apply(treeBuilder,(trainExamples,attrs,nPossibleVals), 69 argD) 70 else: 71 trees[i] = apply(treeBuilder,(trainExamples,attrs,nPossibleVals,nQuantBounds), 72 argD) 73 if trees[i]: 74 errs[i],foo = CrossValidate(trees[i],examples,appendExamples=0) 75 else: 76 errs[i] = 1e6 77 best = argmin(errs) 78 # FIX: this used to say 'trees[i]', could that possibly have been right? 79 return trees[best]
80
81 -def CrossValidate(tree,testExamples,appendExamples=0):
82 """ Determines the classification error for the testExamples 83 84 **Arguments** 85 86 - tree: a decision tree (or anything supporting a _ClassifyExample()_ method) 87 88 - testExamples: a list of examples to be used for testing 89 90 - appendExamples: a toggle which is passed along to the tree as it does 91 the classification. The trees can use this to store the examples they 92 classify locally. 93 94 **Returns** 95 96 a 2-tuple consisting of: 97 98 1) the percent error of the tree 99 100 2) a list of misclassified examples 101 102 """ 103 nTest = len(testExamples) 104 nBad = 0 105 badExamples = [] 106 for i in xrange(nTest): 107 testEx = testExamples[i] 108 trueRes = testEx[-1] 109 res = tree.ClassifyExample(testEx,appendExamples) 110 if trueRes != res: 111 badExamples.append(testEx) 112 nBad += 1 113 114 115 return float(nBad)/nTest,badExamples
116
117 -def CrossValidationDriver(examples,attrs,nPossibleVals,holdOutFrac=.3,silent=0, 118 calcTotalError=0,treeBuilder=ID3.ID3Boot,lessGreedy=0, 119 startAt=None, 120 nQuantBounds=[], 121 maxDepth=-1, 122 **kwargs):
123 """ Driver function for building trees and doing cross validation 124 125 **Arguments** 126 127 - examples: the full set of examples 128 129 - attrs: a list of attributes to consider in the tree building 130 131 - nPossibleVals: a list of the number of possible values each variable can adopt 132 133 - holdOutFrac: the fraction of the data which should be reserved for the hold-out set 134 (used to calculate the error) 135 136 - silent: a toggle used to control how much visual noise this makes as it goes. 137 138 - calcTotalError: a toggle used to indicate whether the classification error 139 of the tree should be calculated using the entire data set (when true) or just 140 the training hold out set (when false) 141 142 - treeBuilder: the function to call to build the tree 143 144 - lessGreedy: toggles use of the less greedy tree growth algorithm (see 145 _ChooseOptimalRoot_). 146 147 - startAt: forces the tree to be rooted at this descriptor 148 149 - nQuantBounds: an optional list. If present, it's assumed that the builder 150 algorithm takes this argument as well (for building QuantTrees) 151 152 - maxDepth: an optional integer. If present, it's assumed that the builder 153 algorithm takes this argument as well 154 155 **Returns** 156 157 a 2-tuple containing: 158 159 1) the tree 160 161 2) the cross-validation error of the tree 162 163 """ 164 nTot = len(examples) 165 if not kwargs.get('replacementSelection',0): 166 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac, 167 silent=1,legacy=1, 168 replacement=0) 169 else: 170 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac, 171 silent=1,legacy=0, 172 replacement=1) 173 trainExamples = [examples[x] for x in trainIndices] 174 testExamples = [examples[x] for x in testIndices] 175 176 nTrain = len(trainExamples) 177 if not silent: 178 print 'Training with %d examples'%(nTrain) 179 180 if not lessGreedy: 181 if nQuantBounds is None or nQuantBounds == []: 182 tree = treeBuilder(trainExamples,attrs,nPossibleVals, 183 initialVar=startAt,maxDepth=maxDepth,**kwargs) 184 else: 185 tree = treeBuilder(trainExamples,attrs,nPossibleVals,nQuantBounds, 186 initialVar=startAt,maxDepth=maxDepth,**kwargs) 187 else: 188 tree = ChooseOptimalRoot(examples,trainExamples,testExamples, 189 attrs,nPossibleVals,treeBuilder,nQuantBounds, 190 maxDepth=maxDepth,**kwargs) 191 192 nTest = len(testExamples) 193 if not silent: 194 print 'Testing with %d examples'%nTest 195 if not calcTotalError: 196 xValError,badExamples = CrossValidate(tree,testExamples,appendExamples=1) 197 else: 198 xValError,badExamples = CrossValidate(tree,examples,appendExamples=0) 199 if not silent: 200 print 'Validation error was %%%4.2f'%(100*xValError) 201 tree.SetBadExamples(badExamples) 202 tree.SetTrainingExamples(trainExamples) 203 tree.SetTestExamples(testExamples) 204 tree._trainIndices = trainIndices 205 return tree,xValError
206 207
208 -def TestRun():
209 """ testing code 210 211 """ 212 from ML.DecTree import randomtest 213 examples,attrs,nPossibleVals = randomtest.GenRandomExamples(nExamples = 200) 214 tree,frac = CrossValidationDriver(examples,attrs, 215 nPossibleVals) 216 217 tree.Pickle('save.pkl') 218 219 import copy 220 t2 = copy.deepcopy(tree) 221 print 't1 == t2',tree==t2 222 l = [tree] 223 print 't2 in [tree]', t2 in l, l.index(t2)
224 225 if __name__ == '__main__': 226 TestRun() 227