Andrew Stirn commited on
Commit
9ccfeb4
·
1 Parent(s): 79470c2

massive cleanup with better table columns

Browse files
Files changed (1) hide show
  1. tiger.py +78 -40
tiger.py CHANGED
@@ -1,6 +1,7 @@
1
  import argparse
2
  import os
3
  import gzip
 
4
  import numpy as np
5
  import pandas as pd
6
  import tensorflow as tf
@@ -14,6 +15,13 @@ NUCLEOTIDE_TOKENS = dict(zip(['A', 'C', 'G', 'T', 'N'], [0, 1, 2, 3, 255]))
14
  NUCLEOTIDE_COMPLEMENT = dict(zip(['A', 'C', 'G', 'T'], ['T', 'G', 'C', 'A']))
15
  NUM_TOP_GUIDES = 10
16
  NUM_MISMATCHES = 3
 
 
 
 
 
 
 
17
  REFERENCE_TRANSCRIPTS = ('gencode.v19.pc_transcripts.fa.gz', 'gencode.v19.lncRNA_transcripts.fa.gz')
18
  BATCH_SIZE_COMPUTE = 500
19
  BATCH_SIZE_SCAN = 20
@@ -35,18 +43,18 @@ def load_transcripts(fasta_files):
35
  try:
36
  if os.path.splitext(file)[1] == '.gz':
37
  with gzip.open(file, 'rt') as f:
38
- df = pd.DataFrame([(t.id, str(t.seq)) for t in SeqIO.parse(f, 'fasta')], columns=['id', 'seq'])
39
  else:
40
- df = pd.DataFrame([(t.id, str(t.seq)) for t in SeqIO.parse(file, 'fasta')], columns=['id', 'seq'])
41
  except Exception as e:
42
  print(e, 'while loading', file)
43
  continue
44
  transcripts = pd.concat([transcripts, df])
45
 
46
  # set index
47
- transcripts['id'] = transcripts['id'].apply(lambda s: s.split('|')[0])
48
- transcripts.set_index('id', inplace=True)
49
- assert not transcripts.index.has_duplicates, "duplicate transcript ID's detected"
50
 
51
  return transcripts
52
 
@@ -101,6 +109,9 @@ def process_data(transcript_seq: str):
101
 
102
 
103
  def prediction_transform(predictions: np.array, **params):
 
 
 
104
 
105
  if UNIT_INTERVAL_MAP == 'sigmoid':
106
  return 1 - 1 / (1 + np.exp(params['a'] * predictions + params['b']))
@@ -135,23 +146,47 @@ def prediction_transform(predictions: np.array, **params):
135
  raise NotImplementedError
136
 
137
 
138
- def titration_ratio(guide: np.array, parent: np.array):
139
- return 1 - np.clip(parent - guide, a_min=0.0, a_max=1.0)
140
 
 
 
 
141
 
142
- def predict_on_target(transcript_seq: str, model: tf.keras.Model):
 
143
 
144
- # parse transcript sequence
145
- target_seq, guide_seq, model_inputs = process_data(transcript_seq)
 
 
 
 
 
 
146
 
147
- # get predictions
148
- normalized_lfc = model.predict(model_inputs, batch_size=BATCH_SIZE_COMPUTE, verbose=False)
149
- predictions = pd.DataFrame({'Guide': guide_seq, 'Normalized LFC': tf.squeeze(normalized_lfc).numpy()})
150
- predictions = predictions.sort_values('Normalized LFC')
 
 
 
 
151
 
152
  return predictions
153
 
154
 
 
 
 
 
 
 
 
 
 
 
 
155
  def find_off_targets(top_guides: pd.DataFrame, status_bar, status_text):
156
 
157
  # load reference transcripts
@@ -171,7 +206,7 @@ def find_off_targets(top_guides: pd.DataFrame, status_bar, status_text):
171
  i += BATCH_SIZE_SCAN
172
 
173
  # find locations of off-targets
174
- transcripts = one_hot_encode_sequence(df_batch['seq'].values.tolist(), add_context_padding=False)
175
  num_mismatches = GUIDE_LEN - tf.nn.conv1d(transcripts, guide_filter, stride=1, padding='SAME')
176
  loc_off_targets = tf.where(tf.round(num_mismatches) <= NUM_MISMATCHES).numpy()
177
 
@@ -183,7 +218,7 @@ def find_off_targets(top_guides: pd.DataFrame, status_bar, status_text):
183
  'On-target ID': top_guides.iloc[loc_off_targets[:, 2]]['On-target ID'],
184
  'Guide': top_guides.iloc[loc_off_targets[:, 2]]['Guide'],
185
  'Off-target ID': df_batch.index.values[loc_off_targets[:, 0]],
186
- 'Target': df_batch['seq'].values[loc_off_targets[:, 0]],
187
  'Mismatches': tf.gather_nd(num_mismatches, loc_off_targets).numpy().astype(int),
188
  'Midpoint': loc_off_targets[:, 1],
189
  }).to_dict('records')
@@ -224,12 +259,12 @@ def predict_off_target(off_targets: pd.DataFrame, model: tf.keras.Model):
224
  tf.reshape(one_hot_encode_sequence(off_targets['Target'], add_context_padding=False), [len(off_targets), -1]),
225
  tf.reshape(one_hot_encode_sequence(off_targets['Guide'], add_context_padding=True), [len(off_targets), -1]),
226
  ], axis=-1)
227
- off_targets['Normalized LFC'] = model.predict(model_inputs, batch_size=BATCH_SIZE_COMPUTE, verbose=False)
228
 
229
- return off_targets.sort_values('Normalized LFC')
230
 
231
 
232
- def tiger_exhibit(transcripts: pd.DataFrame, status_bar=None, status_text=None, check_off_targets=False):
233
 
234
  # load model
235
  if os.path.exists('model'):
@@ -238,31 +273,30 @@ def tiger_exhibit(transcripts: pd.DataFrame, status_bar=None, status_text=None,
238
  print('no saved model!')
239
  exit()
240
 
241
- # find top guides for each transcript
242
- print('Finding top guides for each transcript')
243
- on_target_predictions = pd.DataFrame(columns=['On-target ID', 'Guide', 'Normalized LFC'])
244
- for i, (index, row) in enumerate(transcripts.iterrows()):
245
- df = predict_on_target(row['seq'], model=tiger)
246
- df['On-target ID'] = index
247
- on_target_predictions = pd.concat([on_target_predictions, df.iloc[:NUM_TOP_GUIDES]])
248
 
249
- # progress update
250
- if status_bar:
251
- status_text.text("Scanning for on-targets Percent complete: {:.2f}%".format(100 * min((i + 1) / len(transcripts), 1)))
252
- status_bar.progress(int(100 * min((i + 1) / len(transcripts), 1)))
253
- print('\rPercent complete: {:.2f}%'.format(100 * min((i + 1) / len(transcripts), 1)), end='')
254
- print('')
255
-
256
- # predict off-target effects for top guides
257
  off_target_predictions = pd.DataFrame()
 
 
 
 
 
 
 
 
 
 
 
258
  if check_off_targets:
259
  off_targets = find_off_targets(on_target_predictions, status_bar, status_text)
260
  off_target_predictions = predict_off_target(off_targets, model=tiger)
261
 
262
  # reverse guide sequences
263
- on_target_predictions['Guide'] = on_target_predictions['Guide'].apply(lambda s: s[::-1])
264
  if check_off_targets and len(off_target_predictions) > 0:
265
- off_target_predictions['Guide'] = off_target_predictions['Guide'].apply(lambda s: s[::-1])
266
 
267
  return on_target_predictions.reset_index(drop=True), off_target_predictions.reset_index(drop=True)
268
 
@@ -279,9 +313,11 @@ if __name__ == '__main__':
279
  # simple test case
280
  if args.simple_test:
281
  # first 50 from EIF3B-003's CDS
282
- simple_test = pd.DataFrame(dict(id=['ManualEntry'], seq=['ATGCAGGACGCGGAGAACGTGGCGGTGCCCGAGGCGGCCGAGGAGCGCGC']))
283
- simple_test.set_index('id', inplace=True)
284
- df_on_target, df_off_target = tiger_exhibit(simple_test, check_off_targets=args.check_off_targets)
 
 
285
  df_on_target.to_csv('on_target.csv')
286
  if args.check_off_targets:
287
  df_off_target.to_csv('off_target.csv')
@@ -306,7 +342,9 @@ if __name__ == '__main__':
306
 
307
  # run batch
308
  idx_stop = min(idx + BATCH_SIZE_TRANSCRIPTS, len(df_transcripts))
309
- df_on_target, df_off_target = tiger_exhibit(df_transcripts[idx:idx_stop], check_off_targets=args.check_off_targets)
 
 
310
 
311
  # save batch results
312
  df_on_target.to_csv('on_target.csv', header=batch == 1, index=False, mode='a')
 
1
  import argparse
2
  import os
3
  import gzip
4
+ import pickle
5
  import numpy as np
6
  import pandas as pd
7
  import tensorflow as tf
 
15
  NUCLEOTIDE_COMPLEMENT = dict(zip(['A', 'C', 'G', 'T'], ['T', 'G', 'C', 'A']))
16
  NUM_TOP_GUIDES = 10
17
  NUM_MISMATCHES = 3
18
+ ID_COL = 'Transcript ID'
19
+ SEQ_COL = 'Sequence'
20
+ TARGET_COL = 'Target Sequence'
21
+ GUIDE_COL = 'Guide Sequence'
22
+ SCORE_COL = 'Guide Score'
23
+ RUN_MODE_ALL_PM = 'All on-target guides per transcript'
24
+ RUN_MODE_TITRATION = 'Top guides per transcript'
25
  REFERENCE_TRANSCRIPTS = ('gencode.v19.pc_transcripts.fa.gz', 'gencode.v19.lncRNA_transcripts.fa.gz')
26
  BATCH_SIZE_COMPUTE = 500
27
  BATCH_SIZE_SCAN = 20
 
43
  try:
44
  if os.path.splitext(file)[1] == '.gz':
45
  with gzip.open(file, 'rt') as f:
46
+ df = pd.DataFrame([(t.id, str(t.seq)) for t in SeqIO.parse(f, 'fasta')], columns=[ID_COL, SEQ_COL])
47
  else:
48
+ df = pd.DataFrame([(t.id, str(t.seq)) for t in SeqIO.parse(file, 'fasta')], columns=[ID_COL, SEQ_COL])
49
  except Exception as e:
50
  print(e, 'while loading', file)
51
  continue
52
  transcripts = pd.concat([transcripts, df])
53
 
54
  # set index
55
+ transcripts[ID_COL] = transcripts[ID_COL].apply(lambda s: s.split('|')[0])
56
+ transcripts.set_index(ID_COL, inplace=True)
57
+ assert not transcripts.index.has_duplicates, "duplicate transcript ID's detected in fasta file"
58
 
59
  return transcripts
60
 
 
109
 
110
 
111
  def prediction_transform(predictions: np.array, **params):
112
+ if len(params) == 0:
113
+ with open('transform_params.pkl', 'rb') as f:
114
+ params = pickle.load(f)
115
 
116
  if UNIT_INTERVAL_MAP == 'sigmoid':
117
  return 1 - 1 / (1 + np.exp(params['a'] * predictions + params['b']))
 
146
  raise NotImplementedError
147
 
148
 
149
+ def get_on_target_predictions(transcripts: pd.DataFrame, model: tf.keras.Model, status_bar=None, status_text=None):
 
150
 
151
+ # loop over transcripts
152
+ predictions = pd.DataFrame()
153
+ for i, (index, row) in enumerate(transcripts.iterrows()):
154
 
155
+ # parse transcript sequence
156
+ target_seq, guide_seq, model_inputs = process_data(row[SEQ_COL])
157
 
158
+ # get predictions
159
+ lfc_estimate = model.predict(model_inputs, batch_size=BATCH_SIZE_COMPUTE, verbose=False)
160
+ scores = prediction_transform(tf.squeeze(lfc_estimate).numpy())
161
+ predictions = pd.concat([predictions, pd.DataFrame({
162
+ ID_COL: [index] * len(scores),
163
+ TARGET_COL: [seq[CONTEXT_5P:len(seq) - CONTEXT_3P] for seq in target_seq],
164
+ GUIDE_COL: guide_seq,
165
+ SCORE_COL: scores})])
166
 
167
+ # progress update
168
+ percent_complete = 100 * min((i + 1) / len(transcripts), 1)
169
+ update_text = 'Evaluating on-target guides for each transcript: {:.2f}%'.format(percent_complete)
170
+ if status_bar:
171
+ status_text.text()
172
+ status_bar.progress(int(100 * min((i + 1) / len(transcripts), 1)))
173
+ print('\r' + update_text, end='')
174
+ print('')
175
 
176
  return predictions
177
 
178
 
179
+ def top_guides_per_transcript(predictions: pd.DataFrame):
180
+
181
+ top_guides = pd.DataFrame()
182
+ for transcript in predictions[ID_COL].unique():
183
+ df = predictions.loc[predictions[ID_COL] == transcript]
184
+ df = df.sort_values(SCORE_COL, ascending=False).reset_index(drop=True).iloc[:NUM_TOP_GUIDES]
185
+ top_guides = pd.concat([top_guides, df])
186
+
187
+ return top_guides.reset_index(drop=True)
188
+
189
+
190
  def find_off_targets(top_guides: pd.DataFrame, status_bar, status_text):
191
 
192
  # load reference transcripts
 
206
  i += BATCH_SIZE_SCAN
207
 
208
  # find locations of off-targets
209
+ transcripts = one_hot_encode_sequence(df_batch[SEQ_COL].values.tolist(), add_context_padding=False)
210
  num_mismatches = GUIDE_LEN - tf.nn.conv1d(transcripts, guide_filter, stride=1, padding='SAME')
211
  loc_off_targets = tf.where(tf.round(num_mismatches) <= NUM_MISMATCHES).numpy()
212
 
 
218
  'On-target ID': top_guides.iloc[loc_off_targets[:, 2]]['On-target ID'],
219
  'Guide': top_guides.iloc[loc_off_targets[:, 2]]['Guide'],
220
  'Off-target ID': df_batch.index.values[loc_off_targets[:, 0]],
221
+ 'Target': df_batch[SEQ_COL].values[loc_off_targets[:, 0]],
222
  'Mismatches': tf.gather_nd(num_mismatches, loc_off_targets).numpy().astype(int),
223
  'Midpoint': loc_off_targets[:, 1],
224
  }).to_dict('records')
 
259
  tf.reshape(one_hot_encode_sequence(off_targets['Target'], add_context_padding=False), [len(off_targets), -1]),
260
  tf.reshape(one_hot_encode_sequence(off_targets['Guide'], add_context_padding=True), [len(off_targets), -1]),
261
  ], axis=-1)
262
+ off_targets[SCORE_COL] = model.predict(model_inputs, batch_size=BATCH_SIZE_COMPUTE, verbose=False)
263
 
264
+ return off_targets.sort_values(SCORE_COL)
265
 
266
 
267
+ def tiger_exhibit(transcripts: pd.DataFrame, run_mode: str, check_off_targets: bool, status_bar=None, status_text=None):
268
 
269
  # load model
270
  if os.path.exists('model'):
 
273
  print('no saved model!')
274
  exit()
275
 
276
+ # evaluate all on-target guides per transcript
277
+ on_target_predictions = get_on_target_predictions(transcripts, tiger, status_bar, status_text)
 
 
 
 
 
278
 
279
+ # initialize other outputs
 
 
 
 
 
 
 
280
  off_target_predictions = pd.DataFrame()
281
+
282
+ if run_mode == RUN_MODE_ALL_PM:
283
+ return on_target_predictions, off_target_predictions
284
+
285
+ elif run_mode == RUN_MODE_TITRATION: # TODO: and titration candidates
286
+ on_target_predictions = top_guides_per_transcript(on_target_predictions)
287
+
288
+ else:
289
+ raise NotImplementedError
290
+
291
+ # check off-target effects for top guides
292
  if check_off_targets:
293
  off_targets = find_off_targets(on_target_predictions, status_bar, status_text)
294
  off_target_predictions = predict_off_target(off_targets, model=tiger)
295
 
296
  # reverse guide sequences
297
+ on_target_predictions[GUIDE_COL] = on_target_predictions[GUIDE_COL].apply(lambda s: s[::-1])
298
  if check_off_targets and len(off_target_predictions) > 0:
299
+ off_target_predictions[GUIDE_COL] = off_target_predictions[GUIDE_COL].apply(lambda s: s[::-1])
300
 
301
  return on_target_predictions.reset_index(drop=True), off_target_predictions.reset_index(drop=True)
302
 
 
313
  # simple test case
314
  if args.simple_test:
315
  # first 50 from EIF3B-003's CDS
316
+ simple_test = pd.DataFrame({
317
+ ID_COL: ['ManualEntry'],
318
+ SEQ_COL: ['ATGCAGGACGCGGAGAACGTGGCGGTGCCCGAGGCGGCCGAGGAGCGCGC']})
319
+ simple_test.set_index(ID_COL, inplace=True)
320
+ df_on_target, df_off_target = tiger_exhibit(simple_test, check_off_targets=args.off_target)
321
  df_on_target.to_csv('on_target.csv')
322
  if args.check_off_targets:
323
  df_off_target.to_csv('off_target.csv')
 
342
 
343
  # run batch
344
  idx_stop = min(idx + BATCH_SIZE_TRANSCRIPTS, len(df_transcripts))
345
+ df_on_target, df_off_target = tiger_exhibit(df_transcripts[idx:idx_stop],
346
+ run_mode=RUN_MODE_TITRATION,
347
+ check_off_targets=args.check_off_targets)
348
 
349
  # save batch results
350
  df_on_target.to_csv('on_target.csv', header=batch == 1, index=False, mode='a')