1
2
3
4 """ handles doing cross validation with decision trees
5
6 This is, perhaps, a little misleading. For the purposes of this module,
7 cross validation == evaluating the accuracy of a tree.
8
9
10 """
11 from ML.DecTree import ID3
12 from ML.Data import SplitData
13 from Numeric import *
14
15 import RDRandom
16
17 -def ChooseOptimalRoot(examples,trainExamples,testExamples,attrs,
18 nPossibleVals,treeBuilder,nQuantBounds=[],
19 **kwargs):
20 """ loops through all possible tree roots and chooses the one which produces the best tree
21
22 **Arguments**
23
24 - examples: the full set of examples
25
26 - trainExamples: the training examples
27
28 - testExamples: the testing examples
29
30 - attrs: a list of attributes to consider in the tree building
31
32 - nPossibleVals: a list of the number of possible values each variable can adopt
33
34 - treeBuilder: the function to be used to actually build the tree
35
36 - nQuantBounds: an optional list. If present, it's assumed that the builder
37 algorithm takes this argument as well (for building QuantTrees)
38
39 **Returns**
40
41 The best tree found
42
43 **Notes**
44
45 1) Trees are built using _trainExamples_
46
47 2) Testing of each tree (to determine which is best) is done using _CrossValidate_ and
48 the entire set of data (i.e. all of _examples_)
49
50 3) _trainExamples_ is not used at all, which immediately raises the question of
51 why it's even being passed in
52
53 """
54 attrs = attrs[:]
55 if nQuantBounds:
56 for i in range(len(nQuantBounds)):
57 if nQuantBounds[i]==-1 and i in attrs:
58 attrs.remove(i)
59 nAttrs = len(attrs)
60 trees = [None]*nAttrs
61 errs = [0]*nAttrs
62 errs[0] = 1e6
63
64 for i in xrange(1,nAttrs):
65 argD = {'initialVar':attrs[i]}
66 argD.update(kwargs)
67 if nQuantBounds is None or nQuantBounds == []:
68 trees[i] = apply(treeBuilder,(trainExamples,attrs,nPossibleVals),
69 argD)
70 else:
71 trees[i] = apply(treeBuilder,(trainExamples,attrs,nPossibleVals,nQuantBounds),
72 argD)
73 if trees[i]:
74 errs[i],foo = CrossValidate(trees[i],examples,appendExamples=0)
75 else:
76 errs[i] = 1e6
77 best = argmin(errs)
78
79 return trees[best]
80
82 """ Determines the classification error for the testExamples
83
84 **Arguments**
85
86 - tree: a decision tree (or anything supporting a _ClassifyExample()_ method)
87
88 - testExamples: a list of examples to be used for testing
89
90 - appendExamples: a toggle which is passed along to the tree as it does
91 the classification. The trees can use this to store the examples they
92 classify locally.
93
94 **Returns**
95
96 a 2-tuple consisting of:
97
98 1) the percent error of the tree
99
100 2) a list of misclassified examples
101
102 """
103 nTest = len(testExamples)
104 nBad = 0
105 badExamples = []
106 for i in xrange(nTest):
107 testEx = testExamples[i]
108 trueRes = testEx[-1]
109 res = tree.ClassifyExample(testEx,appendExamples)
110 if trueRes != res:
111 badExamples.append(testEx)
112 nBad += 1
113
114
115 return float(nBad)/nTest,badExamples
116
117 -def CrossValidationDriver(examples,attrs,nPossibleVals,holdOutFrac=.3,silent=0,
118 calcTotalError=0,treeBuilder=ID3.ID3Boot,lessGreedy=0,
119 startAt=None,
120 nQuantBounds=[],
121 maxDepth=-1,
122 **kwargs):
123 """ Driver function for building trees and doing cross validation
124
125 **Arguments**
126
127 - examples: the full set of examples
128
129 - attrs: a list of attributes to consider in the tree building
130
131 - nPossibleVals: a list of the number of possible values each variable can adopt
132
133 - holdOutFrac: the fraction of the data which should be reserved for the hold-out set
134 (used to calculate the error)
135
136 - silent: a toggle used to control how much visual noise this makes as it goes.
137
138 - calcTotalError: a toggle used to indicate whether the classification error
139 of the tree should be calculated using the entire data set (when true) or just
140 the training hold out set (when false)
141
142 - treeBuilder: the function to call to build the tree
143
144 - lessGreedy: toggles use of the less greedy tree growth algorithm (see
145 _ChooseOptimalRoot_).
146
147 - startAt: forces the tree to be rooted at this descriptor
148
149 - nQuantBounds: an optional list. If present, it's assumed that the builder
150 algorithm takes this argument as well (for building QuantTrees)
151
152 - maxDepth: an optional integer. If present, it's assumed that the builder
153 algorithm takes this argument as well
154
155 **Returns**
156
157 a 2-tuple containing:
158
159 1) the tree
160
161 2) the cross-validation error of the tree
162
163 """
164 nTot = len(examples)
165 if not kwargs.get('replacementSelection',0):
166 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac,
167 silent=1,legacy=1,
168 replacement=0)
169 else:
170 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac,
171 silent=1,legacy=0,
172 replacement=1)
173 trainExamples = [examples[x] for x in trainIndices]
174 testExamples = [examples[x] for x in testIndices]
175
176 nTrain = len(trainExamples)
177 if not silent:
178 print 'Training with %d examples'%(nTrain)
179
180 if not lessGreedy:
181 if nQuantBounds is None or nQuantBounds == []:
182 tree = treeBuilder(trainExamples,attrs,nPossibleVals,
183 initialVar=startAt,maxDepth=maxDepth,**kwargs)
184 else:
185 tree = treeBuilder(trainExamples,attrs,nPossibleVals,nQuantBounds,
186 initialVar=startAt,maxDepth=maxDepth,**kwargs)
187 else:
188 tree = ChooseOptimalRoot(examples,trainExamples,testExamples,
189 attrs,nPossibleVals,treeBuilder,nQuantBounds,
190 maxDepth=maxDepth,**kwargs)
191
192 nTest = len(testExamples)
193 if not silent:
194 print 'Testing with %d examples'%nTest
195 if not calcTotalError:
196 xValError,badExamples = CrossValidate(tree,testExamples,appendExamples=1)
197 else:
198 xValError,badExamples = CrossValidate(tree,examples,appendExamples=0)
199 if not silent:
200 print 'Validation error was %%%4.2f'%(100*xValError)
201 tree.SetBadExamples(badExamples)
202 tree.SetTrainingExamples(trainExamples)
203 tree.SetTestExamples(testExamples)
204 tree._trainIndices = trainIndices
205 return tree,xValError
206
207
209 """ testing code
210
211 """
212 from ML.DecTree import randomtest
213 examples,attrs,nPossibleVals = randomtest.GenRandomExamples(nExamples = 200)
214 tree,frac = CrossValidationDriver(examples,attrs,
215 nPossibleVals)
216
217 tree.Pickle('save.pkl')
218
219 import copy
220 t2 = copy.deepcopy(tree)
221 print 't1 == t2',tree==t2
222 l = [tree]
223 print 't2 in [tree]', t2 in l, l.index(t2)
224
225 if __name__ == '__main__':
226 TestRun()
227