Andrew Stirn commited on
Commit
42a3866
·
1 Parent(s): ccfd7e1

titration mode!

Browse files
Files changed (1) hide show
  1. tiger.py +46 -13
tiger.py CHANGED
@@ -38,7 +38,7 @@ NUM_MISMATCHES = 3
38
  RUN_MODES = dict(
39
  all='All on-target guides per transcript',
40
  top_guides='Top {:d} guides per transcript'.format(NUM_TOP_GUIDES),
41
- # titration='Top {:d} guides per transcript & their titration candidates'.format(NUM_TOP_GUIDES) # TODO: do this!
42
  )
43
 
44
 
@@ -175,7 +175,7 @@ def get_on_target_predictions(transcripts: pd.DataFrame, model: tf.keras.Model,
175
  scores = prediction_transform(tf.squeeze(lfc_estimate).numpy())
176
  predictions = pd.concat([predictions, pd.DataFrame({
177
  ID_COL: [index] * len(scores),
178
- TARGET_COL: [seq[CONTEXT_5P:len(seq) - CONTEXT_3P] for seq in target_seq],
179
  GUIDE_COL: guide_seq,
180
  SCORE_COL: scores})])
181
 
@@ -202,6 +202,27 @@ def top_guides_per_transcript(predictions: pd.DataFrame):
202
  return top_guides.reset_index(drop=True)
203
 
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  def find_off_targets(top_guides: pd.DataFrame, status_update_fn=None):
206
 
207
  # load reference transcripts
@@ -284,7 +305,7 @@ def predict_off_target(off_targets: pd.DataFrame, model: tf.keras.Model):
284
  # trim context sequence
285
  off_targets[TARGET_COL] = off_targets[TARGET_COL].apply(lambda seq: seq[CONTEXT_5P:len(seq) - CONTEXT_3P])
286
 
287
- return off_targets.sort_values(SCORE_COL, ascending=False).reset_index(drop=True)
288
 
289
 
290
  def tiger_exhibit(transcripts: pd.DataFrame, mode: str, check_off_targets: bool, status_update_fn=None):
@@ -300,30 +321,36 @@ def tiger_exhibit(transcripts: pd.DataFrame, mode: str, check_off_targets: bool,
300
  on_target_predictions = get_on_target_predictions(transcripts, tiger, status_update_fn)
301
 
302
  # initialize other outputs
303
- off_target_predictions = None
304
 
305
  if mode == 'all' and not check_off_targets:
306
- pass # nothing to do!
307
 
308
  elif mode == 'top_guides':
309
  on_target_predictions = top_guides_per_transcript(on_target_predictions)
 
310
 
311
- # TODO: add titration candidates
 
 
 
 
312
 
313
  else:
314
  raise NotImplementedError
315
 
316
  # check off-target effects for top guides
317
- if check_off_targets:
318
- off_targets = find_off_targets(on_target_predictions, status_update_fn)
319
- off_target_predictions = predict_off_target(off_targets, model=tiger)
 
320
 
321
  # reverse guide sequences
322
  on_target_predictions[GUIDE_COL] = on_target_predictions[GUIDE_COL].apply(lambda s: s[::-1])
323
  if check_off_targets and len(off_target_predictions) > 0:
324
  off_target_predictions[GUIDE_COL] = off_target_predictions[GUIDE_COL].apply(lambda s: s[::-1])
325
 
326
- return on_target_predictions, off_target_predictions
327
 
328
 
329
  if __name__ == '__main__':
@@ -336,7 +363,7 @@ if __name__ == '__main__':
336
  args = parser.parse_args()
337
 
338
  # check for any existing results
339
- if os.path.exists('on_target.csv') or os.path.exists('off_target.csv'):
340
  raise FileExistsError('please rename or delete existing results')
341
 
342
  # load transcripts from a directory of fasta files
@@ -360,11 +387,17 @@ if __name__ == '__main__':
360
 
361
  # run batch
362
  idx_stop = min(idx + BATCH_SIZE_TRANSCRIPTS, len(df_transcripts))
363
- df_on_target, df_off_target = tiger_exhibit(df_transcripts[idx:idx_stop], args.mode, args.check_off_targets)
 
 
 
 
364
 
365
  # save batch results
366
  df_on_target.to_csv('on_target.csv', header=batch == 1, index=False, mode='a')
367
- if args.check_off_targets:
 
 
368
  df_off_target.to_csv('off_target.csv', header=batch == 1, index=False, mode='a')
369
 
370
  # clear session to prevent memory blow up
 
38
  RUN_MODES = dict(
39
  all='All on-target guides per transcript',
40
  top_guides='Top {:d} guides per transcript'.format(NUM_TOP_GUIDES),
41
+ titration='Top {:d} guides per transcript & their titration candidates'.format(NUM_TOP_GUIDES)
42
  )
43
 
44
 
 
175
  scores = prediction_transform(tf.squeeze(lfc_estimate).numpy())
176
  predictions = pd.concat([predictions, pd.DataFrame({
177
  ID_COL: [index] * len(scores),
178
+ TARGET_COL: target_seq,
179
  GUIDE_COL: guide_seq,
180
  SCORE_COL: scores})])
181
 
 
202
  return top_guides.reset_index(drop=True)
203
 
204
 
205
+ def get_titration_candidates(top_guide_predictions: pd.DataFrame):
206
+
207
+ # generate a table of all titration candidates
208
+ titration_candidates = pd.DataFrame()
209
+ for _, row in top_guide_predictions.iterrows():
210
+ for i in range(len(row[GUIDE_COL])):
211
+ nt = row[GUIDE_COL][i]
212
+ for mutation in set(NUCLEOTIDE_TOKENS.keys()) - {nt, 'N'}:
213
+ sm_guide = list(row[GUIDE_COL])
214
+ sm_guide[i] = mutation
215
+ sm_guide = ''.join(sm_guide)
216
+ assert row[GUIDE_COL] != sm_guide
217
+ titration_candidates = pd.concat([titration_candidates, pd.DataFrame({
218
+ ID_COL: [row[ID_COL]],
219
+ TARGET_COL: [row[TARGET_COL]],
220
+ GUIDE_COL: [sm_guide]
221
+ })])
222
+
223
+ return titration_candidates
224
+
225
+
226
  def find_off_targets(top_guides: pd.DataFrame, status_update_fn=None):
227
 
228
  # load reference transcripts
 
305
  # trim context sequence
306
  off_targets[TARGET_COL] = off_targets[TARGET_COL].apply(lambda seq: seq[CONTEXT_5P:len(seq) - CONTEXT_3P])
307
 
308
+ return off_targets.reset_index(drop=True)
309
 
310
 
311
  def tiger_exhibit(transcripts: pd.DataFrame, mode: str, check_off_targets: bool, status_update_fn=None):
 
321
  on_target_predictions = get_on_target_predictions(transcripts, tiger, status_update_fn)
322
 
323
  # initialize other outputs
324
+ titration_predictions = off_target_predictions = None
325
 
326
  if mode == 'all' and not check_off_targets:
327
+ off_target_candidates = None
328
 
329
  elif mode == 'top_guides':
330
  on_target_predictions = top_guides_per_transcript(on_target_predictions)
331
+ off_target_candidates = on_target_predictions
332
 
333
+ elif mode == 'titration':
334
+ on_target_predictions = top_guides_per_transcript(on_target_predictions)
335
+ titration_candidates = get_titration_candidates(on_target_predictions)
336
+ titration_predictions = predict_off_target(titration_candidates, model=tiger)
337
+ off_target_candidates = pd.concat([on_target_predictions, titration_predictions])
338
 
339
  else:
340
  raise NotImplementedError
341
 
342
  # check off-target effects for top guides
343
+ if check_off_targets and off_target_candidates is not None:
344
+ off_target_candidates = find_off_targets(off_target_candidates, status_update_fn)
345
+ off_target_predictions = predict_off_target(off_target_candidates, model=tiger)
346
+ off_target_predictions = off_target_predictions.sort_values(SCORE_COL, ascending=False)
347
 
348
  # reverse guide sequences
349
  on_target_predictions[GUIDE_COL] = on_target_predictions[GUIDE_COL].apply(lambda s: s[::-1])
350
  if check_off_targets and len(off_target_predictions) > 0:
351
  off_target_predictions[GUIDE_COL] = off_target_predictions[GUIDE_COL].apply(lambda s: s[::-1])
352
 
353
+ return on_target_predictions, titration_predictions, off_target_predictions
354
 
355
 
356
  if __name__ == '__main__':
 
363
  args = parser.parse_args()
364
 
365
  # check for any existing results
366
+ if os.path.exists('on_target.csv') or os.path.exists('titration.csv') or os.path.exists('off_target.csv'):
367
  raise FileExistsError('please rename or delete existing results')
368
 
369
  # load transcripts from a directory of fasta files
 
387
 
388
  # run batch
389
  idx_stop = min(idx + BATCH_SIZE_TRANSCRIPTS, len(df_transcripts))
390
+ df_on_target, df_titration, df_off_target = tiger_exhibit(
391
+ transcripts=df_transcripts[idx:idx_stop],
392
+ mode=args.mode,
393
+ check_off_targets=args.check_off_targets
394
+ )
395
 
396
  # save batch results
397
  df_on_target.to_csv('on_target.csv', header=batch == 1, index=False, mode='a')
398
+ if df_titration is not None:
399
+ df_titration.to_csv('titration.csv', header=batch == 1, index=False, mode='a')
400
+ if df_off_target is not None:
401
  df_off_target.to_csv('off_target.csv', header=batch == 1, index=False, mode='a')
402
 
403
  # clear session to prevent memory blow up