1
2
3
4 """ handles doing cross validation with decision trees
5
6 This is, perhaps, a little misleading. For the purposes of this module,
7 cross validation == evaluating the accuracy of a tree.
8
9
10 """
11 from rdkit.ML.DecTree import ID3
12 from rdkit.ML.Data import SplitData
13 import numpy
14 from rdkit import RDRandom
15
16 -def ChooseOptimalRoot(examples,trainExamples,testExamples,attrs,
17 nPossibleVals,treeBuilder,nQuantBounds=[],
18 **kwargs):
19 """ loops through all possible tree roots and chooses the one which produces the best tree
20
21 **Arguments**
22
23 - examples: the full set of examples
24
25 - trainExamples: the training examples
26
27 - testExamples: the testing examples
28
29 - attrs: a list of attributes to consider in the tree building
30
31 - nPossibleVals: a list of the number of possible values each variable can adopt
32
33 - treeBuilder: the function to be used to actually build the tree
34
35 - nQuantBounds: an optional list. If present, it's assumed that the builder
36 algorithm takes this argument as well (for building QuantTrees)
37
38 **Returns**
39
40 The best tree found
41
42 **Notes**
43
44 1) Trees are built using _trainExamples_
45
46 2) Testing of each tree (to determine which is best) is done using _CrossValidate_ and
47 the entire set of data (i.e. all of _examples_)
48
49 3) _trainExamples_ is not used at all, which immediately raises the question of
50 why it's even being passed in
51
52 """
53 attrs = attrs[:]
54 if nQuantBounds:
55 for i in range(len(nQuantBounds)):
56 if nQuantBounds[i]==-1 and i in attrs:
57 attrs.remove(i)
58 nAttrs = len(attrs)
59 trees = [None]*nAttrs
60 errs = [0]*nAttrs
61 errs[0] = 1e6
62
63 for i in xrange(1,nAttrs):
64 argD = {'initialVar':attrs[i]}
65 argD.update(kwargs)
66 if nQuantBounds is None or nQuantBounds == []:
67 trees[i] = apply(treeBuilder,(trainExamples,attrs,nPossibleVals),
68 argD)
69 else:
70 trees[i] = apply(treeBuilder,(trainExamples,attrs,nPossibleVals,nQuantBounds),
71 argD)
72 if trees[i]:
73 errs[i],foo = CrossValidate(trees[i],examples,appendExamples=0)
74 else:
75 errs[i] = 1e6
76 best = numpy.argmin(errs)
77
78 return trees[best]
79
81 """ Determines the classification error for the testExamples
82
83 **Arguments**
84
85 - tree: a decision tree (or anything supporting a _ClassifyExample()_ method)
86
87 - testExamples: a list of examples to be used for testing
88
89 - appendExamples: a toggle which is passed along to the tree as it does
90 the classification. The trees can use this to store the examples they
91 classify locally.
92
93 **Returns**
94
95 a 2-tuple consisting of:
96
97 1) the percent error of the tree
98
99 2) a list of misclassified examples
100
101 """
102 nTest = len(testExamples)
103 nBad = 0
104 badExamples = []
105 for i in xrange(nTest):
106 testEx = testExamples[i]
107 trueRes = testEx[-1]
108 res = tree.ClassifyExample(testEx,appendExamples)
109 if (trueRes != res).any():
110 badExamples.append(testEx)
111 nBad += 1
112
113
114 return float(nBad)/nTest,badExamples
115
116 -def CrossValidationDriver(examples,attrs,nPossibleVals,holdOutFrac=.3,silent=0,
117 calcTotalError=0,treeBuilder=ID3.ID3Boot,lessGreedy=0,
118 startAt=None,
119 nQuantBounds=[],
120 maxDepth=-1,
121 **kwargs):
122 """ Driver function for building trees and doing cross validation
123
124 **Arguments**
125
126 - examples: the full set of examples
127
128 - attrs: a list of attributes to consider in the tree building
129
130 - nPossibleVals: a list of the number of possible values each variable can adopt
131
132 - holdOutFrac: the fraction of the data which should be reserved for the hold-out set
133 (used to calculate the error)
134
135 - silent: a toggle used to control how much visual noise this makes as it goes.
136
137 - calcTotalError: a toggle used to indicate whether the classification error
138 of the tree should be calculated using the entire data set (when true) or just
139 the training hold out set (when false)
140
141 - treeBuilder: the function to call to build the tree
142
143 - lessGreedy: toggles use of the less greedy tree growth algorithm (see
144 _ChooseOptimalRoot_).
145
146 - startAt: forces the tree to be rooted at this descriptor
147
148 - nQuantBounds: an optional list. If present, it's assumed that the builder
149 algorithm takes this argument as well (for building QuantTrees)
150
151 - maxDepth: an optional integer. If present, it's assumed that the builder
152 algorithm takes this argument as well
153
154 **Returns**
155
156 a 2-tuple containing:
157
158 1) the tree
159
160 2) the cross-validation error of the tree
161
162 """
163 nTot = len(examples)
164 if not kwargs.get('replacementSelection',0):
165 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac,
166 silent=1,legacy=1,
167 replacement=0)
168 else:
169 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac,
170 silent=1,legacy=0,
171 replacement=1)
172 trainExamples = [examples[x] for x in trainIndices]
173 testExamples = [examples[x] for x in testIndices]
174
175 nTrain = len(trainExamples)
176 if not silent:
177 print 'Training with %d examples'%(nTrain)
178
179 if not lessGreedy:
180 if nQuantBounds is None or nQuantBounds == []:
181 tree = treeBuilder(trainExamples,attrs,nPossibleVals,
182 initialVar=startAt,maxDepth=maxDepth,**kwargs)
183 else:
184 tree = treeBuilder(trainExamples,attrs,nPossibleVals,nQuantBounds,
185 initialVar=startAt,maxDepth=maxDepth,**kwargs)
186 else:
187 tree = ChooseOptimalRoot(examples,trainExamples,testExamples,
188 attrs,nPossibleVals,treeBuilder,nQuantBounds,
189 maxDepth=maxDepth,**kwargs)
190
191 nTest = len(testExamples)
192 if not silent:
193 print 'Testing with %d examples'%nTest
194 if not calcTotalError:
195 xValError,badExamples = CrossValidate(tree,testExamples,appendExamples=1)
196 else:
197 xValError,badExamples = CrossValidate(tree,examples,appendExamples=0)
198 if not silent:
199 print 'Validation error was %%%4.2f'%(100*xValError)
200 tree.SetBadExamples(badExamples)
201 tree.SetTrainingExamples(trainExamples)
202 tree.SetTestExamples(testExamples)
203 tree._trainIndices = trainIndices
204 return tree,xValError
205
206
208 """ testing code
209
210 """
211 from rdkit.ML.DecTree import randomtest
212 examples,attrs,nPossibleVals = randomtest.GenRandomExamples(nExamples = 200)
213 tree,frac = CrossValidationDriver(examples,attrs,
214 nPossibleVals)
215
216 tree.Pickle('save.pkl')
217
218 import copy
219 t2 = copy.deepcopy(tree)
220 print 't1 == t2',tree==t2
221 l = [tree]
222 print 't2 in [tree]', t2 in l, l.index(t2)
223
224 if __name__ == '__main__':
225 TestRun()
226