Package ML :: Package DecTree :: Module Tree
[hide private]
[frames] | no frames]

Source Code for Module ML.DecTree.Tree

  1  # 
  2  #  Copyright (C) 2000,2003  greg Landrum and Rational Discovery LLC 
  3  # 
  4  """ Implements a class used to represent N-ary trees 
  5   
  6  """ 
  7  import cPickle 
  8   
  9  # FIX: the TreeNode class has not been updated to new-style classes 
 10  # (RD Issue380) because that would break all of our legacy pickled 
 11  # data. Until a solution is found for this breakage, an update is 
 12  # impossible. 
13 -class TreeNode:
14 """ This is your bog standard Tree class. 15 16 the root of the tree is just a TreeNode like all other members. 17 """
18 - def __init__(self,parent,name,label=None,data=None,level=0,isTerminal=0):
19 """ constructor 20 21 **Arguments** 22 23 - parent: the parent of this node in the tree 24 25 - name: the name of the node 26 27 - label: the node's label (should be an integer) 28 29 - data: an optional data field 30 31 - level: an integer indicating the level of this node in the hierarchy 32 (used for printing) 33 34 - isTerminal: flags a node as being terminal. This is useful for those 35 times when it's useful to know such things. 36 37 """ 38 self.children = [] 39 self.parent = parent 40 self.name = name 41 self.data = data 42 self.terminalNode = isTerminal 43 self.label = label 44 self.level = level 45 self.examples = []
46 - def NameTree(self,varNames):
47 """ Set the names of each node in the tree from a list of variable names. 48 49 **Arguments** 50 51 - varNames: a list of names to be assigned 52 53 **Notes** 54 55 1) this works its magic by recursively traversing all children 56 57 2) The assumption is made here that the varNames list can be indexed 58 by the labels of tree nodes 59 60 """ 61 if self.GetTerminal(): 62 return 63 else: 64 for child in self.GetChildren(): 65 child.NameTree(varNames) 66 self.SetName(varNames[self.GetLabel()])
67 NameModel=NameTree 68
69 - def AddChildNode(self,node):
70 """ Adds a TreeNode to the local list of children 71 72 **Arguments** 73 74 - node: the node to be added 75 76 **Note** 77 78 the level of the node (used in printing) is set as well 79 80 """ 81 node.SetLevel(self.level + 1) 82 self.children.append(node)
83
84 - def AddChild(self,name,label=None,data=None,isTerminal=0):
85 """ Creates a new TreeNode and adds a child to the tree 86 87 **Arguments** 88 89 - name: the name of the new node 90 91 - label: the label of the new node (should be an integer) 92 93 - data: the data to be stored in the new node 94 95 - isTerminal: a toggle to indicate whether or not the new node is 96 a terminal (leaf) node. 97 98 **Returns* 99 100 the _TreeNode_ which is constructed 101 102 """ 103 child = TreeNode(self,name,label,data,level=self.level+1,isTerminal=isTerminal) 104 self.children.append(child) 105 return child
106
107 - def PruneChild(self,child):
108 """ Removes the child node 109 110 **Arguments** 111 112 - child: a TreeNode 113 114 """ 115 self.children.remove(child)
116
117 - def ReplaceChildIndex(self,index,newChild):
118 """ Replaces a given child with a new one 119 120 **Arguments** 121 122 - index: an integer 123 124 - child: a TreeNode 125 126 """ 127 self.children[index] = newChild
128
129 - def GetChildren(self):
130 """ Returns a python list of the children of this node 131 132 """ 133 return self.children
134
135 - def Destroy(self):
136 """ Destroys this node and all of its children 137 138 """ 139 for child in self.children: 140 child.Destroy() 141 self.children = None 142 # clean up circular references 143 self.parent = None
144
145 - def GetName(self):
146 """ Returns the name of this node 147 148 """ 149 return self.name
150 - def SetName(self,name):
151 """ Sets the name of this node 152 153 """ 154 self.name = name
155
156 - def GetData(self):
157 """ Returns the data stored at this node 158 159 """ 160 return self.data
161 - def SetData(self,data):
162 """ Sets the data stored at this node 163 164 """ 165 self.data=data
166
167 - def GetTerminal(self):
168 """ Returns whether or not this node is terminal 169 170 """ 171 return self.terminalNode
172 - def SetTerminal(self,isTerminal):
173 """ Sets whether or not this node is terminal 174 175 """ 176 self.terminalNode = isTerminal
177
178 - def GetLabel(self):
179 """ Returns the label of this node 180 181 """ 182 return self.label
183 - def SetLabel(self,label):
184 """ Sets the label of this node (should be an integer) 185 186 """ 187 self.label=label
188
189 - def GetLevel(self):
190 """ Returns the level of this node 191 192 """ 193 return self.level
194 - def SetLevel(self,level):
195 """ Sets the level of this node 196 197 """ 198 self.level=level
199
200 - def GetParent(self):
201 """ Returns the parent of this node 202 203 """ 204 return self.parent
205 - def SetParent(self,parent):
206 """ Sets the parent of this node 207 208 """ 209 self.parent = parent
210 211
212 - def Print(self,level=0,showData=0):
213 """ Pretty prints the tree 214 215 **Arguments** 216 217 - level: sets the number of spaces to be added at the beginning of the output 218 219 - showData: if this is nonzero, the node's _data_ value will be printed as well 220 221 **Note** 222 223 this works recursively 224 225 """ 226 if showData: 227 print '%s%s: %s'%(' '*level,self.name,str(self.data)) 228 else: 229 print '%s%s'%(' '*level,self.name) 230 231 for child in self.children: 232 child.Print(level+1,showData=showData)
233
234 - def Pickle(self,fileName='foo.pkl'):
235 """ Pickles the tree and writes it to disk 236 237 """ 238 pFile = open(fileName,'w+') 239 cPickle.dump(self,pFile)
240
241 - def __str__(self):
242 """ returns a string representation of the tree 243 244 **Note** 245 246 this works recursively 247 248 """ 249 here = '%s%s\n'%(' '*self.level,self.name) 250 for child in self.children: 251 here = here + str(child) 252 return here
253
254 - def __cmp__(self,other):
255 """ allows tree1 == tree2 256 257 **Note** 258 259 This works recursively 260 261 262 """ 263 try: 264 nChildren = len(self.children) 265 if cmp(type(self),type(other)): 266 return cmp(type(self),type(other)) 267 elif cmp(self.name,other.name): 268 return cmp(self.name,other.name) 269 elif cmp(self.label,other.label): 270 return cmp(self.label,other.label) 271 if nChildren < len(other.children): 272 return -1 273 elif nChildren > len(other.children): 274 return 1 275 else: 276 for i in xrange(nChildren): 277 res = cmp(self.children[i],other.children[i]) 278 if res != 0: 279 return res 280 except AttributeError: 281 return -1 282 283 return 0
284 285 286 if __name__ == '__main__': 287 tree = TreeNode(None,'root') 288 for i in xrange(3): 289 child = tree.AddChild('child %d'%i) 290 print tree 291 tree.GetChildren()[1].AddChild('grandchild') 292 tree.GetChildren()[1].AddChild('grandchild2') 293 tree.GetChildren()[1].AddChild('grandchild3') 294 print tree 295 tree.Pickle('save.pkl') 296 print 'prune' 297 tree.PruneChild(tree.GetChildren()[1]) 298 print 'done' 299 print tree 300 301 import copy 302 tree2 = copy.deepcopy(tree) 303 print 'tree==tree2', tree==tree2 304 305 foo = [tree] 306 print 'tree in [tree]:', tree in foo,foo.index(tree) 307 print 'tree2 in [tree]:', tree2 in foo, foo.index(tree2) 308 309 tree2.GetChildren()[1].AddChild('grandchild4') 310 print 'tree==tree2', tree==tree2 311 tree.Destroy() 312