Package ML :: Package Data :: Module SplitData
[hide private]
[frames] | no frames]

Source Code for Module ML.Data.SplitData

  1  #
 
  2  #  Copyright (C) 2003-2004 Rational Discovery LLC
 
  3  #    All Rights Reserved
 
  4  #
 
  5  import RDConfig,RDRandom 
  6  import Numeric 
  7  import RandomArray 
  8  import types,os.path,sys 
  9  SeqTypes=(types.ListType,types.TupleType) 
 10  
 
11 -def SplitIndices(nPts,frac,silent=1,legacy=0,replacement=0):
12 """ splits a set of indices into a data set into 2 pieces 13 14 **Arguments** 15 16 - nPts: the total number of points 17 18 - frac: the fraction of the data to be put in the first data set 19 20 - silent: (optional) toggles display of stats 21 22 - legacy: (optional) use the legacy splitting approach 23 24 - replacement: (optional) use selection with replacement 25 26 **Returns** 27 28 a 2-tuple containing the two sets of indices. 29 30 **Notes** 31 32 - the _legacy_ splitting approach uses randomly-generated floats 33 and compares them to _frac_. This is provided for 34 backwards-compatibility reasons. 35 36 - the default splitting approach uses a random permutation of 37 indices which is split into two parts. 38 39 - selection with replacement can generate duplicates. 40 41 42 **Usage**: 43 44 We'll start with a set of indices and pick from them using 45 the three different approaches: 46 >>> from ML.Data import DataUtils 47 48 The base approach always returns the same number of compounds in 49 each set and has no duplicates: 50 >>> DataUtils.InitRandomNumbers((23,42)) 51 >>> test,train = SplitIndices(10,.5) 52 >>> test 53 [9, 4, 3, 8, 2] 54 >>> train 55 [7, 6, 1, 5, 0] 56 >>> test,train = SplitIndices(10,.5) 57 >>> test 58 [4, 6, 8, 2, 7] 59 >>> train 60 [5, 9, 0, 3, 1] 61 62 The legacy approach can return varying numbers, but still has no 63 duplicates. Note the indices come back ordered: 64 >>> DataUtils.InitRandomNumbers((23,42)) 65 >>> test,train = SplitIndices(10,.5,legacy=1) 66 >>> test 67 [0, 1, 2, 3, 4, 7, 9] 68 >>> train 69 [5, 6, 8] 70 >>> test,train = SplitIndices(10,.5,legacy=1) 71 >>> test 72 [4, 5, 7, 8, 9] 73 >>> train 74 [0, 1, 2, 3, 6] 75 76 The replacement approach returns a fixed number in the training set, 77 a variable number in the test set and can contain duplicates in the 78 training set. 79 >>> DataUtils.InitRandomNumbers((23,42)) 80 >>> test,train = SplitIndices(10,.5,replacement=1) 81 >>> test 82 [1, 1, 3, 0, 1] 83 >>> train 84 [2, 4, 5, 6, 7, 8, 9] 85 >>> test,train = SplitIndices(10,.5,replacement=1) 86 >>> test 87 [9, 5, 4, 8, 0] 88 >>> train 89 [1, 2, 3, 6, 7] 90 91 """ 92 if frac<0. or frac > 1.: 93 raise ValueError('frac must be between 0.0 and 1.0 (frac=%f)'%(frac)) 94 95 if replacement: 96 nTrain = int(nPts*frac) 97 resData = [None]*nTrain 98 resTest = [] 99 for i in range(nTrain): 100 val = int(RDRandom.random()*nPts) 101 if val==nPts: val = nPts-1 102 resData[i] = val 103 for i in range(nPts): 104 if i not in resData: 105 resTest.append(i) 106 elif legacy: 107 resData = [] 108 resTest = [] 109 for i in range(nPts): 110 val = RDRandom.random() 111 if val < frac: 112 resData.append(i) 113 else: 114 resTest.append(i) 115 else: 116 perm = RandomArray.permutation(nPts) 117 nTrain = int(nPts*frac) 118 119 resData = list(perm[:nTrain]) 120 resTest = list(perm[nTrain:]) 121 122 if not silent: 123 print 'Training with %d (of %d) points.'%(len(resData),nPts) 124 print '\t%d points are in the hold-out set.'%(len(resTest)) 125 return resData,resTest
126 127
128 -def SplitDataSet(data,frac,silent=0):
129 """ splits a data set into two pieces 130 131 **Arguments** 132 133 - data: a list of examples to be split 134 135 - frac: the fraction of the data to be put in the first data set 136 137 - silent: controls the amount of visual noise produced. 138 139 **Returns** 140 141 a 2-tuple containing the two new data sets. 142 143 """ 144 if frac>0. or frac < 1.: 145 raise ValueError('frac must be between 0.0 and 1.0') 146 147 nOrig = len(data) 148 train,test = SplitIndices(nOrig,frac,silent=1) 149 resData = [data[x] for x in train] 150 resTest = [data[x] for x in test] 151 152 if not silent: 153 print 'Training with %d (of %d) points.'%(len(resData),nOrig) 154 print '\t%d points are in the hold-out set.'%(len(resTest)) 155 return resData,resTest
156 157
158 -def SplitDbData(conn,fracs,table='',fields='*',where='',join='', 159 labelCol='', 160 useActs=0,nActs=2,actCol='',actBounds=[], 161 silent=0):
162 """ "splits" a data set held in a DB by returning lists of ids 163 164 **Arguments**: 165 166 - conn: a DbConnect object 167 168 - frac: the split fraction. This can optionally be specified as a 169 sequence with a different fraction for each activity value. 170 171 - table,fields,where,join: (optional) SQL query parameters 172 173 - useActs: (optional) toggles splitting based on activities 174 (ensuring that a given fraction of each activity class ends 175 up in the hold-out set) 176 Defaults to 0 177 178 - nActs: (optional) number of possible activity values, only 179 used if _useActs_ is nonzero 180 Defaults to 2 181 182 - actCol: (optional) name of the activity column 183 Defaults to use the last column returned by the query 184 185 - actBounds: (optional) sequence of activity bounds 186 (for cases where the activity isn't quantized in the db) 187 Defaults to an empty sequence 188 189 - silent: controls the amount of visual noise produced. 190 191 **Usage**: 192 193 Set up the db connection, the simple tables we're using have actives with even 194 ids and inactives with odd ids: 195 >>> from ML.Data import DataUtils 196 >>> from Dbase.DbConnection import DbConnect 197 >>> conn = DbConnect(RDConfig.RDTestDatabase) 198 199 Pull a set of points from a simple table... take 33% of all points: 200 >>> DataUtils.InitRandomNumbers((23,42)) 201 >>> train,test = SplitDbData(conn,1./3.,'basic_2class') 202 >>> [str(x) for x in train] 203 ['id-10', 'id-5', 'id-4', 'id-9'] 204 205 ...take 50% of actives and 50% of inactives: 206 >>> DataUtils.InitRandomNumbers((23,42)) 207 >>> train,test = SplitDbData(conn,.5,'basic_2class',useActs=1) 208 >>> [str(x) for x in train] 209 ['id-9', 'id-7', 'id-5', 'id-8', 'id-6', 'id-4'] 210 211 Notice how the results came out sorted by activity 212 213 We can be asymmetrical: take 33% of actives and 50% of inactives: 214 >>> DataUtils.InitRandomNumbers((23,42)) 215 >>> train,test = SplitDbData(conn,[.5,1./3.],'basic_2class',useActs=1) 216 >>> [str(x) for x in train] 217 ['id-9', 'id-7', 'id-5', 'id-8', 'id-6'] 218 219 And we can pull from tables with non-quantized activities by providing 220 activity quantization bounds: 221 >>> DataUtils.InitRandomNumbers((23,42)) 222 >>> train,test = SplitDbData(conn,.5,'float_2class',useActs=1,actBounds=[1.0]) 223 >>> [str(x) for x in train] 224 ['id-9', 'id-7', 'id-5', 'id-8', 'id-6', 'id-4'] 225 226 """ 227 if not table: 228 table=conn.tableName 229 if actBounds and len(actBounds)!=nActs-1: 230 raise ValueError('activity bounds list length incorrect') 231 if useActs: 232 if type(fracs) not in SeqTypes: 233 fracs = tuple([fracs]*nActs) 234 for frac in fracs: 235 if frac <0.0 or frac>1.0: 236 raise ValueError('fractions must be between 0.0 and 1.0') 237 else: 238 if type(fracs) in SeqTypes: 239 frac = fracs[0] 240 if frac<0.0 or frac>1.0: 241 raise ValueError('fractions must be between 0.0 and 1.0') 242 else: 243 frac = fracs 244 # start by getting the name of the ID column: 245 colNames = conn.GetColumnNames(table=table,what=fields,join=join) 246 idCol = colNames[0] 247 248 if not useActs: 249 # get the IDS: 250 d = conn.GetData(table=table,fields=idCol,join=join) 251 ids = [x[0] for x in d] 252 nRes = len(ids) 253 train,test = SplitIndices(nRes,frac,silent=1) 254 trainPts = [ids[x] for x in train] 255 testPts = [ids[x] for x in test] 256 else: 257 trainPts = [] 258 testPts = [] 259 if not actCol: 260 actCol = colNames[-1] 261 whereBase=where.strip() 262 if whereBase.find('where')!=0: 263 whereBase = 'where '+whereBase 264 if where: 265 whereBase += ' and ' 266 for act in range(nActs): 267 frac = fracs[act] 268 if not actBounds: 269 whereTxt = whereBase + '%s=%d'%(actCol,act) 270 else: 271 whereTxt = whereBase 272 if act!=0: 273 whereTxt += '%s>=%f '%(actCol,actBounds[act-1]) 274 if act < nActs-1: 275 if act!=0: 276 whereTxt += 'and ' 277 whereTxt += '%s<%f'%(actCol,actBounds[act]) 278 d = conn.GetData(table=table,fields=idCol,join=join,where=whereTxt) 279 ids = [x[0] for x in d] 280 nRes = len(ids) 281 train,test = SplitIndices(nRes,frac,silent=1) 282 trainPts.extend([ids[x] for x in train]) 283 testPts.extend([ids[x] for x in test]) 284 285 return trainPts,testPts
286
287 -def _test():
288 import doctest,sys 289 return doctest.testmod(sys.modules["__main__"])
290 291 if __name__ == '__main__': 292 import sys 293 failed,tried = _test() 294 sys.exit(failed) 295