1
2
3
4
5
6 """ handles doing cross validation with naive bayes models
7 and evaluation of individual models
8
9 """
10 from rdkit.ML.NaiveBayes.ClassificationModel import NaiveBayesClassifier
11 from rdkit.ML.Data import SplitData
12 from rdkit.ML.FeatureSelect import CMIM
13
14 -def makeNBClassificationModel(trainExamples, attrs, nPossibleValues, nQuantBounds,
15 mEstimateVal=-1.0,
16 useSigs=False,
17 ensemble=None,useCMIM=0,
18 **kwargs) :
19 if useCMIM > 0 and useSigs and not ensemble:
20 ensemble = CMIM.SelectFeatures(trainExamples,useCMIM,bvCol=1)
21 if ensemble:
22 attrs = ensemble
23 model = NaiveBayesClassifier(attrs, nPossibleValues, nQuantBounds,
24 mEstimateVal=mEstimateVal,useSigs=useSigs)
25
26
27 model.SetTrainingExamples(trainExamples)
28 model.trainModel()
29 return model
30
32
33 nTest = len(testExamples)
34 assert nTest,'no test examples: %s'%str(testExamples)
35 badExamples = []
36 nBad = 0
37 preds = NBmodel.ClassifyExamples(testExamples, appendExamples)
38 assert len(preds) == nTest
39
40 for i in range(nTest):
41 testEg = testExamples[i]
42 trueRes = testEg[-1]
43 res = preds[i]
44
45 if (trueRes != res) :
46 badExamples.append(testEg)
47 nBad += 1
48 return float(nBad)/nTest, badExamples
49
54 nTot = len(examples)
55 if not kwargs.get('replacementSelection',0):
56 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac,
57 silent=1,legacy=1,
58 replacement=0)
59 else :
60 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac,
61 silent=1,legacy=0,
62 replacement=1)
63
64 trainExamples = [examples[x] for x in trainIndices]
65 testExamples = [examples[x] for x in testIndices]
66
67 NBmodel = modelBuilder(trainExamples, attrs, nPossibleValues, nQuantBounds,
68 mEstimateVal,**kwargs)
69
70 if not calcTotalError:
71 xValError, badExamples = CrossValidate(NBmodel, testExamples,appendExamples=1)
72 else:
73 xValError,badExamples = CrossValidate(NBmodel, examples,appendExamples=0)
74
75 if not silent:
76 print 'Validation error was %%%4.2f'%(100*xValError)
77 NBmodel._trainIndices = trainIndices
78 return NBmodel, xValError
79