Package rdkit :: Package ML :: Package Data :: Module SplitData
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.Data.SplitData

  1  ## Automatically adapted for numpy.oldnumeric Jun 27, 2008 by -c 
  2   
  3  # 
  4  #  Copyright (C) 2003-2008 Greg Landrum and Rational Discovery LLC 
  5  #    All Rights Reserved 
  6  # 
  7  from rdkit import RDConfig,RDRandom 
  8  import random 
  9  import types,os.path,sys 
 10  SeqTypes=(types.ListType,types.TupleType) 
 11   
12 -def SplitIndices(nPts,frac,silent=1,legacy=0,replacement=0):
13 """ splits a set of indices into a data set into 2 pieces 14 15 **Arguments** 16 17 - nPts: the total number of points 18 19 - frac: the fraction of the data to be put in the first data set 20 21 - silent: (optional) toggles display of stats 22 23 - legacy: (optional) use the legacy splitting approach 24 25 - replacement: (optional) use selection with replacement 26 27 **Returns** 28 29 a 2-tuple containing the two sets of indices. 30 31 **Notes** 32 33 - the _legacy_ splitting approach uses randomly-generated floats 34 and compares them to _frac_. This is provided for 35 backwards-compatibility reasons. 36 37 - the default splitting approach uses a random permutation of 38 indices which is split into two parts. 39 40 - selection with replacement can generate duplicates. 41 42 43 **Usage**: 44 45 We'll start with a set of indices and pick from them using 46 the three different approaches: 47 >>> from rdkit.ML.Data import DataUtils 48 49 The base approach always returns the same number of compounds in 50 each set and has no duplicates: 51 >>> DataUtils.InitRandomNumbers((23,42)) 52 >>> test,train = SplitIndices(10,.5) 53 >>> test 54 [1, 5, 6, 4, 2] 55 >>> train 56 [3, 0, 7, 8, 9] 57 58 >>> test,train = SplitIndices(10,.5) 59 >>> test 60 [5, 2, 9, 8, 7] 61 >>> train 62 [6, 0, 3, 1, 4] 63 64 65 The legacy approach can return varying numbers, but still has no 66 duplicates. Note the indices come back ordered: 67 >>> DataUtils.InitRandomNumbers((23,42)) 68 >>> test,train = SplitIndices(10,.5,legacy=1) 69 >>> test 70 [0, 1, 2, 3, 4, 7, 9] 71 >>> train 72 [5, 6, 8] 73 >>> test,train = SplitIndices(10,.5,legacy=1) 74 >>> test 75 [4, 5, 7, 8, 9] 76 >>> train 77 [0, 1, 2, 3, 6] 78 79 The replacement approach returns a fixed number in the training set, 80 a variable number in the test set and can contain duplicates in the 81 training set. 82 >>> DataUtils.InitRandomNumbers((23,42)) 83 >>> test,train = SplitIndices(10,.5,replacement=1) 84 >>> test 85 [1, 1, 3, 0, 1] 86 >>> train 87 [2, 4, 5, 6, 7, 8, 9] 88 >>> test,train = SplitIndices(10,.5,replacement=1) 89 >>> test 90 [9, 5, 4, 8, 0] 91 >>> train 92 [1, 2, 3, 6, 7] 93 94 """ 95 if frac<0. or frac > 1.: 96 raise ValueError('frac must be between 0.0 and 1.0 (frac=%f)'%(frac)) 97 98 if replacement: 99 nTrain = int(nPts*frac) 100 resData = [None]*nTrain 101 resTest = [] 102 for i in range(nTrain): 103 val = int(RDRandom.random()*nPts) 104 if val==nPts: val = nPts-1 105 resData[i] = val 106 for i in range(nPts): 107 if i not in resData: 108 resTest.append(i) 109 elif legacy: 110 resData = [] 111 resTest = [] 112 for i in range(nPts): 113 val = RDRandom.random() 114 if val < frac: 115 resData.append(i) 116 else: 117 resTest.append(i) 118 else: 119 perm = range(nPts) 120 random.shuffle(perm) 121 nTrain = int(nPts*frac) 122 123 resData = list(perm[:nTrain]) 124 resTest = list(perm[nTrain:]) 125 126 if not silent: 127 print 'Training with %d (of %d) points.'%(len(resData),nPts) 128 print '\t%d points are in the hold-out set.'%(len(resTest)) 129 return resData,resTest
130 131
132 -def SplitDataSet(data,frac,silent=0):
133 """ splits a data set into two pieces 134 135 **Arguments** 136 137 - data: a list of examples to be split 138 139 - frac: the fraction of the data to be put in the first data set 140 141 - silent: controls the amount of visual noise produced. 142 143 **Returns** 144 145 a 2-tuple containing the two new data sets. 146 147 """ 148 if frac>0. or frac < 1.: 149 raise ValueError('frac must be between 0.0 and 1.0') 150 151 nOrig = len(data) 152 train,test = SplitIndices(nOrig,frac,silent=1) 153 resData = [data[x] for x in train] 154 resTest = [data[x] for x in test] 155 156 if not silent: 157 print 'Training with %d (of %d) points.'%(len(resData),nOrig) 158 print '\t%d points are in the hold-out set.'%(len(resTest)) 159 return resData,resTest
160 161
162 -def SplitDbData(conn,fracs,table='',fields='*',where='',join='', 163 labelCol='', 164 useActs=0,nActs=2,actCol='',actBounds=[], 165 silent=0):
166 """ "splits" a data set held in a DB by returning lists of ids 167 168 **Arguments**: 169 170 - conn: a DbConnect object 171 172 - frac: the split fraction. This can optionally be specified as a 173 sequence with a different fraction for each activity value. 174 175 - table,fields,where,join: (optional) SQL query parameters 176 177 - useActs: (optional) toggles splitting based on activities 178 (ensuring that a given fraction of each activity class ends 179 up in the hold-out set) 180 Defaults to 0 181 182 - nActs: (optional) number of possible activity values, only 183 used if _useActs_ is nonzero 184 Defaults to 2 185 186 - actCol: (optional) name of the activity column 187 Defaults to use the last column returned by the query 188 189 - actBounds: (optional) sequence of activity bounds 190 (for cases where the activity isn't quantized in the db) 191 Defaults to an empty sequence 192 193 - silent: controls the amount of visual noise produced. 194 195 **Usage**: 196 197 Set up the db connection, the simple tables we're using have actives with even 198 ids and inactives with odd ids: 199 >>> from rdkit.ML.Data import DataUtils 200 >>> from rdkit.Dbase.DbConnection import DbConnect 201 >>> conn = DbConnect(RDConfig.RDTestDatabase) 202 203 Pull a set of points from a simple table... take 33% of all points: 204 >>> DataUtils.InitRandomNumbers((23,42)) 205 >>> train,test = SplitDbData(conn,1./3.,'basic_2class') 206 >>> [str(x) for x in train] 207 ['id-7', 'id-6', 'id-2', 'id-8'] 208 209 ...take 50% of actives and 50% of inactives: 210 >>> DataUtils.InitRandomNumbers((23,42)) 211 >>> train,test = SplitDbData(conn,.5,'basic_2class',useActs=1) 212 >>> [str(x) for x in train] 213 ['id-5', 'id-3', 'id-1', 'id-4', 'id-10', 'id-8'] 214 215 216 Notice how the results came out sorted by activity 217 218 We can be asymmetrical: take 33% of actives and 50% of inactives: 219 >>> DataUtils.InitRandomNumbers((23,42)) 220 >>> train,test = SplitDbData(conn,[.5,1./3.],'basic_2class',useActs=1) 221 >>> [str(x) for x in train] 222 ['id-5', 'id-3', 'id-1', 'id-4', 'id-10'] 223 224 And we can pull from tables with non-quantized activities by providing 225 activity quantization bounds: 226 >>> DataUtils.InitRandomNumbers((23,42)) 227 >>> train,test = SplitDbData(conn,.5,'float_2class',useActs=1,actBounds=[1.0]) 228 >>> [str(x) for x in train] 229 ['id-5', 'id-3', 'id-1', 'id-4', 'id-10', 'id-8'] 230 231 232 """ 233 if not table: 234 table=conn.tableName 235 if actBounds and len(actBounds)!=nActs-1: 236 raise ValueError('activity bounds list length incorrect') 237 if useActs: 238 if type(fracs) not in SeqTypes: 239 fracs = tuple([fracs]*nActs) 240 for frac in fracs: 241 if frac <0.0 or frac>1.0: 242 raise ValueError('fractions must be between 0.0 and 1.0') 243 else: 244 if type(fracs) in SeqTypes: 245 frac = fracs[0] 246 if frac<0.0 or frac>1.0: 247 raise ValueError('fractions must be between 0.0 and 1.0') 248 else: 249 frac = fracs 250 # start by getting the name of the ID column: 251 colNames = conn.GetColumnNames(table=table,what=fields,join=join) 252 idCol = colNames[0] 253 254 if not useActs: 255 # get the IDS: 256 d = conn.GetData(table=table,fields=idCol,join=join) 257 ids = [x[0] for x in d] 258 nRes = len(ids) 259 train,test = SplitIndices(nRes,frac,silent=1) 260 trainPts = [ids[x] for x in train] 261 testPts = [ids[x] for x in test] 262 else: 263 trainPts = [] 264 testPts = [] 265 if not actCol: 266 actCol = colNames[-1] 267 whereBase=where.strip() 268 if whereBase.find('where')!=0: 269 whereBase = 'where '+whereBase 270 if where: 271 whereBase += ' and ' 272 for act in range(nActs): 273 frac = fracs[act] 274 if not actBounds: 275 whereTxt = whereBase + '%s=%d'%(actCol,act) 276 else: 277 whereTxt = whereBase 278 if act!=0: 279 whereTxt += '%s>=%f '%(actCol,actBounds[act-1]) 280 if act < nActs-1: 281 if act!=0: 282 whereTxt += 'and ' 283 whereTxt += '%s<%f'%(actCol,actBounds[act]) 284 d = conn.GetData(table=table,fields=idCol,join=join,where=whereTxt) 285 ids = [x[0] for x in d] 286 nRes = len(ids) 287 train,test = SplitIndices(nRes,frac,silent=1) 288 trainPts.extend([ids[x] for x in train]) 289 testPts.extend([ids[x] for x in test]) 290 291 return trainPts,testPts
292
293 -def _test():
294 import doctest,sys 295 return doctest.testmod(sys.modules["__main__"])
296 297 if __name__ == '__main__': 298 import sys 299 failed,tried = _test() 300 sys.exit(failed) 301