Package rdkit :: Package ML :: Package Cluster :: Module Butina
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.Cluster.Butina

 1  # $Id: Butina.py 1029 2009-03-23 17:00:45Z glandrum $ 
 2  # 
 3  # Copyright (C) 2007-2008 Greg Landrum 
 4  #   All Rights Reserved 
 5  # 
 6  """ Implementation of the clustering algorithm published in: 
 7    Butina JCICS 39 747-750 (1999) 
 8   
 9  """ 
10  import numpy 
11  from rdkit import RDLogger 
12  logger=RDLogger.logger() 
13   
14 -def EuclideanDist(pi,pj):
15 dv = array(pi)- array(pj) 16 return numpy.sqrt(dv*dv)
17 18
19 -def ClusterData(data,nPts,distThresh,isDistData=False,distFunc=EuclideanDist):
20 """ clusters the data points passed in and returns the list of clusters 21 22 **Arguments** 23 24 - data: a list of items with the input data 25 (see discussion of _isDistData_ argument for the exception) 26 27 - nPts: the number of points to be used 28 29 - distThresh: elements within this range of each other are considered 30 to be neighbors 31 32 - isDistData: set this toggle when the data passed in is a 33 distance matrix. The distance matrix should be stored 34 symmetrically. An example of how to do this: 35 36 dists = [] 37 for i in range(nPts): 38 for j in range(i): 39 dists.append( distfunc(i,j) ) 40 41 - distFunc: a function to calculate distances between points. 42 Receives 2 points as arguments, should return a float 43 44 **Returns** 45 46 - a tuple of tuples containing information about the clusters: 47 ( (cluster1_elem1, cluster1_elem2, ...), 48 (cluster2_elem1, cluster2_elem2, ...), 49 ... 50 ) 51 The first element for each cluster is its centroid. 52 53 """ 54 if isDistData and len(data)>(nPts*(nPts-1)/2): 55 logger.warning("Distance matrix is too long") 56 nbrLists = [None]*nPts 57 for i in range(nPts): nbrLists[i] = [] 58 59 dmIdx=0 60 for i in range(nPts): 61 for j in range(i): 62 if not isDistData: 63 dij = distFunc(data[i],data[j]) 64 else: 65 dij = data[dmIdx] 66 dmIdx+=1 67 if dij<=distThresh: 68 nbrLists[i].append(j) 69 nbrLists[j].append(i) 70 #print nbrLists 71 # sort by the number of neighbors: 72 tLists = [(len(y),x) for x,y in enumerate(nbrLists)] 73 tLists.sort() 74 tLists.reverse() 75 76 res = [] 77 seen = [0]*nPts 78 while tLists: 79 nNbrs,idx = tLists.pop(0) 80 if seen[idx]: 81 continue 82 tRes = [idx] 83 for nbr in nbrLists[idx]: 84 if not seen[nbr]: 85 tRes.append(nbr) 86 seen[nbr]=1 87 res.append(tuple(tRes)) 88 return tuple(res)
89