1
2
3
4
5
6
7 from rdkit import RDConfig,RDRandom
8 import random
9 import types,os.path,sys
10 SeqTypes=(types.ListType,types.TupleType)
11
13 """ splits a set of indices into a data set into 2 pieces
14
15 **Arguments**
16
17 - nPts: the total number of points
18
19 - frac: the fraction of the data to be put in the first data set
20
21 - silent: (optional) toggles display of stats
22
23 - legacy: (optional) use the legacy splitting approach
24
25 - replacement: (optional) use selection with replacement
26
27 **Returns**
28
29 a 2-tuple containing the two sets of indices.
30
31 **Notes**
32
33 - the _legacy_ splitting approach uses randomly-generated floats
34 and compares them to _frac_. This is provided for
35 backwards-compatibility reasons.
36
37 - the default splitting approach uses a random permutation of
38 indices which is split into two parts.
39
40 - selection with replacement can generate duplicates.
41
42
43 **Usage**:
44
45 We'll start with a set of indices and pick from them using
46 the three different approaches:
47 >>> from rdkit.ML.Data import DataUtils
48
49 The base approach always returns the same number of compounds in
50 each set and has no duplicates:
51 >>> DataUtils.InitRandomNumbers((23,42))
52 >>> test,train = SplitIndices(10,.5)
53 >>> test
54 [1, 5, 6, 4, 2]
55 >>> train
56 [3, 0, 7, 8, 9]
57
58 >>> test,train = SplitIndices(10,.5)
59 >>> test
60 [5, 2, 9, 8, 7]
61 >>> train
62 [6, 0, 3, 1, 4]
63
64
65 The legacy approach can return varying numbers, but still has no
66 duplicates. Note the indices come back ordered:
67 >>> DataUtils.InitRandomNumbers((23,42))
68 >>> test,train = SplitIndices(10,.5,legacy=1)
69 >>> test
70 [0, 1, 2, 3, 4, 7, 9]
71 >>> train
72 [5, 6, 8]
73 >>> test,train = SplitIndices(10,.5,legacy=1)
74 >>> test
75 [4, 5, 7, 8, 9]
76 >>> train
77 [0, 1, 2, 3, 6]
78
79 The replacement approach returns a fixed number in the training set,
80 a variable number in the test set and can contain duplicates in the
81 training set.
82 >>> DataUtils.InitRandomNumbers((23,42))
83 >>> test,train = SplitIndices(10,.5,replacement=1)
84 >>> test
85 [1, 1, 3, 0, 1]
86 >>> train
87 [2, 4, 5, 6, 7, 8, 9]
88 >>> test,train = SplitIndices(10,.5,replacement=1)
89 >>> test
90 [9, 5, 4, 8, 0]
91 >>> train
92 [1, 2, 3, 6, 7]
93
94 """
95 if frac<0. or frac > 1.:
96 raise ValueError('frac must be between 0.0 and 1.0 (frac=%f)'%(frac))
97
98 if replacement:
99 nTrain = int(nPts*frac)
100 resData = [None]*nTrain
101 resTest = []
102 for i in range(nTrain):
103 val = int(RDRandom.random()*nPts)
104 if val==nPts: val = nPts-1
105 resData[i] = val
106 for i in range(nPts):
107 if i not in resData:
108 resTest.append(i)
109 elif legacy:
110 resData = []
111 resTest = []
112 for i in range(nPts):
113 val = RDRandom.random()
114 if val < frac:
115 resData.append(i)
116 else:
117 resTest.append(i)
118 else:
119 perm = range(nPts)
120 random.shuffle(perm)
121 nTrain = int(nPts*frac)
122
123 resData = list(perm[:nTrain])
124 resTest = list(perm[nTrain:])
125
126 if not silent:
127 print 'Training with %d (of %d) points.'%(len(resData),nPts)
128 print '\t%d points are in the hold-out set.'%(len(resTest))
129 return resData,resTest
130
131
133 """ splits a data set into two pieces
134
135 **Arguments**
136
137 - data: a list of examples to be split
138
139 - frac: the fraction of the data to be put in the first data set
140
141 - silent: controls the amount of visual noise produced.
142
143 **Returns**
144
145 a 2-tuple containing the two new data sets.
146
147 """
148 if frac>0. or frac < 1.:
149 raise ValueError('frac must be between 0.0 and 1.0')
150
151 nOrig = len(data)
152 train,test = SplitIndices(nOrig,frac,silent=1)
153 resData = [data[x] for x in train]
154 resTest = [data[x] for x in test]
155
156 if not silent:
157 print 'Training with %d (of %d) points.'%(len(resData),nOrig)
158 print '\t%d points are in the hold-out set.'%(len(resTest))
159 return resData,resTest
160
161
162 -def SplitDbData(conn,fracs,table='',fields='*',where='',join='',
163 labelCol='',
164 useActs=0,nActs=2,actCol='',actBounds=[],
165 silent=0):
166 """ "splits" a data set held in a DB by returning lists of ids
167
168 **Arguments**:
169
170 - conn: a DbConnect object
171
172 - frac: the split fraction. This can optionally be specified as a
173 sequence with a different fraction for each activity value.
174
175 - table,fields,where,join: (optional) SQL query parameters
176
177 - useActs: (optional) toggles splitting based on activities
178 (ensuring that a given fraction of each activity class ends
179 up in the hold-out set)
180 Defaults to 0
181
182 - nActs: (optional) number of possible activity values, only
183 used if _useActs_ is nonzero
184 Defaults to 2
185
186 - actCol: (optional) name of the activity column
187 Defaults to use the last column returned by the query
188
189 - actBounds: (optional) sequence of activity bounds
190 (for cases where the activity isn't quantized in the db)
191 Defaults to an empty sequence
192
193 - silent: controls the amount of visual noise produced.
194
195 **Usage**:
196
197 Set up the db connection, the simple tables we're using have actives with even
198 ids and inactives with odd ids:
199 >>> from rdkit.ML.Data import DataUtils
200 >>> from rdkit.Dbase.DbConnection import DbConnect
201 >>> conn = DbConnect(RDConfig.RDTestDatabase)
202
203 Pull a set of points from a simple table... take 33% of all points:
204 >>> DataUtils.InitRandomNumbers((23,42))
205 >>> train,test = SplitDbData(conn,1./3.,'basic_2class')
206 >>> [str(x) for x in train]
207 ['id-7', 'id-6', 'id-2', 'id-8']
208
209 ...take 50% of actives and 50% of inactives:
210 >>> DataUtils.InitRandomNumbers((23,42))
211 >>> train,test = SplitDbData(conn,.5,'basic_2class',useActs=1)
212 >>> [str(x) for x in train]
213 ['id-5', 'id-3', 'id-1', 'id-4', 'id-10', 'id-8']
214
215
216 Notice how the results came out sorted by activity
217
218 We can be asymmetrical: take 33% of actives and 50% of inactives:
219 >>> DataUtils.InitRandomNumbers((23,42))
220 >>> train,test = SplitDbData(conn,[.5,1./3.],'basic_2class',useActs=1)
221 >>> [str(x) for x in train]
222 ['id-5', 'id-3', 'id-1', 'id-4', 'id-10']
223
224 And we can pull from tables with non-quantized activities by providing
225 activity quantization bounds:
226 >>> DataUtils.InitRandomNumbers((23,42))
227 >>> train,test = SplitDbData(conn,.5,'float_2class',useActs=1,actBounds=[1.0])
228 >>> [str(x) for x in train]
229 ['id-5', 'id-3', 'id-1', 'id-4', 'id-10', 'id-8']
230
231
232 """
233 if not table:
234 table=conn.tableName
235 if actBounds and len(actBounds)!=nActs-1:
236 raise ValueError('activity bounds list length incorrect')
237 if useActs:
238 if type(fracs) not in SeqTypes:
239 fracs = tuple([fracs]*nActs)
240 for frac in fracs:
241 if frac <0.0 or frac>1.0:
242 raise ValueError('fractions must be between 0.0 and 1.0')
243 else:
244 if type(fracs) in SeqTypes:
245 frac = fracs[0]
246 if frac<0.0 or frac>1.0:
247 raise ValueError('fractions must be between 0.0 and 1.0')
248 else:
249 frac = fracs
250
251 colNames = conn.GetColumnNames(table=table,what=fields,join=join)
252 idCol = colNames[0]
253
254 if not useActs:
255
256 d = conn.GetData(table=table,fields=idCol,join=join)
257 ids = [x[0] for x in d]
258 nRes = len(ids)
259 train,test = SplitIndices(nRes,frac,silent=1)
260 trainPts = [ids[x] for x in train]
261 testPts = [ids[x] for x in test]
262 else:
263 trainPts = []
264 testPts = []
265 if not actCol:
266 actCol = colNames[-1]
267 whereBase=where.strip()
268 if whereBase.find('where')!=0:
269 whereBase = 'where '+whereBase
270 if where:
271 whereBase += ' and '
272 for act in range(nActs):
273 frac = fracs[act]
274 if not actBounds:
275 whereTxt = whereBase + '%s=%d'%(actCol,act)
276 else:
277 whereTxt = whereBase
278 if act!=0:
279 whereTxt += '%s>=%f '%(actCol,actBounds[act-1])
280 if act < nActs-1:
281 if act!=0:
282 whereTxt += 'and '
283 whereTxt += '%s<%f'%(actCol,actBounds[act])
284 d = conn.GetData(table=table,fields=idCol,join=join,where=whereTxt)
285 ids = [x[0] for x in d]
286 nRes = len(ids)
287 train,test = SplitIndices(nRes,frac,silent=1)
288 trainPts.extend([ids[x] for x in train])
289 testPts.extend([ids[x] for x in test])
290
291 return trainPts,testPts
292
294 import doctest,sys
295 return doctest.testmod(sys.modules["__main__"])
296
297 if __name__ == '__main__':
298 import sys
299 failed,tried = _test()
300 sys.exit(failed)
301