Package ML :: Package KNN :: Module CrossValidate
[hide private]
[frames] | no frames]

Source Code for Module ML.KNN.CrossValidate

  1  # 
  2  #  Copyright (C) 2000  greg Landrum 
  3  # 
  4  """ handles doing cross validation with k-nearest neighbors model 
  5   
  6  and evaluation of individual models 
  7   
  8  """ 
  9   
 10  from Numeric import * 
 11  from ML.KNN.KNNClassificationModel import KNNClassificationModel 
 12  from ML.KNN.KNNRegressionModel import KNNRegressionModel 
 13  import RDRandom 
 14  from ML.KNN import DistFunctions 
 15  from ML.Data import SplitData 
 16   
17 -def makeClassificationModel(numNeigh, attrs, distFunc) :
18 return KNNClassificationModel(numNeigh, attrs, distFunc)
19 -def makeRegressionModel(numNeigh, attrs, distFunc) :
20 return KNNRegressionModel(numNeigh, attrs, distFunc)
21
22 -def CrossValidate(knnMod,testExamples,appendExamples=0):
23 """ 24 Determines the classification error for the testExamples 25 26 **Arguments** 27 28 - tree: a decision tree (or anything supporting a _ClassifyExample()_ method) 29 30 - testExamples: a list of examples to be used for testing 31 32 - appendExamples: a toggle which is passed along to the tree as it does 33 the classification. The trees can use this to store the examples they 34 classify locally. 35 36 **Returns** 37 38 a 2-tuple consisting of: 39 """ 40 nTest = len(testExamples) 41 42 if isinstance(knnMod,KNNClassificationModel): 43 badExamples = [] 44 nBad = 0 45 for i in xrange(nTest): 46 testEx = testExamples[i] 47 trueRes = testEx[-1] 48 res = knnMod.ClassifyExample(testEx, appendExamples) 49 if (trueRes != res) : 50 badExamples.append(testEx) 51 nBad += 1 52 return float(nBad)/nTest, badExamples 53 elif isinstance(knnMod,KNNRegressionModel): 54 devSum=0.0 55 for i in xrange(nTest): 56 testEx = testExamples[i] 57 trueRes = testEx[-1] 58 res = knnMod.PredictExample(testEx, appendExamples) 59 devSum += abs(trueRes-res) 60 return devSum/nTest,None 61 raise ValueError,"Unrecognized Model Type"
62
63 -def CrossValidationDriver(examples, attrs, nPossibleValues, numNeigh, 64 modelBuilder=makeClassificationModel, 65 distFunc=DistFunctions.EuclideanDist, 66 holdOutFrac=0.3, 67 silent=0, 68 calcTotalError=0, 69 **kwargs) :
70 """ Driver function for building a KNN model of a specified type 71 72 **Arguments** 73 74 - examples: the full set of examples 75 76 - numNeigh: number of neighbors for the KNN model (basically k in k-NN) 77 78 - knnModel: the type of KNN model (a classification vs regression model) 79 80 - holdOutFrac: the fraction of the data which should be reserved for the hold-out set 81 (used to calculate error) 82 83 - silent: a toggle used to control how much visual noise this makes as it goes 84 85 - calcTotalError: a toggle used to indicate whether the classification error 86 of the tree should be calculated using the entire data set (when true) or just 87 the training hold out set (when false) 88 """ 89 90 nTot = len(examples) 91 if not kwargs.get('replacementSelection',0): 92 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac, 93 silent=1,legacy=1, 94 replacement=0) 95 else: 96 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac, 97 silent=1,legacy=0, 98 replacement=1) 99 trainExamples = [examples[x] for x in trainIndices] 100 testExamples = [examples[x] for x in testIndices] 101 102 103 nTrain = len(trainExamples) 104 105 if not silent: 106 print "Training with %d examples"%(nTrain) 107 108 knnMod = modelBuilder(numNeigh, attrs, distFunc) 109 110 knnMod.SetTrainingExamples(trainExamples) 111 knnMod.SetTestExamples(testExamples) 112 113 if not calcTotalError: 114 xValError,badExamples = CrossValidate(knnMod, testExamples,appendExamples=1) 115 else: 116 xValError,badExamples = CrossValidate(knnMod, examples,appendExamples=0) 117 118 if not silent : 119 'Validation error was %%%4.2f'%(100*xValError) 120 121 knnMod._trainIndices = trainIndices 122 return knnMod, xValError
123