Package ML :: Package NaiveBayes :: Module CrossValidate
[hide private]
[frames] | no frames]

Source Code for Module ML.NaiveBayes.CrossValidate

 1  # $Id: CrossValidate.py 2 2006-05-06 22:54:39Z glandrum $
 
 2  #
 
 3  #  Copyright (C) 2004-2005 Rational Discovery LLC.
 
 4  #   All Rights Reserved
 
 5  #
 
 6  """ handles doing cross validation with naive bayes models
 
 7  and evaluation of individual models
 
 8  
 
 9  """ 
10  from ML.NaiveBayes.ClassificationModel import NaiveBayesClassifier 
11  from ML.Data import SplitData 
12  from ML.FeatureSelect import CMIM 
13  
 
14 -def makeNBClassificationModel(trainExamples, attrs, nPossibleValues, nQuantBounds, 15 mEstimateVal=-1.0, 16 useSigs=False, 17 ensemble=None,useCMIM=0, 18 **kwargs) :
19 if useCMIM > 0 and useSigs and not ensemble: 20 ensemble = CMIM.SelectFeatures(trainExamples,useCMIM,bvCol=1) 21 if ensemble: 22 attrs = ensemble 23 model = NaiveBayesClassifier(attrs, nPossibleValues, nQuantBounds, 24 mEstimateVal=mEstimateVal,useSigs=useSigs) 25 26 27 model.SetTrainingExamples(trainExamples) 28 model.trainModel() 29 return model
30
31 -def CrossValidate(NBmodel, testExamples, appendExamples=0) :
32 33 nTest = len(testExamples) 34 assert nTest,'no test examples: %s'%str(testExamples) 35 badExamples = [] 36 nBad = 0 37 preds = NBmodel.ClassifyExamples(testExamples, appendExamples) 38 assert len(preds) == nTest 39 40 for i in range(nTest): 41 testEg = testExamples[i] 42 trueRes = testEg[-1] 43 res = preds[i] 44 45 if (trueRes != res) : 46 badExamples.append(testEg) 47 nBad += 1 48 return float(nBad)/nTest, badExamples
49
50 -def CrossValidationDriver(examples, attrs, nPossibleValues, nQuantBounds, 51 mEstimateVal=0.0, 52 holdOutFrac=0.3, modelBuilder=makeNBClassificationModel, 53 silent=0, calcTotalError=0, **kwargs) :
54 nTot = len(examples) 55 if not kwargs.get('replacementSelection',0): 56 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac, 57 silent=1,legacy=1, 58 replacement=0) 59 else : 60 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac, 61 silent=1,legacy=0, 62 replacement=1) 63 64 trainExamples = [examples[x] for x in trainIndices] 65 testExamples = [examples[x] for x in testIndices] 66 67 NBmodel = modelBuilder(trainExamples, attrs, nPossibleValues, nQuantBounds, 68 mEstimateVal,**kwargs) 69 70 if not calcTotalError: # 71 xValError, badExamples = CrossValidate(NBmodel, testExamples,appendExamples=1) 72 else: 73 xValError,badExamples = CrossValidate(NBmodel, examples,appendExamples=0) 74 75 if not silent: 76 print 'Validation error was %%%4.2f'%(100*xValError) 77 NBmodel._trainIndices = trainIndices 78 return NBmodel, xValError
79