1
2
3
4
5
6 """ Implementation of the clustering algorithm published in:
7 Butina JCICS 39 747-750 (1999)
8
9 """
10 import numpy
11 from rdkit import RDLogger
12 logger=RDLogger.logger()
13
15 dv = array(pi)- array(pj)
16 return numpy.sqrt(dv*dv)
17
18
20 """ clusters the data points passed in and returns the list of clusters
21
22 **Arguments**
23
24 - data: a list of items with the input data
25 (see discussion of _isDistData_ argument for the exception)
26
27 - nPts: the number of points to be used
28
29 - distThresh: elements within this range of each other are considered
30 to be neighbors
31
32 - isDistData: set this toggle when the data passed in is a
33 distance matrix. The distance matrix should be stored
34 symmetrically. An example of how to do this:
35
36 dists = []
37 for i in range(nPts):
38 for j in range(i):
39 dists.append( distfunc(i,j) )
40
41 - distFunc: a function to calculate distances between points.
42 Receives 2 points as arguments, should return a float
43
44 **Returns**
45
46 - a tuple of tuples containing information about the clusters:
47 ( (cluster1_elem1, cluster1_elem2, ...),
48 (cluster2_elem1, cluster2_elem2, ...),
49 ...
50 )
51 The first element for each cluster is its centroid.
52
53 """
54 if isDistData and len(data)>(nPts*(nPts-1)/2):
55 logger.warning("Distance matrix is too long")
56 nbrLists = [None]*nPts
57 for i in range(nPts): nbrLists[i] = []
58
59 dmIdx=0
60 for i in range(nPts):
61 for j in range(i):
62 if not isDistData:
63 dij = distFunc(data[i],data[j])
64 else:
65 dij = data[dmIdx]
66 dmIdx+=1
67 if dij<=distThresh:
68 nbrLists[i].append(j)
69 nbrLists[j].append(i)
70
71
72 tLists = [(len(y),x) for x,y in enumerate(nbrLists)]
73 tLists.sort()
74 tLists.reverse()
75
76 res = []
77 seen = [0]*nPts
78 while tLists:
79 nNbrs,idx = tLists.pop(0)
80 if seen[idx]:
81 continue
82 tRes = [idx]
83 for nbr in nbrLists[idx]:
84 if not seen[nbr]:
85 tRes.append(nbr)
86 seen[nbr]=1
87 res.append(tuple(tRes))
88 return tuple(res)
89