Package rdkit :: Package ML :: Package DecTree :: Module Tree
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.DecTree.Tree

  1  # 
  2  #  Copyright (C) 2000-2008  greg Landrum and Rational Discovery LLC 
  3  # 
  4  """ Implements a class used to represent N-ary trees 
  5   
  6  """ 
  7  import cPickle 
  8  import numpy 
  9   
 10  # FIX: the TreeNode class has not been updated to new-style classes 
 11  # (RD Issue380) because that would break all of our legacy pickled 
 12  # data. Until a solution is found for this breakage, an update is 
 13  # impossible. 
14 -class TreeNode:
15 """ This is your bog standard Tree class. 16 17 the root of the tree is just a TreeNode like all other members. 18 """
19 - def __init__(self,parent,name,label=None,data=None,level=0,isTerminal=0):
20 """ constructor 21 22 **Arguments** 23 24 - parent: the parent of this node in the tree 25 26 - name: the name of the node 27 28 - label: the node's label (should be an integer) 29 30 - data: an optional data field 31 32 - level: an integer indicating the level of this node in the hierarchy 33 (used for printing) 34 35 - isTerminal: flags a node as being terminal. This is useful for those 36 times when it's useful to know such things. 37 38 """ 39 self.children = [] 40 self.parent = parent 41 self.name = name 42 self.data = data 43 self.terminalNode = isTerminal 44 self.label = label 45 self.level = level 46 self.examples = []
47 - def NameTree(self,varNames):
48 """ Set the names of each node in the tree from a list of variable names. 49 50 **Arguments** 51 52 - varNames: a list of names to be assigned 53 54 **Notes** 55 56 1) this works its magic by recursively traversing all children 57 58 2) The assumption is made here that the varNames list can be indexed 59 by the labels of tree nodes 60 61 """ 62 if self.GetTerminal(): 63 return 64 else: 65 for child in self.GetChildren(): 66 child.NameTree(varNames) 67 self.SetName(varNames[self.GetLabel()])
68 NameModel=NameTree 69
70 - def AddChildNode(self,node):
71 """ Adds a TreeNode to the local list of children 72 73 **Arguments** 74 75 - node: the node to be added 76 77 **Note** 78 79 the level of the node (used in printing) is set as well 80 81 """ 82 node.SetLevel(self.level + 1) 83 self.children.append(node)
84
85 - def AddChild(self,name,label=None,data=None,isTerminal=0):
86 """ Creates a new TreeNode and adds a child to the tree 87 88 **Arguments** 89 90 - name: the name of the new node 91 92 - label: the label of the new node (should be an integer) 93 94 - data: the data to be stored in the new node 95 96 - isTerminal: a toggle to indicate whether or not the new node is 97 a terminal (leaf) node. 98 99 **Returns* 100 101 the _TreeNode_ which is constructed 102 103 """ 104 child = TreeNode(self,name,label,data,level=self.level+1,isTerminal=isTerminal) 105 self.children.append(child) 106 return child
107
108 - def PruneChild(self,child):
109 """ Removes the child node 110 111 **Arguments** 112 113 - child: a TreeNode 114 115 """ 116 self.children.remove(child)
117
118 - def ReplaceChildIndex(self,index,newChild):
119 """ Replaces a given child with a new one 120 121 **Arguments** 122 123 - index: an integer 124 125 - child: a TreeNode 126 127 """ 128 self.children[index] = newChild
129
130 - def GetChildren(self):
131 """ Returns a python list of the children of this node 132 133 """ 134 return self.children
135
136 - def Destroy(self):
137 """ Destroys this node and all of its children 138 139 """ 140 for child in self.children: 141 child.Destroy() 142 self.children = None 143 # clean up circular references 144 self.parent = None
145
146 - def GetName(self):
147 """ Returns the name of this node 148 149 """ 150 return self.name
151 - def SetName(self,name):
152 """ Sets the name of this node 153 154 """ 155 self.name = name
156
157 - def GetData(self):
158 """ Returns the data stored at this node 159 160 """ 161 return self.data
162 - def SetData(self,data):
163 """ Sets the data stored at this node 164 165 """ 166 self.data=data
167
168 - def GetTerminal(self):
169 """ Returns whether or not this node is terminal 170 171 """ 172 return self.terminalNode
173 - def SetTerminal(self,isTerminal):
174 """ Sets whether or not this node is terminal 175 176 """ 177 self.terminalNode = isTerminal
178
179 - def GetLabel(self):
180 """ Returns the label of this node 181 182 """ 183 return self.label
184 - def SetLabel(self,label):
185 """ Sets the label of this node (should be an integer) 186 187 """ 188 self.label=label
189
190 - def GetLevel(self):
191 """ Returns the level of this node 192 193 """ 194 return self.level
195 - def SetLevel(self,level):
196 """ Sets the level of this node 197 198 """ 199 self.level=level
200
201 - def GetParent(self):
202 """ Returns the parent of this node 203 204 """ 205 return self.parent
206 - def SetParent(self,parent):
207 """ Sets the parent of this node 208 209 """ 210 self.parent = parent
211 212
213 - def Print(self,level=0,showData=0):
214 """ Pretty prints the tree 215 216 **Arguments** 217 218 - level: sets the number of spaces to be added at the beginning of the output 219 220 - showData: if this is nonzero, the node's _data_ value will be printed as well 221 222 **Note** 223 224 this works recursively 225 226 """ 227 if showData: 228 print '%s%s: %s'%(' '*level,self.name,str(self.data)) 229 else: 230 print '%s%s'%(' '*level,self.name) 231 232 for child in self.children: 233 child.Print(level+1,showData=showData)
234
235 - def Pickle(self,fileName='foo.pkl'):
236 """ Pickles the tree and writes it to disk 237 238 """ 239 pFile = open(fileName,'w+') 240 cPickle.dump(self,pFile)
241
242 - def __str__(self):
243 """ returns a string representation of the tree 244 245 **Note** 246 247 this works recursively 248 249 """ 250 here = '%s%s\n'%(' '*self.level,self.name) 251 for child in self.children: 252 here = here + str(child) 253 return here
254
255 - def __cmp__(self,other):
256 """ allows tree1 == tree2 257 258 **Note** 259 260 This works recursively 261 """ 262 try: 263 nChildren = len(self.children) 264 if cmp(type(self),type(other)): 265 return cmp(type(self),type(other)) 266 elif cmp(self.name,other.name): 267 return cmp(self.name,other.name) 268 elif cmp(self.label,other.label): 269 return cmp(self.label,other.label) 270 if nChildren < len(other.children): 271 return -1 272 elif nChildren > len(other.children): 273 return 1 274 else: 275 for i in xrange(nChildren): 276 res = cmp(self.children[i],other.children[i]) 277 if res != 0: 278 return res 279 except AttributeError: 280 return -1 281 282 return 0
283 284 285 if __name__ == '__main__': 286 tree = TreeNode(None,'root') 287 for i in xrange(3): 288 child = tree.AddChild('child %d'%i) 289 print tree 290 tree.GetChildren()[1].AddChild('grandchild') 291 tree.GetChildren()[1].AddChild('grandchild2') 292 tree.GetChildren()[1].AddChild('grandchild3') 293 print tree 294 tree.Pickle('save.pkl') 295 print 'prune' 296 tree.PruneChild(tree.GetChildren()[1]) 297 print 'done' 298 print tree 299 300 import copy 301 tree2 = copy.deepcopy(tree) 302 print 'tree==tree2', tree==tree2 303 304 foo = [tree] 305 print 'tree in [tree]:', tree in foo,foo.index(tree) 306 print 'tree2 in [tree]:', tree2 in foo, foo.index(tree2) 307 308 tree2.GetChildren()[1].AddChild('grandchild4') 309 print 'tree==tree2', tree==tree2 310 tree.Destroy() 311