Package rdkit :: Package ML :: Package DecTree :: Module BuildSigTree
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.DecTree.BuildSigTree

  1  ## Automatically adapted for numpy.oldnumeric Jun 27, 2008 by -c 
  2   
  3  # $Id$ 
  4  # 
  5  #  Copyright (C) 2003-2008  Greg Landrum and Rational Discovery LLC 
  6  #    All Rights Reserved 
  7  # 
  8  """  
  9   
 10  """ 
 11  from __future__ import print_function 
 12  import numpy 
 13  from rdkit.ML.DecTree import SigTree 
 14  from rdkit.ML import InfoTheory 
 15  try: 
 16    from rdkit.ML.FeatureSelect import CMIM 
 17  except ImportError: 
 18    CMIM = None 
 19  from rdkit.DataStructs.VectCollection import VectCollection 
 20  import copy 
 21  import random 
 22   
 23   
24 -def _GenerateRandomEnsemble(nToInclude, nBits):
25 """ Generates a random subset of a group of indices 26 27 **Arguments** 28 29 - nToInclude: the size of the desired set 30 31 - nBits: the maximum index to be included in the set 32 33 **Returns** 34 35 a list of indices 36 37 """ 38 # Before Python 2.3 added the random.sample() function, this was 39 # way more complicated: 40 res = random.sample(range(nBits), nToInclude) 41 return res
42 43
44 -def BuildSigTree(examples, nPossibleRes, ensemble=None, random=0, 45 metric=InfoTheory.InfoType.BIASENTROPY, biasList=[1], depth=0, maxDepth=-1, 46 useCMIM=0, allowCollections=False, verbose=0, **kwargs):
47 """ 48 **Arguments** 49 50 - examples: the examples to be classified. Each example 51 should be a sequence at least three entries long, with 52 entry 0 being a label, entry 1 a BitVector and entry -1 53 an activity value 54 55 - nPossibleRes: the number of result codes possible 56 57 - ensemble: (optional) if this argument is provided, it 58 should be a sequence which is used to limit the bits 59 which are actually considered as potential descriptors. 60 The default is None (use all bits). 61 62 - random: (optional) If this argument is nonzero, it 63 specifies the number of bits to be randomly selected 64 for consideration at this node (i.e. this toggles the 65 growth of Random Trees). 66 The default is 0 (no random descriptor selection) 67 68 - metric: (optional) This is an _InfoTheory.InfoType_ and 69 sets the metric used to rank the bits. 70 The default is _InfoTheory.InfoType.BIASENTROPY_ 71 72 - biasList: (optional) If provided, this provides a bias 73 list for the bit ranker. 74 See the _InfoTheory.InfoBitRanker_ docs for an explanation 75 of bias. 76 The default value is [1], which biases towards actives. 77 78 - maxDepth: (optional) the maximum depth to which the tree 79 will be grown 80 The default is -1 (no depth limit). 81 82 - useCMIM: (optional) if this is >0, the CMIM algorithm 83 (conditional mutual information maximization) will be 84 used to select the descriptors used to build the trees. 85 The value of the variable should be set to the number 86 of descriptors to be used. This option and the 87 ensemble option are mutually exclusive (CMIM will not be 88 used if the ensemble is set), but it happily coexsts 89 with the random argument (to only consider random subsets 90 of the top N CMIM bits) 91 The default is 0 (do not use CMIM) 92 93 - depth: (optional) the current depth in the tree 94 This is used in the recursion and should not be set 95 by the client. 96 97 **Returns** 98 99 a SigTree.SigTreeNode with the root of the decision tree 100 101 """ 102 if verbose: 103 print(' ' * depth, 'Build') 104 tree = SigTree.SigTreeNode(None, 'node', level=depth) 105 tree.SetData(-666) 106 #tree.SetExamples(examples) 107 108 # counts of each result code: 109 #resCodes = map(lambda x:int(x[-1]),examples) 110 resCodes = [int(x[-1]) for x in examples] 111 #print('resCodes:',resCodes) 112 counts = [0] * nPossibleRes 113 for res in resCodes: 114 counts[res] += 1 115 #print(' '*depth,'counts:',counts) 116 117 nzCounts = numpy.nonzero(counts)[0] 118 if verbose: 119 print(' ' * depth, '\tcounts:', counts) 120 if len(nzCounts) == 1: 121 # bottomed out because there is only one result code left 122 # with any counts (i.e. there's only one type of example 123 # left... this is GOOD!). 124 res = nzCounts[0] 125 tree.SetLabel(res) 126 tree.SetName(str(res)) 127 tree.SetTerminal(1) 128 elif maxDepth >= 0 and depth > maxDepth: 129 # Bottomed out: max depth hit 130 # We don't really know what to do here, so 131 # use the heuristic of picking the most prevalent 132 # result 133 v = numpy.argmax(counts) 134 tree.SetLabel(v) 135 tree.SetName('%d?' % v) 136 tree.SetTerminal(1) 137 else: 138 # find the variable which gives us the best improvement 139 # We do this with an InfoBitRanker: 140 fp = examples[0][1] 141 nBits = fp.GetNumBits() 142 ranker = InfoTheory.InfoBitRanker(nBits, nPossibleRes, metric) 143 if biasList: 144 ranker.SetBiasList(biasList) 145 if CMIM is not None and useCMIM > 0 and not ensemble: 146 ensemble = CMIM.SelectFeatures(examples, useCMIM, bvCol=1) 147 if random: 148 if ensemble: 149 if len(ensemble) > random: 150 picks = _GenerateRandomEnsemble(random, len(ensemble)) 151 availBits = list(take(ensemble, picks)) 152 else: 153 availBits = range(len(ensemble)) 154 else: 155 availBits = _GenerateRandomEnsemble(random, nBits) 156 else: 157 availBits = None 158 if availBits: 159 ranker.SetMaskBits(availBits) 160 #print(' 2:'*depth,availBits) 161 162 useCollections = isinstance(examples[0][1], VectCollection) 163 for example in examples: 164 #print(' '*depth,example[1].ToBitString(),example[-1]) 165 if not useCollections: 166 ranker.AccumulateVotes(example[1], example[-1]) 167 else: 168 example[1].Reset() 169 ranker.AccumulateVotes(example[1].orVect, example[-1]) 170 171 try: 172 bitInfo = ranker.GetTopN(1)[0] 173 best = int(bitInfo[0]) 174 gain = bitInfo[1] 175 except Exception: 176 import traceback 177 traceback.print_exc() 178 print('get top n failed') 179 gain = -1.0 180 if gain <= 0.0: 181 v = numpy.argmax(counts) 182 tree.SetLabel(v) 183 tree.SetName('?%d?' % v) 184 tree.SetTerminal(1) 185 return tree 186 best = int(bitInfo[0]) 187 #print(' '*depth,'\tbest:',bitInfo) 188 if verbose: 189 print(' ' * depth, '\tbest:', bitInfo) 190 # set some info at this node 191 tree.SetName('Bit-%d' % (best)) 192 tree.SetLabel(best) 193 #tree.SetExamples(examples) 194 tree.SetTerminal(0) 195 196 # loop over possible values of the new variable and 197 # build a subtree for each one 198 onExamples = [] 199 offExamples = [] 200 for example in examples: 201 if example[1][best]: 202 if allowCollections and useCollections: 203 sig = copy.copy(example[1]) 204 sig.DetachVectsNotMatchingBit(best) 205 ex = [example[0], sig] 206 if len(example) > 2: 207 ex.extend(example[2:]) 208 example = ex 209 onExamples.append(example) 210 else: 211 offExamples.append(example) 212 #print(' '*depth,len(offExamples),len(onExamples)) 213 for ex in (offExamples, onExamples): 214 if len(ex) == 0: 215 v = numpy.argmax(counts) 216 tree.AddChild('%d??' % v, label=v, data=0.0, isTerminal=1) 217 else: 218 child = BuildSigTree(ex, nPossibleRes, random=random, ensemble=ensemble, metric=metric, 219 biasList=biasList, depth=depth + 1, maxDepth=maxDepth, verbose=verbose) 220 if child is None: 221 v = numpy.argmax(counts) 222 tree.AddChild('%d???' % v, label=v, data=0.0, isTerminal=1) 223 else: 224 tree.AddChildNode(child) 225 return tree
226 227
228 -def SigTreeBuilder(examples, attrs, nPossibleVals, initialVar=None, ensemble=None, 229 randomDescriptors=0, **kwargs):
230 nRes = nPossibleVals[-1] 231 return BuildSigTree(examples, nRes, random=randomDescriptors, **kwargs)
232