1
2
3
4 """ handles doing cross validation with neural nets
5
6 This is, perhaps, a little misleading. For the purposes of this module,
7 cross validation == evaluating the accuracy of a net.
8
9 """
10
11 from ML.Neural import Network,Trainers
12 from ML.Data import SplitData
13 import math
14
15 import RDRandom
16
18 """ Determines the classification error for the testExamples
19 **Arguments**
20
21 - tree: a decision tree (or anything supporting a _ClassifyExample()_ method)
22
23 - testExamples: a list of examples to be used for testing
24
25 - appendExamples: a toggle which is ignored, it's just here to maintain
26 the same API as the decision tree code.
27
28 **Returns**
29
30 a 2-tuple consisting of:
31
32 1) the percent error of the net
33
34 2) a list of misclassified examples
35
36 **Note**
37 At the moment, this is specific to nets with only one output
38 """
39 nTest = len(testExamples)
40 nBad = 0
41 badExamples = []
42 for i in xrange(nTest):
43 testEx = testExamples[i]
44 trueRes = testExamples[i][-1]
45 res = net.ClassifyExample(testEx)
46 if math.fabs(trueRes-res) > tolerance:
47 badExamples.append(testEx)
48 nBad = nBad + 1
49
50 return float(nBad)/nTest,badExamples
51
52 -def CrossValidationDriver(examples,attrs=[],nPossibleVals=[],holdOutFrac=.3,silent=0,
53 tolerance=0.3,calcTotalError=0,hiddenSizes=None,
54 **kwargs):
55 """
56 **Arguments**
57
58 - examples: the full set of examples
59
60 - attrs: a list of attributes to consider in the tree building
61 *This argument is ignored*
62
63 - nPossibleVals: a list of the number of possible values each variable can adopt
64 *This argument is ignored*
65
66 - holdOutFrac: the fraction of the data which should be reserved for the hold-out set
67 (used to calculate the error)
68
69 - silent: a toggle used to control how much visual noise this makes as it goes.
70
71 - tolerance: the tolerance for convergence of the net
72
73 - calcTotalError: if this is true the entire data set is used to calculate
74 accuracy of the net
75
76 - hiddenSizes: a list containing the size(s) of the hidden layers in the network.
77 if _hiddenSizes_ is None, one hidden layer containing the same number of nodes
78 as the input layer will be used
79
80 **Returns**
81
82 a 2-tuple containing:
83
84 1) the net
85
86 2) the cross-validation error of the net
87
88 **Note**
89 At the moment, this is specific to nets with only one output
90
91 """
92 nTot = len(examples)
93 if not kwargs.get('replacementSelection',0):
94 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac,
95 silent=1,legacy=1,
96 replacement=0)
97 else:
98 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac,
99 silent=1,legacy=0,
100 replacement=1)
101 trainExamples = [examples[x] for x in trainIndices]
102 testExamples = [examples[x] for x in testIndices]
103
104 nTrain = len(trainExamples)
105 if not silent:
106 print 'Training with %d examples'%(nTrain)
107
108 nInput = len(allExamples[0])-1
109 nOutput = 1
110 if hiddenSizes is None:
111 nHidden = nInput
112 netSize = [nInput,nHidden,nOutput]
113 else:
114 netSize = [nInput] + hiddenSizes + [nOutput]
115 net = Network.Network(netSize)
116 t = Trainers.BackProp()
117 t.TrainOnLine(trainExamples,net,errTol=tolerance,useAvgErr=0,silent=silent)
118
119
120 nTest = len(testExamples)
121 if not silent:
122 print 'Testing with %d examples'%nTest
123 if not calcTotalError:
124 xValError,badExamples = CrossValidate(net,testExamples,tolerance)
125 else:
126 xValError,badExamples = CrossValidate(net,allExamples,tolerance)
127 if not silent:
128 print 'Validation error was %%%4.2f'%(100*xValError)
129 net._trainIndices=trainIndices
130 return net,xValError
131