1
2
3
4 """ handles doing cross validation with k-nearest neighbors model
5
6 and evaluation of individual models
7
8 """
9
10 from rdkit.ML.KNN.KNNClassificationModel import KNNClassificationModel
11 from rdkit.ML.KNN.KNNRegressionModel import KNNRegressionModel
12 from rdkit import RDRandom
13 from rdkit.ML.KNN import DistFunctions
14 from rdkit.ML.Data import SplitData
15
20
22 """
23 Determines the classification error for the testExamples
24
25 **Arguments**
26
27 - tree: a decision tree (or anything supporting a _ClassifyExample()_ method)
28
29 - testExamples: a list of examples to be used for testing
30
31 - appendExamples: a toggle which is passed along to the tree as it does
32 the classification. The trees can use this to store the examples they
33 classify locally.
34
35 **Returns**
36
37 a 2-tuple consisting of:
38 """
39 nTest = len(testExamples)
40
41 if isinstance(knnMod,KNNClassificationModel):
42 badExamples = []
43 nBad = 0
44 for i in xrange(nTest):
45 testEx = testExamples[i]
46 trueRes = testEx[-1]
47 res = knnMod.ClassifyExample(testEx, appendExamples)
48 if (trueRes != res) :
49 badExamples.append(testEx)
50 nBad += 1
51 return float(nBad)/nTest, badExamples
52 elif isinstance(knnMod,KNNRegressionModel):
53 devSum=0.0
54 for i in xrange(nTest):
55 testEx = testExamples[i]
56 trueRes = testEx[-1]
57 res = knnMod.PredictExample(testEx, appendExamples)
58 devSum += abs(trueRes-res)
59 return devSum/nTest,None
60 raise ValueError,"Unrecognized Model Type"
61
69 """ Driver function for building a KNN model of a specified type
70
71 **Arguments**
72
73 - examples: the full set of examples
74
75 - numNeigh: number of neighbors for the KNN model (basically k in k-NN)
76
77 - knnModel: the type of KNN model (a classification vs regression model)
78
79 - holdOutFrac: the fraction of the data which should be reserved for the hold-out set
80 (used to calculate error)
81
82 - silent: a toggle used to control how much visual noise this makes as it goes
83
84 - calcTotalError: a toggle used to indicate whether the classification error
85 of the tree should be calculated using the entire data set (when true) or just
86 the training hold out set (when false)
87 """
88
89 nTot = len(examples)
90 if not kwargs.get('replacementSelection',0):
91 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac,
92 silent=1,legacy=1,
93 replacement=0)
94 else:
95 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac,
96 silent=1,legacy=0,
97 replacement=1)
98 trainExamples = [examples[x] for x in trainIndices]
99 testExamples = [examples[x] for x in testIndices]
100
101
102 nTrain = len(trainExamples)
103
104 if not silent:
105 print "Training with %d examples"%(nTrain)
106
107 knnMod = modelBuilder(numNeigh, attrs, distFunc)
108
109 knnMod.SetTrainingExamples(trainExamples)
110 knnMod.SetTestExamples(testExamples)
111
112 if not calcTotalError:
113 xValError,badExamples = CrossValidate(knnMod, testExamples,appendExamples=1)
114 else:
115 xValError,badExamples = CrossValidate(knnMod, examples,appendExamples=0)
116
117 if not silent :
118 'Validation error was %%%4.2f'%(100*xValError)
119
120 knnMod._trainIndices = trainIndices
121 return knnMod, xValError
122