1
2
3
4
5 import RDConfig,RDRandom
6 import Numeric
7 import RandomArray
8 import types,os.path,sys
9 SeqTypes=(types.ListType,types.TupleType)
10
12 """ splits a set of indices into a data set into 2 pieces
13
14 **Arguments**
15
16 - nPts: the total number of points
17
18 - frac: the fraction of the data to be put in the first data set
19
20 - silent: (optional) toggles display of stats
21
22 - legacy: (optional) use the legacy splitting approach
23
24 - replacement: (optional) use selection with replacement
25
26 **Returns**
27
28 a 2-tuple containing the two sets of indices.
29
30 **Notes**
31
32 - the _legacy_ splitting approach uses randomly-generated floats
33 and compares them to _frac_. This is provided for
34 backwards-compatibility reasons.
35
36 - the default splitting approach uses a random permutation of
37 indices which is split into two parts.
38
39 - selection with replacement can generate duplicates.
40
41
42 **Usage**:
43
44 We'll start with a set of indices and pick from them using
45 the three different approaches:
46 >>> from ML.Data import DataUtils
47
48 The base approach always returns the same number of compounds in
49 each set and has no duplicates:
50 >>> DataUtils.InitRandomNumbers((23,42))
51 >>> test,train = SplitIndices(10,.5)
52 >>> test
53 [9, 4, 3, 8, 2]
54 >>> train
55 [7, 6, 1, 5, 0]
56 >>> test,train = SplitIndices(10,.5)
57 >>> test
58 [4, 6, 8, 2, 7]
59 >>> train
60 [5, 9, 0, 3, 1]
61
62 The legacy approach can return varying numbers, but still has no
63 duplicates. Note the indices come back ordered:
64 >>> DataUtils.InitRandomNumbers((23,42))
65 >>> test,train = SplitIndices(10,.5,legacy=1)
66 >>> test
67 [0, 1, 2, 3, 4, 7, 9]
68 >>> train
69 [5, 6, 8]
70 >>> test,train = SplitIndices(10,.5,legacy=1)
71 >>> test
72 [4, 5, 7, 8, 9]
73 >>> train
74 [0, 1, 2, 3, 6]
75
76 The replacement approach returns a fixed number in the training set,
77 a variable number in the test set and can contain duplicates in the
78 training set.
79 >>> DataUtils.InitRandomNumbers((23,42))
80 >>> test,train = SplitIndices(10,.5,replacement=1)
81 >>> test
82 [1, 1, 3, 0, 1]
83 >>> train
84 [2, 4, 5, 6, 7, 8, 9]
85 >>> test,train = SplitIndices(10,.5,replacement=1)
86 >>> test
87 [9, 5, 4, 8, 0]
88 >>> train
89 [1, 2, 3, 6, 7]
90
91 """
92 if frac<0. or frac > 1.:
93 raise ValueError('frac must be between 0.0 and 1.0 (frac=%f)'%(frac))
94
95 if replacement:
96 nTrain = int(nPts*frac)
97 resData = [None]*nTrain
98 resTest = []
99 for i in range(nTrain):
100 val = int(RDRandom.random()*nPts)
101 if val==nPts: val = nPts-1
102 resData[i] = val
103 for i in range(nPts):
104 if i not in resData:
105 resTest.append(i)
106 elif legacy:
107 resData = []
108 resTest = []
109 for i in range(nPts):
110 val = RDRandom.random()
111 if val < frac:
112 resData.append(i)
113 else:
114 resTest.append(i)
115 else:
116 perm = RandomArray.permutation(nPts)
117 nTrain = int(nPts*frac)
118
119 resData = list(perm[:nTrain])
120 resTest = list(perm[nTrain:])
121
122 if not silent:
123 print 'Training with %d (of %d) points.'%(len(resData),nPts)
124 print '\t%d points are in the hold-out set.'%(len(resTest))
125 return resData,resTest
126
127
129 """ splits a data set into two pieces
130
131 **Arguments**
132
133 - data: a list of examples to be split
134
135 - frac: the fraction of the data to be put in the first data set
136
137 - silent: controls the amount of visual noise produced.
138
139 **Returns**
140
141 a 2-tuple containing the two new data sets.
142
143 """
144 if frac>0. or frac < 1.:
145 raise ValueError('frac must be between 0.0 and 1.0')
146
147 nOrig = len(data)
148 train,test = SplitIndices(nOrig,frac,silent=1)
149 resData = [data[x] for x in train]
150 resTest = [data[x] for x in test]
151
152 if not silent:
153 print 'Training with %d (of %d) points.'%(len(resData),nOrig)
154 print '\t%d points are in the hold-out set.'%(len(resTest))
155 return resData,resTest
156
157
158 -def SplitDbData(conn,fracs,table='',fields='*',where='',join='',
159 labelCol='',
160 useActs=0,nActs=2,actCol='',actBounds=[],
161 silent=0):
162 """ "splits" a data set held in a DB by returning lists of ids
163
164 **Arguments**:
165
166 - conn: a DbConnect object
167
168 - frac: the split fraction. This can optionally be specified as a
169 sequence with a different fraction for each activity value.
170
171 - table,fields,where,join: (optional) SQL query parameters
172
173 - useActs: (optional) toggles splitting based on activities
174 (ensuring that a given fraction of each activity class ends
175 up in the hold-out set)
176 Defaults to 0
177
178 - nActs: (optional) number of possible activity values, only
179 used if _useActs_ is nonzero
180 Defaults to 2
181
182 - actCol: (optional) name of the activity column
183 Defaults to use the last column returned by the query
184
185 - actBounds: (optional) sequence of activity bounds
186 (for cases where the activity isn't quantized in the db)
187 Defaults to an empty sequence
188
189 - silent: controls the amount of visual noise produced.
190
191 **Usage**:
192
193 Set up the db connection, the simple tables we're using have actives with even
194 ids and inactives with odd ids:
195 >>> from ML.Data import DataUtils
196 >>> from Dbase.DbConnection import DbConnect
197 >>> conn = DbConnect(RDConfig.RDTestDatabase)
198
199 Pull a set of points from a simple table... take 33% of all points:
200 >>> DataUtils.InitRandomNumbers((23,42))
201 >>> train,test = SplitDbData(conn,1./3.,'basic_2class')
202 >>> [str(x) for x in train]
203 ['id-10', 'id-5', 'id-4', 'id-9']
204
205 ...take 50% of actives and 50% of inactives:
206 >>> DataUtils.InitRandomNumbers((23,42))
207 >>> train,test = SplitDbData(conn,.5,'basic_2class',useActs=1)
208 >>> [str(x) for x in train]
209 ['id-9', 'id-7', 'id-5', 'id-8', 'id-6', 'id-4']
210
211 Notice how the results came out sorted by activity
212
213 We can be asymmetrical: take 33% of actives and 50% of inactives:
214 >>> DataUtils.InitRandomNumbers((23,42))
215 >>> train,test = SplitDbData(conn,[.5,1./3.],'basic_2class',useActs=1)
216 >>> [str(x) for x in train]
217 ['id-9', 'id-7', 'id-5', 'id-8', 'id-6']
218
219 And we can pull from tables with non-quantized activities by providing
220 activity quantization bounds:
221 >>> DataUtils.InitRandomNumbers((23,42))
222 >>> train,test = SplitDbData(conn,.5,'float_2class',useActs=1,actBounds=[1.0])
223 >>> [str(x) for x in train]
224 ['id-9', 'id-7', 'id-5', 'id-8', 'id-6', 'id-4']
225
226 """
227 if not table:
228 table=conn.tableName
229 if actBounds and len(actBounds)!=nActs-1:
230 raise ValueError('activity bounds list length incorrect')
231 if useActs:
232 if type(fracs) not in SeqTypes:
233 fracs = tuple([fracs]*nActs)
234 for frac in fracs:
235 if frac <0.0 or frac>1.0:
236 raise ValueError('fractions must be between 0.0 and 1.0')
237 else:
238 if type(fracs) in SeqTypes:
239 frac = fracs[0]
240 if frac<0.0 or frac>1.0:
241 raise ValueError('fractions must be between 0.0 and 1.0')
242 else:
243 frac = fracs
244
245 colNames = conn.GetColumnNames(table=table,what=fields,join=join)
246 idCol = colNames[0]
247
248 if not useActs:
249
250 d = conn.GetData(table=table,fields=idCol,join=join)
251 ids = [x[0] for x in d]
252 nRes = len(ids)
253 train,test = SplitIndices(nRes,frac,silent=1)
254 trainPts = [ids[x] for x in train]
255 testPts = [ids[x] for x in test]
256 else:
257 trainPts = []
258 testPts = []
259 if not actCol:
260 actCol = colNames[-1]
261 whereBase=where.strip()
262 if whereBase.find('where')!=0:
263 whereBase = 'where '+whereBase
264 if where:
265 whereBase += ' and '
266 for act in range(nActs):
267 frac = fracs[act]
268 if not actBounds:
269 whereTxt = whereBase + '%s=%d'%(actCol,act)
270 else:
271 whereTxt = whereBase
272 if act!=0:
273 whereTxt += '%s>=%f '%(actCol,actBounds[act-1])
274 if act < nActs-1:
275 if act!=0:
276 whereTxt += 'and '
277 whereTxt += '%s<%f'%(actCol,actBounds[act])
278 d = conn.GetData(table=table,fields=idCol,join=join,where=whereTxt)
279 ids = [x[0] for x in d]
280 nRes = len(ids)
281 train,test = SplitIndices(nRes,frac,silent=1)
282 trainPts.extend([ids[x] for x in train])
283 testPts.extend([ids[x] for x in test])
284
285 return trainPts,testPts
286
288 import doctest,sys
289 return doctest.testmod(sys.modules["__main__"])
290
291 if __name__ == '__main__':
292 import sys
293 failed,tried = _test()
294 sys.exit(failed)
295