Package rdkit :: Package ML :: Package KNN :: Module CrossValidate
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.KNN.CrossValidate

  1  # 
  2  #  Copyright (C) 2000-2008  greg Landrum 
  3  # 
  4  """ handles doing cross validation with k-nearest neighbors model 
  5   
  6  and evaluation of individual models 
  7   
  8  """ 
  9   
 10  from rdkit.ML.KNN.KNNClassificationModel import KNNClassificationModel 
 11  from rdkit.ML.KNN.KNNRegressionModel import KNNRegressionModel 
 12  from rdkit import RDRandom 
 13  from rdkit.ML.KNN import DistFunctions 
 14  from rdkit.ML.Data import SplitData 
 15   
16 -def makeClassificationModel(numNeigh, attrs, distFunc) :
17 return KNNClassificationModel(numNeigh, attrs, distFunc)
18 -def makeRegressionModel(numNeigh, attrs, distFunc) :
19 return KNNRegressionModel(numNeigh, attrs, distFunc)
20
21 -def CrossValidate(knnMod,testExamples,appendExamples=0):
22 """ 23 Determines the classification error for the testExamples 24 25 **Arguments** 26 27 - tree: a decision tree (or anything supporting a _ClassifyExample()_ method) 28 29 - testExamples: a list of examples to be used for testing 30 31 - appendExamples: a toggle which is passed along to the tree as it does 32 the classification. The trees can use this to store the examples they 33 classify locally. 34 35 **Returns** 36 37 a 2-tuple consisting of: 38 """ 39 nTest = len(testExamples) 40 41 if isinstance(knnMod,KNNClassificationModel): 42 badExamples = [] 43 nBad = 0 44 for i in xrange(nTest): 45 testEx = testExamples[i] 46 trueRes = testEx[-1] 47 res = knnMod.ClassifyExample(testEx, appendExamples) 48 if (trueRes != res) : 49 badExamples.append(testEx) 50 nBad += 1 51 return float(nBad)/nTest, badExamples 52 elif isinstance(knnMod,KNNRegressionModel): 53 devSum=0.0 54 for i in xrange(nTest): 55 testEx = testExamples[i] 56 trueRes = testEx[-1] 57 res = knnMod.PredictExample(testEx, appendExamples) 58 devSum += abs(trueRes-res) 59 return devSum/nTest,None 60 raise ValueError,"Unrecognized Model Type"
61
62 -def CrossValidationDriver(examples, attrs, nPossibleValues, numNeigh, 63 modelBuilder=makeClassificationModel, 64 distFunc=DistFunctions.EuclideanDist, 65 holdOutFrac=0.3, 66 silent=0, 67 calcTotalError=0, 68 **kwargs) :
69 """ Driver function for building a KNN model of a specified type 70 71 **Arguments** 72 73 - examples: the full set of examples 74 75 - numNeigh: number of neighbors for the KNN model (basically k in k-NN) 76 77 - knnModel: the type of KNN model (a classification vs regression model) 78 79 - holdOutFrac: the fraction of the data which should be reserved for the hold-out set 80 (used to calculate error) 81 82 - silent: a toggle used to control how much visual noise this makes as it goes 83 84 - calcTotalError: a toggle used to indicate whether the classification error 85 of the tree should be calculated using the entire data set (when true) or just 86 the training hold out set (when false) 87 """ 88 89 nTot = len(examples) 90 if not kwargs.get('replacementSelection',0): 91 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac, 92 silent=1,legacy=1, 93 replacement=0) 94 else: 95 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac, 96 silent=1,legacy=0, 97 replacement=1) 98 trainExamples = [examples[x] for x in trainIndices] 99 testExamples = [examples[x] for x in testIndices] 100 101 102 nTrain = len(trainExamples) 103 104 if not silent: 105 print "Training with %d examples"%(nTrain) 106 107 knnMod = modelBuilder(numNeigh, attrs, distFunc) 108 109 knnMod.SetTrainingExamples(trainExamples) 110 knnMod.SetTestExamples(testExamples) 111 112 if not calcTotalError: 113 xValError,badExamples = CrossValidate(knnMod, testExamples,appendExamples=1) 114 else: 115 xValError,badExamples = CrossValidate(knnMod, examples,appendExamples=0) 116 117 if not silent : 118 'Validation error was %%%4.2f'%(100*xValError) 119 120 knnMod._trainIndices = trainIndices 121 return knnMod, xValError
122