Package rdkit :: Package ML :: Package DecTree :: Module BuildSigTree
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.DecTree.BuildSigTree

  1  ## Automatically adapted for numpy.oldnumeric Jun 27, 2008 by -c 
  2   
  3  # $Id$ 
  4  # 
  5  #  Copyright (C) 2003-2008  Greg Landrum and Rational Discovery LLC 
  6  #    All Rights Reserved 
  7  # 
  8  """  
  9   
 10  """ 
 11  from __future__ import print_function 
 12  import numpy 
 13  from rdkit.ML.DecTree import SigTree 
 14  from rdkit.ML import InfoTheory 
 15  try: 
 16    from rdkit.ML.FeatureSelect import CMIM 
 17  except ImportError: 
 18    CMIM=None 
 19  from rdkit.DataStructs.VectCollection import VectCollection 
 20  import copy 
 21  import random 
22 -def _GenerateRandomEnsemble(nToInclude,nBits):
23 """ Generates a random subset of a group of indices 24 25 **Arguments** 26 27 - nToInclude: the size of the desired set 28 29 - nBits: the maximum index to be included in the set 30 31 **Returns** 32 33 a list of indices 34 35 """ 36 # Before Python 2.3 added the random.sample() function, this was 37 # way more complicated: 38 res = random.sample(range(nBits),nToInclude) 39 return res
40
41 -def BuildSigTree(examples,nPossibleRes,ensemble=None,random=0, 42 metric=InfoTheory.InfoType.BIASENTROPY, 43 biasList=[1], 44 depth=0,maxDepth=-1, 45 useCMIM=0,allowCollections=False, 46 verbose=0,**kwargs):
47 """ 48 **Arguments** 49 50 - examples: the examples to be classified. Each example 51 should be a sequence at least three entries long, with 52 entry 0 being a label, entry 1 a BitVector and entry -1 53 an activity value 54 55 - nPossibleRes: the number of result codes possible 56 57 - ensemble: (optional) if this argument is provided, it 58 should be a sequence which is used to limit the bits 59 which are actually considered as potential descriptors. 60 The default is None (use all bits). 61 62 - random: (optional) If this argument is nonzero, it 63 specifies the number of bits to be randomly selected 64 for consideration at this node (i.e. this toggles the 65 growth of Random Trees). 66 The default is 0 (no random descriptor selection) 67 68 - metric: (optional) This is an _InfoTheory.InfoType_ and 69 sets the metric used to rank the bits. 70 The default is _InfoTheory.InfoType.BIASENTROPY_ 71 72 - biasList: (optional) If provided, this provides a bias 73 list for the bit ranker. 74 See the _InfoTheory.InfoBitRanker_ docs for an explanation 75 of bias. 76 The default value is [1], which biases towards actives. 77 78 - maxDepth: (optional) the maximum depth to which the tree 79 will be grown 80 The default is -1 (no depth limit). 81 82 - useCMIM: (optional) if this is >0, the CMIM algorithm 83 (conditional mutual information maximization) will be 84 used to select the descriptors used to build the trees. 85 The value of the variable should be set to the number 86 of descriptors to be used. This option and the 87 ensemble option are mutually exclusive (CMIM will not be 88 used if the ensemble is set), but it happily coexsts 89 with the random argument (to only consider random subsets 90 of the top N CMIM bits) 91 The default is 0 (do not use CMIM) 92 93 - depth: (optional) the current depth in the tree 94 This is used in the recursion and should not be set 95 by the client. 96 97 **Returns** 98 99 a SigTree.SigTreeNode with the root of the decision tree 100 101 """ 102 if verbose: print(' '*depth,'Build') 103 tree=SigTree.SigTreeNode(None,'node',level=depth) 104 tree.SetData(-666) 105 #tree.SetExamples(examples) 106 107 # counts of each result code: 108 #resCodes = map(lambda x:int(x[-1]),examples) 109 resCodes = [int(x[-1]) for x in examples] 110 #print('resCodes:',resCodes) 111 counts = [0]*nPossibleRes 112 for res in resCodes: 113 counts[res] += 1 114 #print(' '*depth,'counts:',counts) 115 116 nzCounts = numpy.nonzero(counts)[0] 117 if verbose: print(' '*depth,'\tcounts:',counts) 118 if len(nzCounts) == 1: 119 # bottomed out because there is only one result code left 120 # with any counts (i.e. there's only one type of example 121 # left... this is GOOD!). 122 res = nzCounts[0] 123 tree.SetLabel(res) 124 tree.SetName(str(res)) 125 tree.SetTerminal(1) 126 elif maxDepth>=0 and depth>maxDepth: 127 # Bottomed out: max depth hit 128 # We don't really know what to do here, so 129 # use the heuristic of picking the most prevalent 130 # result 131 v = numpy.argmax(counts) 132 tree.SetLabel(v) 133 tree.SetName('%d?'%v) 134 tree.SetTerminal(1) 135 else: 136 # find the variable which gives us the best improvement 137 # We do this with an InfoBitRanker: 138 fp = examples[0][1] 139 nBits = fp.GetNumBits() 140 ranker = InfoTheory.InfoBitRanker(nBits,nPossibleRes,metric) 141 if biasList: ranker.SetBiasList(biasList) 142 if CMIM is not None and useCMIM > 0 and not ensemble: 143 ensemble = CMIM.SelectFeatures(examples,useCMIM,bvCol=1) 144 if random: 145 if ensemble: 146 if len(ensemble)>random: 147 picks = _GenerateRandomEnsemble(random,len(ensemble)) 148 availBits = list(take(ensemble,picks)) 149 else: 150 availBits = range(len(ensemble)) 151 else: 152 availBits = _GenerateRandomEnsemble(random,nBits) 153 else: 154 availBits=None 155 if availBits: 156 ranker.SetMaskBits(availBits) 157 #print(' 2:'*depth,availBits) 158 159 useCollections=isinstance(examples[0][1],VectCollection) 160 for example in examples: 161 #print(' '*depth,example[1].ToBitString(),example[-1]) 162 if not useCollections: 163 ranker.AccumulateVotes(example[1],example[-1]) 164 else: 165 example[1].Reset() 166 ranker.AccumulateVotes(example[1].orVect,example[-1]) 167 168 try: 169 bitInfo = ranker.GetTopN(1)[0] 170 best = int(bitInfo[0]) 171 gain = bitInfo[1] 172 except: 173 import traceback 174 traceback.print_exc() 175 print('get top n failed') 176 gain = -1.0 177 if gain <= 0.0: 178 v = numpy.argmax(counts) 179 tree.SetLabel(v) 180 tree.SetName('?%d?'%v) 181 tree.SetTerminal(1) 182 return tree 183 best = int(bitInfo[0]) 184 #print(' '*depth,'\tbest:',bitInfo) 185 if verbose: print(' '*depth,'\tbest:',bitInfo) 186 # set some info at this node 187 tree.SetName('Bit-%d'%(best)) 188 tree.SetLabel(best) 189 #tree.SetExamples(examples) 190 tree.SetTerminal(0) 191 192 # loop over possible values of the new variable and 193 # build a subtree for each one 194 onExamples = [] 195 offExamples = [] 196 for example in examples: 197 if example[1][best]: 198 if allowCollections and useCollections: 199 sig = copy.copy(example[1]) 200 sig.DetachVectsNotMatchingBit(best) 201 ex = [example[0],sig] 202 if len(example)>2: 203 ex.extend(example[2:]) 204 example = ex 205 onExamples.append(example) 206 else: 207 offExamples.append(example) 208 #print(' '*depth,len(offExamples),len(onExamples)) 209 for ex in (offExamples,onExamples): 210 if len(ex) == 0: 211 v = numpy.argmax(counts) 212 tree.AddChild('%d??'%v,label=v,data=0.0,isTerminal=1) 213 else: 214 child = BuildSigTree(ex,nPossibleRes,random=random, 215 ensemble=ensemble, 216 metric=metric,biasList=biasList, 217 depth=depth+1,maxDepth=maxDepth, 218 verbose=verbose) 219 if child is None: 220 v = numpy.argmax(counts) 221 tree.AddChild('%d???'%v,label=v,data=0.0,isTerminal=1) 222 else: 223 tree.AddChildNode(child) 224 return tree
225 226
227 -def SigTreeBuilder(examples,attrs,nPossibleVals,initialVar=None,ensemble=None, 228 randomDescriptors=0, 229 **kwargs):
230 nRes = nPossibleVals[-1] 231 return BuildSigTree(examples,nRes,random=randomDescriptors,**kwargs)
232