Package Chem :: Package Subshape :: Module SubshapeAligner
[hide private]
[frames] | no frames]

Source Code for Module Chem.Subshape.SubshapeAligner

  1  # $Id: SubshapeAligner.py 290 2007-07-16 05:19:53Z glandrum $
 
  2  #
 
  3  # Copyright (C) 2007 by Greg Landrum 
 
  4  #  All rights reserved
 
  5  #
 
  6  import RDLogger 
  7  logger = RDLogger.logger() 
  8  import Chem,Geometry 
  9  import Numeric 
 10  from Numerics import Alignment 
 11  from Chem.Subshape import SubshapeObjects 
 12  
 
13 -class SubshapeAlignment(object):
14 transform=None 15 triangleSSD=None 16 targetTri=None 17 queryTri=None 18 alignedConfId=-1 19 dirMatch=0.0 20 shapeDist=0.0
21
22 -def _getAllTriangles(pts,orderedTraversal=False):
23 for i in range(len(pts)): 24 if orderedTraversal: 25 jStart=i+1 26 else: 27 jStart=0 28 for j in range(jStart,len(pts)): 29 if j==i: 30 continue 31 if orderedTraversal: 32 kStart=j+1 33 else: 34 kStart=0 35 for k in range(j+1,len(pts)): 36 if k==i or k==j: 37 continue 38 yield (i,j,k)
39
40 -class SubshapeDistanceMetric(object):
41 TANIMOTO=0 42 PROTRUDE=1
43 44 # returns the distance between two shapea according to the provided metric
45 -def GetShapeShapeDistance(s1,s2,distMetric):
46 if distMetric==SubshapeDistanceMetric.PROTRUDE: 47 #print s1.grid.GetOccupancyVect().GetTotalVal(),s2.grid.GetOccupancyVect().GetTotalVal() 48 if s1.grid.GetOccupancyVect().GetTotalVal()<s2.grid.GetOccupancyVect().GetTotalVal(): 49 d = Geometry.ProtrudeDistance(s1.grid,s2.grid) 50 #print d 51 else: 52 d = Geometry.ProtrudeDistance(s2.grid,s1.grid) 53 else: 54 d = Geometry.TanimotoDistance(s1.grid,s2.grid) 55 return d
56 57 # clusters a set of alignments and returns the cluster centroid
58 -def ClusterAlignments(mol,alignments,builder, 59 neighborTol=0.1, 60 distMetric=SubshapeDistanceMetric.PROTRUDE, 61 tempConfId=1001):
62 from ML.Cluster import Butina 63 dists = [] 64 for i in range(len(alignments)): 65 TransformMol(mol,alignments[i].transform,newConfId=tempConfId) 66 shapeI=builder.GenerateSubshapeShape(mol,tempConfId,addSkeleton=False) 67 for j in range(i): 68 TransformMol(mol,alignments[j].transform,newConfId=tempConfId+1) 69 shapeJ=builder.GenerateSubshapeShape(mol,tempConfId+1,addSkeleton=False) 70 d = GetShapeShapeDistance(shapeI,shapeJ,distMetric) 71 dists.append(d) 72 mol.RemoveConformer(tempConfId+1) 73 mol.RemoveConformer(tempConfId) 74 clusts=Butina.ClusterData(dists,len(alignments),neighborTol,isDistData=True) 75 res = [alignments[x[0]] for x in clusts] 76 return res
77
78 -def TransformMol(mol,tform,confId=-1,newConfId=100):
79 """ Applies the transformation to a molecule and sets it up with 80 a single conformer 81 82 """ 83 newConf = Chem.Conformer() 84 newConf.SetId(0) 85 refConf = mol.GetConformer(confId) 86 for i in range(refConf.GetNumAtoms()): 87 pos = list(refConf.GetAtomPosition(i)) 88 pos.append(1.0) 89 newPos = Numeric.matrixmultiply(tform,Numeric.array(pos)) 90 newConf.SetAtomPosition(i,list(newPos)[:3]) 91 newConf.SetId(newConfId) 92 mol.RemoveConformer(newConfId) 93 mol.AddConformer(newConf,assignId=False)
94
95 -class SubshapeAligner(object):
96 triangleRMSTol=1.0 97 distMetric=SubshapeDistanceMetric.PROTRUDE 98 shapeDistTol=0.2 99 numFeatThresh=3 100 dirThresh=2.6 101 edgeTol=6.0 102 #coarseGridToleranceMult=1.5 103 #medGridToleranceMult=1.25 104 coarseGridToleranceMult=1.0 105 medGridToleranceMult=1.0 106
107 - def GetTriangleMatches(self,target,query):
108 """ this is a generator function returning the possible triangle 109 matches between the two shapes 110 """ 111 ssdTol = (self.triangleRMSTol**2)*9 112 res = [] 113 tgtPts = target.skelPts 114 queryPts = query.skelPts 115 tgtLs = {} 116 for i in range(len(tgtPts)): 117 for j in range(i+1,len(tgtPts)): 118 l2 = (tgtPts[i].location-tgtPts[j].location).LengthSq() 119 tgtLs[(i,j)]=l2 120 queryLs = {} 121 for i in range(len(queryPts)): 122 for j in range(i+1,len(queryPts)): 123 l2 = (queryPts[i].location-queryPts[j].location).LengthSq() 124 queryLs[(i,j)]=l2 125 compatEdges={} 126 tol2 = self.edgeTol*self.edgeTol 127 for tk,tv in tgtLs.iteritems(): 128 for qk,qv in queryLs.iteritems(): 129 if abs(tv-qv)<tol2: 130 compatEdges[(tk,qk)]=1 131 seqNo=0 132 for tgtTri in _getAllTriangles(tgtPts,orderedTraversal=True): 133 tgtLocs=[tgtPts[x].location for x in tgtTri] 134 for queryTri in _getAllTriangles(queryPts,orderedTraversal=False): 135 if compatEdges.has_key(((tgtTri[0],tgtTri[1]),(queryTri[0],queryTri[1]))) and \ 136 compatEdges.has_key(((tgtTri[0],tgtTri[2]),(queryTri[0],queryTri[2]))) and \ 137 compatEdges.has_key(((tgtTri[1],tgtTri[2]),(queryTri[1],queryTri[2]))): 138 queryLocs=[queryPts[x].location for x in queryTri] 139 ssd,tf = Alignment.GetAlignmentTransform(tgtLocs,queryLocs) 140 if ssd<=ssdTol: 141 alg = SubshapeAlignment() 142 alg.transform=tf 143 alg.triangleSSD=ssd 144 alg.targetTri=tgtTri 145 alg.queryTri=queryTri 146 alg._seqNo=seqNo 147 seqNo+=1 148 yield alg
149
150 - def _checkMatchFeatures(self,targetPts,queryPts,alignment):
151 nMatched=0 152 for i in range(3): 153 tgtFeats = targetPts[alignment.targetTri[i]].molFeatures 154 qFeats = queryPts[alignment.queryTri[i]].molFeatures 155 if not tgtFeats and not qFeats: 156 nMatched+=1 157 else: 158 for j,jFeat in enumerate(tgtFeats): 159 if jFeat in qFeats: 160 nMatched+=1 161 break 162 if nMatched>=self.numFeatThresh: 163 break 164 return nMatched>=self.numFeatThresh
165
166 - def PruneMatchesUsingFeatures(self,target,query,alignments,pruneStats=None):
167 i = 0 168 targetPts = target.skelPts 169 queryPts = query.skelPts 170 while i<len(alignments): 171 alg = alignments[i] 172 if not self._checkMatchFeatures(targetPts,queryPts,alg): 173 if pruneStats is not None: 174 pruneStats['features']=pruneStats.get('features',0)+1 175 del alignments[i] 176 else: 177 i+=1
178
179 - def _checkMatchDirections(self,targetPts,queryPts,alignment):
180 dot = 0.0 181 for i in range(3): 182 tgtPt = targetPts[alignment.targetTri[i]] 183 queryPt = queryPts[alignment.queryTri[i]] 184 qv = queryPt.shapeDirs[0] 185 tv = tgtPt.shapeDirs[0] 186 rotV =[0.0]*3 187 rotV[0] = alignment.transform[0,0]*qv[0]+alignment.transform[0,1]*qv[1]+alignment.transform[0,2]*qv[2] 188 rotV[1] = alignment.transform[1,0]*qv[0]+alignment.transform[1,1]*qv[1]+alignment.transform[1,2]*qv[2] 189 rotV[2] = alignment.transform[2,0]*qv[0]+alignment.transform[2,1]*qv[1]+alignment.transform[2,2]*qv[2] 190 dot += abs(rotV[0]*tv[0]+rotV[1]*tv[1]+rotV[2]*tv[2]) 191 if dot>=self.dirThresh: 192 # already above the threshold, no need to continue 193 break 194 alignment.dirMatch=dot 195 return dot>=self.dirThresh
196
197 - def PruneMatchesUsingDirection(self,target,query,alignments,pruneStats=None):
198 i = 0 199 tgtPts = target.skelPts 200 queryPts = query.skelPts 201 while i<len(alignments): 202 if not self._checkMatchDirections(tgtPts,queryPts,alignments[i]): 203 if pruneStats is not None: 204 pruneStats['direction']=pruneStats.get('direction',0)+1 205 del alignments[i] 206 else: 207 i+=1
208
209 - def _addCoarseAndMediumGrids(self,mol,tgt,confId,builder):
210 oSpace=builder.gridSpacing 211 if mol: 212 builder.gridSpacing = oSpace*1.5 213 tgt.medGrid = builder.GenerateSubshapeShape(mol,confId,addSkeleton=False) 214 builder.gridSpacing = oSpace*2 215 tgt.coarseGrid = builder.GenerateSubshapeShape(mol,confId,addSkeleton=False) 216 builder.gridSpacing = oSpace 217 else: 218 tgt.medGrid = builder.SampleSubshape(tgt,oSpace*1.5) 219 tgt.coarseGrid = builder.SampleSubshape(tgt,oSpace*2.0)
220
221 - def _checkMatchShape(self,targetMol,target,queryMol,query,alignment,builder, 222 targetConf,queryConf,pruneStats=None,tConfId=1001):
223 matchOk=True 224 TransformMol(queryMol,alignment.transform,confId=queryConf,newConfId=tConfId) 225 oSpace=builder.gridSpacing 226 builder.gridSpacing=oSpace*2 227 coarseGrid=builder.GenerateSubshapeShape(queryMol,tConfId,addSkeleton=False) 228 d = GetShapeShapeDistance(coarseGrid,target.coarseGrid,self.distMetric) 229 if d>self.shapeDistTol*self.coarseGridToleranceMult: 230 matchOk=False 231 if pruneStats is not None: 232 pruneStats['coarseGrid']=pruneStats.get('coarseGrid',0)+1 233 else: 234 builder.gridSpacing=oSpace*1.5 235 medGrid=builder.GenerateSubshapeShape(queryMol,tConfId,addSkeleton=False) 236 d = GetShapeShapeDistance(medGrid,target.medGrid,self.distMetric) 237 if d>self.shapeDistTol*self.medGridToleranceMult: 238 matchOk=False 239 if pruneStats is not None: 240 pruneStats['medGrid']=pruneStats.get('medGrid',0)+1 241 else: 242 builder.gridSpacing=oSpace 243 fineGrid=builder.GenerateSubshapeShape(queryMol,tConfId,addSkeleton=False) 244 d = GetShapeShapeDistance(fineGrid,target,self.distMetric) 245 #print ' ',d 246 if d>self.shapeDistTol: 247 matchOk=False 248 if pruneStats is not None: 249 pruneStats['fineGrid']=pruneStats.get('fineGrid',0)+1 250 alignment.shapeDist=d 251 queryMol.RemoveConformer(tConfId) 252 builder.gridSpacing=oSpace 253 return matchOk
254
255 - def PruneMatchesUsingShape(self,targetMol,target,queryMol,query,builder, 256 alignments,tgtConf=-1,queryConf=-1, 257 pruneStats=None):
258 if not hasattr(target,'medGrid'): 259 self._addCoarseAndMediumGrids(targetMol,target,tgtConf,builder) 260 261 logger.info("Shape-based Pruning") 262 i=0 263 nOrig = len(alignments) 264 nDone=0 265 while i < len(alignments): 266 removeIt=False 267 alg = alignments[i] 268 nDone+=1 269 if not nDone%100: 270 nLeft = len(alignments) 271 logger.info(' processed %d of %d. %d alignments remain'%((nDone, 272 nOrig, 273 nLeft))) 274 if not self._checkMatchShape(targetMol,target,queryMol,query,alg,builder, 275 targetConf=tgtConf,queryConf=queryConf, 276 pruneStats=pruneStats): 277 del alignments[i] 278 else: 279 i+=1
280
281 - def GetSubshapeAlignments(self,targetMol,target,queryMol,query,builder, 282 tgtConf=-1,queryConf=-1,pruneStats=None):
283 import time 284 if pruneStats is None: 285 pruneStats={} 286 logger.info("Generating triangle matches") 287 t1=time.time() 288 res = [x for x in self.GetTriangleMatches(target,query)] 289 t2=time.time() 290 logger.info("Got %d possible alignments in %.1f seconds"%(len(res),t2-t1)) 291 pruneStats['gtm_time']=t2-t1 292 if builder.featFactory: 293 logger.info("Doing feature pruning") 294 t1 = time.time() 295 self.PruneMatchesUsingFeatures(target,query,res,pruneStats=pruneStats) 296 t2 = time.time() 297 pruneStats['feats_time']=t2-t1 298 logger.info("%d possible alignments remain. (%.1f seconds required)"%(len(res),t2-t1)) 299 logger.info("Doing direction pruning") 300 t1 = time.time() 301 self.PruneMatchesUsingDirection(target,query,res,pruneStats=pruneStats) 302 t2 = time.time() 303 pruneStats['direction_time']=t2-t1 304 logger.info("%d possible alignments remain. (%.1f seconds required)"%(len(res),t2-t1)) 305 t1 = time.time() 306 self.PruneMatchesUsingShape(targetMol,target,queryMol,query,builder,res, 307 tgtConf=tgtConf,queryConf=queryConf, 308 pruneStats=pruneStats) 309 t2 = time.time() 310 pruneStats['shape_time']=t2-t1 311 return res
312
313 - def __call__(self,targetMol,target,queryMol,query,builder, 314 tgtConf=-1,queryConf=-1,pruneStats=None):
315 for alignment in self.GetTriangleMatches(target,query): 316 if builder.featFactory and \ 317 not self._checkMatchFeatures(target.skelPts,query.skelPts,alignment): 318 if pruneStats is not None: 319 pruneStats['features']=pruneStats.get('features',0)+1 320 continue 321 if not self._checkMatchDirections(target.skelPts,query.skelPts,alignment): 322 if pruneStats is not None: 323 pruneStats['direction']=pruneStats.get('direction',0)+1 324 continue 325 326 if not hasattr(target,'medGrid'): 327 self._addCoarseAndMediumGrids(targetMol,target,tgtConf,builder) 328 329 if not self._checkMatchShape(targetMol,target,queryMol,query,alignment,builder, 330 targetConf=tgtConf,queryConf=queryConf, 331 pruneStats=pruneStats): 332 continue 333 # if we made it this far, it's a good alignment 334 yield alignment
335 336 337 if __name__=='__main__': 338 import cPickle 339 tgtMol,tgtShape = cPickle.load(file('target.pkl','rb')) 340 queryMol,queryShape = cPickle.load(file('query.pkl','rb')) 341 builder = cPickle.load(file('builder.pkl','rb')) 342 aligner = SubshapeAligner() 343 algs = aligner.GetSubshapeAlignments(tgtMol,tgtShape,queryMol,queryShape,builder) 344 print len(algs) 345 346 from Chem.PyMol import MolViewer 347 v = MolViewer() 348 v.ShowMol(tgtMol,name='Target',showOnly=True) 349 v.ShowMol(queryMol,name='Query',showOnly=False) 350 SubshapeObjects.DisplaySubshape(v,tgtShape,'target_shape',color=(.8,.2,.2)) 351 SubshapeObjects.DisplaySubshape(v,queryShape,'query_shape',color=(.2,.2,.8)) 352