Package ML :: Package Neural :: Module CrossValidate
[hide private]
[frames] | no frames]

Source Code for Module ML.Neural.CrossValidate

  1  # 
  2  #  Copyright (C) 2000  greg Landrum 
  3  # 
  4  """ handles doing cross validation with neural nets 
  5   
  6  This is, perhaps, a little misleading.  For the purposes of this module, 
  7  cross validation == evaluating the accuracy of a net. 
  8   
  9  """ 
 10   
 11  from ML.Neural import Network,Trainers 
 12  from ML.Data import SplitData 
 13  import math 
 14   
 15  import RDRandom 
 16   
17 -def CrossValidate(net,testExamples,tolerance,appendExamples=0):
18 """ Determines the classification error for the testExamples 19 **Arguments** 20 21 - tree: a decision tree (or anything supporting a _ClassifyExample()_ method) 22 23 - testExamples: a list of examples to be used for testing 24 25 - appendExamples: a toggle which is ignored, it's just here to maintain 26 the same API as the decision tree code. 27 28 **Returns** 29 30 a 2-tuple consisting of: 31 32 1) the percent error of the net 33 34 2) a list of misclassified examples 35 36 **Note** 37 At the moment, this is specific to nets with only one output 38 """ 39 nTest = len(testExamples) 40 nBad = 0 41 badExamples = [] 42 for i in xrange(nTest): 43 testEx = testExamples[i] 44 trueRes = testExamples[i][-1] 45 res = net.ClassifyExample(testEx) 46 if math.fabs(trueRes-res) > tolerance: 47 badExamples.append(testEx) 48 nBad = nBad + 1 49 50 return float(nBad)/nTest,badExamples
51
52 -def CrossValidationDriver(examples,attrs=[],nPossibleVals=[],holdOutFrac=.3,silent=0, 53 tolerance=0.3,calcTotalError=0,hiddenSizes=None, 54 **kwargs):
55 """ 56 **Arguments** 57 58 - examples: the full set of examples 59 60 - attrs: a list of attributes to consider in the tree building 61 *This argument is ignored* 62 63 - nPossibleVals: a list of the number of possible values each variable can adopt 64 *This argument is ignored* 65 66 - holdOutFrac: the fraction of the data which should be reserved for the hold-out set 67 (used to calculate the error) 68 69 - silent: a toggle used to control how much visual noise this makes as it goes. 70 71 - tolerance: the tolerance for convergence of the net 72 73 - calcTotalError: if this is true the entire data set is used to calculate 74 accuracy of the net 75 76 - hiddenSizes: a list containing the size(s) of the hidden layers in the network. 77 if _hiddenSizes_ is None, one hidden layer containing the same number of nodes 78 as the input layer will be used 79 80 **Returns** 81 82 a 2-tuple containing: 83 84 1) the net 85 86 2) the cross-validation error of the net 87 88 **Note** 89 At the moment, this is specific to nets with only one output 90 91 """ 92 nTot = len(examples) 93 if not kwargs.get('replacementSelection',0): 94 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac, 95 silent=1,legacy=1, 96 replacement=0) 97 else: 98 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac, 99 silent=1,legacy=0, 100 replacement=1) 101 trainExamples = [examples[x] for x in trainIndices] 102 testExamples = [examples[x] for x in testIndices] 103 104 nTrain = len(trainExamples) 105 if not silent: 106 print 'Training with %d examples'%(nTrain) 107 108 nInput = len(allExamples[0])-1 109 nOutput = 1 110 if hiddenSizes is None: 111 nHidden = nInput 112 netSize = [nInput,nHidden,nOutput] 113 else: 114 netSize = [nInput] + hiddenSizes + [nOutput] 115 net = Network.Network(netSize) 116 t = Trainers.BackProp() 117 t.TrainOnLine(trainExamples,net,errTol=tolerance,useAvgErr=0,silent=silent) 118 119 120 nTest = len(testExamples) 121 if not silent: 122 print 'Testing with %d examples'%nTest 123 if not calcTotalError: 124 xValError,badExamples = CrossValidate(net,testExamples,tolerance) 125 else: 126 xValError,badExamples = CrossValidate(net,allExamples,tolerance) 127 if not silent: 128 print 'Validation error was %%%4.2f'%(100*xValError) 129 net._trainIndices=trainIndices 130 return net,xValError
131