Package ML :: Package DecTree :: Module BuildSigTree
[hide private]
[frames] | no frames]

Source Code for Module ML.DecTree.BuildSigTree

  1  # $Id: BuildSigTree.py 2 2006-05-06 22:54:39Z glandrum $ 
  2  # 
  3  #  Copyright (C) 2003-2005  Rational Discovery LLC 
  4  #    All Rights Reserved 
  5  # 
  6  """  
  7   
  8  """ 
  9   
 10  from Numeric import * 
 11  import RandomArray 
 12  from ML.DecTree import SigTree 
 13  from ML import InfoTheory 
 14  from ML.FeatureSelect import CMIM 
 15  from DataStructs.VectCollection import VectCollection 
 16  import copy 
 17  import random 
18 -def _GenerateRandomEnsemble(nToInclude,nBits):
19 """ Generates a random subset of a group of indices 20 21 **Arguments** 22 23 - nToInclude: the size of the desired set 24 25 - nBits: the maximum index to be included in the set 26 27 **Returns** 28 29 a list of indices 30 31 """ 32 # Before Python 2.3 added the random.sample() function, this was 33 # way more complicated: 34 res = random.sample(xrange(nBits),nToInclude) 35 return res
36
37 -def BuildSigTree(examples,nPossibleRes,ensemble=None,random=0, 38 metric=InfoTheory.InfoType.BIASENTROPY, 39 biasList=[1], 40 depth=0,maxDepth=-1, 41 useCMIM=0,allowCollections=False, 42 verbose=0,**kwargs):
43 """ 44 **Arguments** 45 46 - examples: the examples to be classified. Each example 47 should be a sequence at least three entries long, with 48 entry 0 being a label, entry 1 a BitVector and entry -1 49 an activity value 50 51 - nPossibleRes: the number of result codes possible 52 53 - ensemble: (optional) if this argument is provided, it 54 should be a sequence which is used to limit the bits 55 which are actually considered as potential descriptors. 56 The default is None (use all bits). 57 58 - random: (optional) If this argument is nonzero, it 59 specifies the number of bits to be randomly selected 60 for consideration at this node (i.e. this toggles the 61 growth of Random Trees). 62 The default is 0 (no random descriptor selection) 63 64 - metric: (optional) This is an _InfoTheory.InfoType_ and 65 sets the metric used to rank the bits. 66 The default is _InfoTheory.InfoType.BIASENTROPY_ 67 68 - biasList: (optional) If provided, this provides a bias 69 list for the bit ranker. 70 See the _InfoTheory.InfoBitRanker_ docs for an explanation 71 of bias. 72 The default value is [1], which biases towards actives. 73 74 - maxDepth: (optional) the maximum depth to which the tree 75 will be grown 76 The default is -1 (no depth limit). 77 78 - useCMIM: (optional) if this is >0, the CMIM algorithm 79 (conditional mutual information maximization) will be 80 used to select the descriptors used to build the trees. 81 The value of the variable should be set to the number 82 of descriptors to be used. This option and the 83 ensemble option are mutually exclusive (CMIM will not be 84 used if the ensemble is set), but it happily coexsts 85 with the random argument (to only consider random subsets 86 of the top N CMIM bits) 87 The default is 0 (do not use CMIM) 88 89 - depth: (optional) the current depth in the tree 90 This is used in the recursion and should not be set 91 by the client. 92 93 **Returns** 94 95 a SigTree.SigTreeNode with the root of the decision tree 96 97 """ 98 if verbose: print ' '*depth,'Build' 99 tree=SigTree.SigTreeNode(None,'node',level=depth) 100 tree.SetData(-666) 101 #tree.SetExamples(examples) 102 103 # counts of each result code: 104 #resCodes = map(lambda x:int(x[-1]),examples) 105 resCodes = [int(x[-1]) for x in examples] 106 #print 'resCodes:',resCodes 107 counts = [0]*nPossibleRes 108 for res in resCodes: 109 counts[res] += 1 110 #print ' '*depth,'counts:',counts 111 112 nzCounts = nonzero(counts) 113 if verbose: print ' '*depth,'\tcounts:',counts 114 if len(nzCounts) == 1: 115 # bottomed out because there is only one result code left 116 # with any counts (i.e. there's only one type of example 117 # left... this is GOOD!). 118 res = nzCounts[0] 119 tree.SetLabel(res) 120 tree.SetName(str(res)) 121 tree.SetTerminal(1) 122 elif maxDepth>=0 and depth>maxDepth: 123 # Bottomed out: max depth hit 124 # We don't really know what to do here, so 125 # use the heuristic of picking the most prevalent 126 # result 127 v = argmax(counts) 128 tree.SetLabel(v) 129 tree.SetName('%d?'%v) 130 tree.SetTerminal(1) 131 else: 132 # find the variable which gives us the best improvement 133 # We do this with an InfoBitRanker: 134 fp = examples[0][1] 135 nBits = fp.GetNumBits() 136 ranker = InfoTheory.InfoBitRanker(nBits,nPossibleRes,metric) 137 if biasList: ranker.SetBiasList(biasList) 138 if useCMIM > 0 and not ensemble: 139 ensemble = CMIM.SelectFeatures(examples,useCMIM,bvCol=1) 140 if random: 141 if ensemble: 142 if len(ensemble)>random: 143 picks = _GenerateRandomEnsemble(random,len(ensemble)) 144 availBits = list(take(ensemble,picks)) 145 else: 146 availBits = range(len(ensemble)) 147 else: 148 availBits = _GenerateRandomEnsemble(random,nBits) 149 else: 150 availBits=None 151 if availBits: 152 ranker.SetMaskBits(availBits) 153 #print ' 2:'*depth,availBits 154 155 useCollections=isinstance(examples[0][1],VectCollection) 156 for example in examples: 157 #print ' '*depth,example[1].ToBitString(),example[-1] 158 if not useCollections: 159 ranker.AccumulateVotes(example[1],example[-1]) 160 else: 161 example[1].Reset() 162 ranker.AccumulateVotes(example[1].orVect,example[-1]) 163 164 try: 165 bitInfo = ranker.GetTopN(1)[0] 166 best = int(bitInfo[0]) 167 gain = bitInfo[1] 168 except: 169 import traceback 170 traceback.print_exc() 171 print 'get top n failed' 172 gain = -1.0 173 if gain <= 0.0: 174 v = argmax(counts) 175 tree.SetLabel(v) 176 tree.SetName('?%d?'%v) 177 tree.SetTerminal(1) 178 return tree 179 best = int(bitInfo[0]) 180 #print ' '*depth,'\tbest:',bitInfo 181 if verbose: print ' '*depth,'\tbest:',bitInfo 182 # set some info at this node 183 tree.SetName('Bit-%d'%(best)) 184 tree.SetLabel(best) 185 #tree.SetExamples(examples) 186 tree.SetTerminal(0) 187 188 # loop over possible values of the new variable and 189 # build a subtree for each one 190 onExamples = [] 191 offExamples = [] 192 for example in examples: 193 if example[1][best]: 194 if allowCollections and useCollections: 195 sig = copy.copy(example[1]) 196 sig.DetachVectsNotMatchingBit(best) 197 ex = [example[0],sig] 198 if len(example)>2: 199 ex.extend(example[2:]) 200 example = ex 201 onExamples.append(example) 202 else: 203 offExamples.append(example) 204 #print ' '*depth,len(offExamples),len(onExamples) 205 for ex in (offExamples,onExamples): 206 if len(ex) == 0: 207 v = argmax(counts) 208 tree.AddChild('%d??'%v,label=v,data=0.0,isTerminal=1) 209 else: 210 child = BuildSigTree(ex,nPossibleRes,random=random, 211 ensemble=ensemble, 212 metric=metric,biasList=biasList, 213 depth=depth+1,maxDepth=maxDepth, 214 verbose=verbose) 215 if child is None: 216 v = argmax(counts) 217 tree.AddChild('%d???'%v,label=v,data=0.0,isTerminal=1) 218 else: 219 tree.AddChildNode(child) 220 return tree
221 222
223 -def SigTreeBuilder(examples,attrs,nPossibleVals,initialVar=None,ensemble=None, 224 randomDescriptors=0, 225 **kwargs):
226 nRes = nPossibleVals[-1] 227 return BuildSigTree(examples,nRes,random=randomDescriptors,**kwargs)
228
229 -def testTrees1():
230 import DataStructs 231 pts = [ ['p1',[3],0], 232 ['p2',[1,3],1], 233 ['p3',[2],1], 234 ['p4',[0],1], 235 ['p5',[],0], 236 ] 237 for pt in pts: 238 bv = DataStructs.ExplicitBitVect(4) 239 if pt[1]: 240 bv.SetBitsFromList(pt[1]) 241 pt[1] = bv 242 t = BuildSigTree(pts,2) 243 t.Print() 244 245 RandomArray.seed(23,42) 246 t2 = BuildSigTree(pts,2,random=4) 247 t2.Print() 248 # these two happen to come out the same: 249 assert t==t2 250 251 RandomArray.seed(23,42) 252 # this is just screwing around: 253 t3 = BuildSigTree(pts,2,random=4,ensemble=range(4)) 254 assert t3==t2 255 256 t3 = BuildSigTree(pts,2,useCMIM=3) 257 print t3
258 259 if __name__ == "__main__": 260 testTrees1() 261