Package rdkit :: Package ML :: Package DecTree :: Module BuildSigTree
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.DecTree.BuildSigTree

  1  # $Id$ 
  2  # 
  3  #  Copyright (C) 2003-2008  Greg Landrum and Rational Discovery LLC 
  4  #    All Rights Reserved 
  5  # 
  6  """ 
  7   
  8  """ 
  9  from __future__ import print_function 
 10   
 11  import copy 
 12  import random 
 13   
 14  import numpy 
 15   
 16  from rdkit.DataStructs.VectCollection import VectCollection 
 17  from rdkit.ML import InfoTheory 
 18  from rdkit.ML.DecTree import SigTree 
 19   
 20  try: 
 21    from rdkit.ML.FeatureSelect import CMIM 
 22  except ImportError: 
 23    CMIM = None 
 24   
 25   
26 -def _GenerateRandomEnsemble(nToInclude, nBits):
27 """ Generates a random subset of a group of indices 28 29 **Arguments** 30 31 - nToInclude: the size of the desired set 32 33 - nBits: the maximum index to be included in the set 34 35 **Returns** 36 37 a list of indices 38 39 """ 40 # Before Python 2.3 added the random.sample() function, this was 41 # way more complicated: 42 return random.sample(range(nBits), nToInclude)
43 44
45 -def BuildSigTree(examples, nPossibleRes, ensemble=None, random=0, 46 metric=InfoTheory.InfoType.BIASENTROPY, biasList=[1], depth=0, maxDepth=-1, 47 useCMIM=0, allowCollections=False, verbose=0, **kwargs):
48 """ 49 **Arguments** 50 51 - examples: the examples to be classified. Each example 52 should be a sequence at least three entries long, with 53 entry 0 being a label, entry 1 a BitVector and entry -1 54 an activity value 55 56 - nPossibleRes: the number of result codes possible 57 58 - ensemble: (optional) if this argument is provided, it 59 should be a sequence which is used to limit the bits 60 which are actually considered as potential descriptors. 61 The default is None (use all bits). 62 63 - random: (optional) If this argument is nonzero, it 64 specifies the number of bits to be randomly selected 65 for consideration at this node (i.e. this toggles the 66 growth of Random Trees). 67 The default is 0 (no random descriptor selection) 68 69 - metric: (optional) This is an _InfoTheory.InfoType_ and 70 sets the metric used to rank the bits. 71 The default is _InfoTheory.InfoType.BIASENTROPY_ 72 73 - biasList: (optional) If provided, this provides a bias 74 list for the bit ranker. 75 See the _InfoTheory.InfoBitRanker_ docs for an explanation 76 of bias. 77 The default value is [1], which biases towards actives. 78 79 - maxDepth: (optional) the maximum depth to which the tree 80 will be grown 81 The default is -1 (no depth limit). 82 83 - useCMIM: (optional) if this is >0, the CMIM algorithm 84 (conditional mutual information maximization) will be 85 used to select the descriptors used to build the trees. 86 The value of the variable should be set to the number 87 of descriptors to be used. This option and the 88 ensemble option are mutually exclusive (CMIM will not be 89 used if the ensemble is set), but it happily coexsts 90 with the random argument (to only consider random subsets 91 of the top N CMIM bits) 92 The default is 0 (do not use CMIM) 93 94 - depth: (optional) the current depth in the tree 95 This is used in the recursion and should not be set 96 by the client. 97 98 **Returns** 99 100 a SigTree.SigTreeNode with the root of the decision tree 101 102 """ 103 if verbose: 104 print(' ' * depth, 'Build') 105 tree = SigTree.SigTreeNode(None, 'node', level=depth) 106 tree.SetData(-666) 107 # tree.SetExamples(examples) 108 109 # counts of each result code: 110 # resCodes = map(lambda x:int(x[-1]),examples) 111 resCodes = [int(x[-1]) for x in examples] 112 # print('resCodes:',resCodes) 113 counts = [0] * nPossibleRes 114 for res in resCodes: 115 counts[res] += 1 116 # print(' '*depth,'counts:',counts) 117 118 nzCounts = numpy.nonzero(counts)[0] 119 if verbose: 120 print(' ' * depth, '\tcounts:', counts) 121 if len(nzCounts) == 1: 122 # bottomed out because there is only one result code left 123 # with any counts (i.e. there's only one type of example 124 # left... this is GOOD!). 125 res = nzCounts[0] 126 tree.SetLabel(res) 127 tree.SetName(str(res)) 128 tree.SetTerminal(1) 129 elif maxDepth >= 0 and depth > maxDepth: 130 # Bottomed out: max depth hit 131 # We don't really know what to do here, so 132 # use the heuristic of picking the most prevalent 133 # result 134 v = numpy.argmax(counts) 135 tree.SetLabel(v) 136 tree.SetName('%d?' % v) 137 tree.SetTerminal(1) 138 else: 139 # find the variable which gives us the best improvement 140 # We do this with an InfoBitRanker: 141 fp = examples[0][1] 142 nBits = fp.GetNumBits() 143 ranker = InfoTheory.InfoBitRanker(nBits, nPossibleRes, metric) 144 if biasList: 145 ranker.SetBiasList(biasList) 146 if CMIM is not None and useCMIM > 0 and not ensemble: 147 ensemble = CMIM.SelectFeatures(examples, useCMIM, bvCol=1) 148 if random: 149 if ensemble: 150 if len(ensemble) > random: 151 picks = _GenerateRandomEnsemble(random, len(ensemble)) 152 availBits = list(numpy.take(ensemble, picks)) 153 else: 154 availBits = list(range(len(ensemble))) 155 else: 156 availBits = _GenerateRandomEnsemble(random, nBits) 157 else: 158 availBits = None 159 if availBits: 160 ranker.SetMaskBits(availBits) 161 # print(' 2:'*depth,availBits) 162 163 useCollections = isinstance(examples[0][1], VectCollection) 164 for example in examples: 165 # print(' '*depth,example[1].ToBitString(),example[-1]) 166 if not useCollections: 167 ranker.AccumulateVotes(example[1], example[-1]) 168 else: 169 example[1].Reset() 170 ranker.AccumulateVotes(example[1].orVect, example[-1]) 171 172 try: 173 bitInfo = ranker.GetTopN(1)[0] 174 best = int(bitInfo[0]) 175 gain = bitInfo[1] 176 except Exception: 177 import traceback 178 traceback.print_exc() 179 print('get top n failed') 180 gain = -1.0 181 if gain <= 0.0: 182 v = numpy.argmax(counts) 183 tree.SetLabel(v) 184 tree.SetName('?%d?' % v) 185 tree.SetTerminal(1) 186 return tree 187 best = int(bitInfo[0]) 188 # print(' '*depth,'\tbest:',bitInfo) 189 if verbose: 190 print(' ' * depth, '\tbest:', bitInfo) 191 # set some info at this node 192 tree.SetName('Bit-%d' % (best)) 193 tree.SetLabel(best) 194 # tree.SetExamples(examples) 195 tree.SetTerminal(0) 196 197 # loop over possible values of the new variable and 198 # build a subtree for each one 199 onExamples = [] 200 offExamples = [] 201 for example in examples: 202 if example[1][best]: 203 if allowCollections and useCollections: 204 sig = copy.copy(example[1]) 205 sig.DetachVectsNotMatchingBit(best) 206 ex = [example[0], sig] 207 if len(example) > 2: 208 ex.extend(example[2:]) 209 example = ex 210 onExamples.append(example) 211 else: 212 offExamples.append(example) 213 # print(' '*depth,len(offExamples),len(onExamples)) 214 for ex in (offExamples, onExamples): 215 if len(ex) == 0: 216 v = numpy.argmax(counts) 217 tree.AddChild('%d??' % v, label=v, data=0.0, isTerminal=1) 218 else: 219 child = BuildSigTree(ex, nPossibleRes, random=random, ensemble=ensemble, metric=metric, 220 biasList=biasList, depth=depth + 1, maxDepth=maxDepth, verbose=verbose) 221 if child is None: 222 v = numpy.argmax(counts) 223 tree.AddChild('%d???' % v, label=v, data=0.0, isTerminal=1) 224 else: 225 tree.AddChildNode(child) 226 return tree
227 228
229 -def SigTreeBuilder(examples, attrs, nPossibleVals, initialVar=None, ensemble=None, 230 randomDescriptors=0, **kwargs):
231 nRes = nPossibleVals[-1] 232 return BuildSigTree(examples, nRes, random=randomDescriptors, **kwargs)
233