Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import argparse | |
| import os | |
| import gzip | |
| import pandas as pd | |
| import tensorflow as tf | |
| from Bio import SeqIO | |
| GUIDE_LEN = 23 | |
| CONTEXT_5P = 3 | |
| CONTEXT_3P = 0 | |
| TARGET_LEN = CONTEXT_5P + GUIDE_LEN + CONTEXT_3P | |
| NUCLEOTIDE_TOKENS = dict(zip(['A', 'C', 'G', 'T', 'N'], [0, 1, 2, 3, 255])) | |
| NUCLEOTIDE_COMPLEMENT = dict(zip(['A', 'C', 'G', 'T'], ['T', 'G', 'C', 'A'])) | |
| NUM_TOP_GUIDES = 10 | |
| NUM_MISMATCHES = 3 | |
| REFERENCE_TRANSCRIPTS = ('gencode.v19.pc_transcripts.fa.gz', 'gencode.v19.lncRNA_transcripts.fa.gz') | |
| BATCH_SIZE_COMPUTE = 500 | |
| BATCH_SIZE_SCAN = 20 | |
| BATCH_SIZE_TRANSCRIPTS = 50 | |
| # configure GPUs | |
| for gpu in tf.config.list_physical_devices('GPU'): | |
| tf.config.experimental.set_memory_growth(gpu, enable=True) | |
| if len(tf.config.list_physical_devices('GPU')) > 0: | |
| tf.config.experimental.set_visible_devices(tf.config.list_physical_devices('GPU')[0], 'GPU') | |
| def load_transcripts(fasta_files): | |
| # load all transcripts from fasta files into a DataFrame | |
| transcripts = pd.DataFrame() | |
| for file in fasta_files: | |
| try: | |
| if os.path.splitext(file)[1] == '.gz': | |
| with gzip.open(file, 'rt') as f: | |
| df = pd.DataFrame([(t.id, str(t.seq)) for t in SeqIO.parse(f, 'fasta')], columns=['id', 'seq']) | |
| else: | |
| df = pd.DataFrame([(t.id, str(t.seq)) for t in SeqIO.parse(file, 'fasta')], columns=['id', 'seq']) | |
| except Exception as e: | |
| print(e, 'while loading', file) | |
| continue | |
| transcripts = pd.concat([transcripts, df]) | |
| # set index | |
| transcripts['id'] = transcripts['id'].apply(lambda s: s.split('|')[0]) | |
| transcripts.set_index('id', inplace=True) | |
| assert not transcripts.index.has_duplicates, "duplicate transcript ID's detected" | |
| return transcripts | |
| def sequence_complement(sequence: list): | |
| return [''.join([NUCLEOTIDE_COMPLEMENT[nt] for nt in list(seq)]) for seq in sequence] | |
| def one_hot_encode_sequence(sequence: list, add_context_padding: bool = False): | |
| # stack list of sequences into a tensor | |
| sequence = tf.ragged.stack([tf.constant(list(seq)) for seq in sequence], axis=0) | |
| # tokenize sequence | |
| nucleotide_table = tf.lookup.StaticVocabularyTable( | |
| initializer=tf.lookup.KeyValueTensorInitializer( | |
| keys=tf.constant(list(NUCLEOTIDE_TOKENS.keys()), dtype=tf.string), | |
| values=tf.constant(list(NUCLEOTIDE_TOKENS.values()), dtype=tf.int64)), | |
| num_oov_buckets=1) | |
| sequence = tf.RaggedTensor.from_row_splits(values=nucleotide_table.lookup(sequence.values), | |
| row_splits=sequence.row_splits).to_tensor(255) | |
| # add context padding if requested | |
| if add_context_padding: | |
| pad_5p = 255 * tf.ones([sequence.shape[0], CONTEXT_5P], dtype=sequence.dtype) | |
| pad_3p = 255 * tf.ones([sequence.shape[0], CONTEXT_3P], dtype=sequence.dtype) | |
| sequence = tf.concat([pad_5p, sequence, pad_3p], axis=1) | |
| # one-hot encode | |
| sequence = tf.one_hot(sequence, depth=4, dtype=tf.float16) | |
| return sequence | |
| def process_data(transcript_seq: str): | |
| # convert to upper case | |
| transcript_seq = transcript_seq.upper() | |
| # get all target sites | |
| target_seq = [transcript_seq[i: i + TARGET_LEN] for i in range(len(transcript_seq) - TARGET_LEN + 1)] | |
| # prepare guide sequences | |
| guide_seq = sequence_complement([seq[CONTEXT_5P:len(seq) - CONTEXT_3P] for seq in target_seq]) | |
| # model inputs | |
| model_inputs = tf.concat([ | |
| tf.reshape(one_hot_encode_sequence(target_seq, add_context_padding=False), [len(target_seq), -1]), | |
| tf.reshape(one_hot_encode_sequence(guide_seq, add_context_padding=True), [len(guide_seq), -1]), | |
| ], axis=-1) | |
| return target_seq, guide_seq, model_inputs | |
| def predict_on_target(transcript_seq: str, model: tf.keras.Model): | |
| # parse transcript sequence | |
| target_seq, guide_seq, model_inputs = process_data(transcript_seq) | |
| # get predictions | |
| normalized_lfc = model.predict(model_inputs, batch_size=BATCH_SIZE_COMPUTE, verbose=False) | |
| predictions = pd.DataFrame({'Guide': guide_seq, 'Normalized LFC': tf.squeeze(normalized_lfc).numpy()}) | |
| predictions = predictions.sort_values('Normalized LFC') | |
| return predictions | |
| def find_off_targets(top_guides: pd.DataFrame, status_bar, status_text): | |
| # load reference transcripts | |
| reference_transcripts = load_transcripts([os.path.join('transcripts', f) for f in REFERENCE_TRANSCRIPTS]) | |
| # one-hot encode guides to form a filter | |
| guide_filter = one_hot_encode_sequence(sequence_complement(top_guides['Guide']), add_context_padding=False) | |
| guide_filter = tf.transpose(guide_filter, [1, 2, 0]) | |
| # loop over transcripts in batches | |
| i = 0 | |
| print('Scanning for off-targets') | |
| off_targets = pd.DataFrame() | |
| while i < len(reference_transcripts): | |
| # select batch | |
| df_batch = reference_transcripts.iloc[i:min(i + BATCH_SIZE_SCAN, len(reference_transcripts))] | |
| i += BATCH_SIZE_SCAN | |
| # find locations of off-targets | |
| transcripts = one_hot_encode_sequence(df_batch['seq'].values.tolist(), add_context_padding=False) | |
| num_mismatches = GUIDE_LEN - tf.nn.conv1d(transcripts, guide_filter, stride=1, padding='SAME') | |
| loc_off_targets = tf.where(tf.round(num_mismatches) <= NUM_MISMATCHES).numpy() | |
| # off-targets discovered | |
| if len(loc_off_targets) > 0: | |
| # log off-targets | |
| dict_off_targets = pd.DataFrame({ | |
| 'On-target ID': top_guides.iloc[loc_off_targets[:, 2]]['On-target ID'], | |
| 'Guide': top_guides.iloc[loc_off_targets[:, 2]]['Guide'], | |
| 'Off-target ID': df_batch.index.values[loc_off_targets[:, 0]], | |
| 'Target': df_batch['seq'].values[loc_off_targets[:, 0]], | |
| 'Mismatches': tf.gather_nd(num_mismatches, loc_off_targets).numpy().astype(int), | |
| 'Midpoint': loc_off_targets[:, 1], | |
| }).to_dict('records') | |
| # trim transcripts to targets | |
| for row in dict_off_targets: | |
| start_location = row['Midpoint'] - (GUIDE_LEN // 2) | |
| if start_location < CONTEXT_5P: | |
| row['Target'] = row['Target'][0:GUIDE_LEN + CONTEXT_3P] | |
| row['Target'] = 'N' * (TARGET_LEN - len(row['Target'])) + row['Target'] | |
| elif start_location + GUIDE_LEN + CONTEXT_3P > len(row['Target']): | |
| row['Target'] = row['Target'][start_location - CONTEXT_5P:] | |
| row['Target'] = row['Target'] + 'N' * (TARGET_LEN - len(row['Target'])) | |
| else: | |
| row['Target'] = row['Target'][start_location - CONTEXT_5P:start_location + GUIDE_LEN + CONTEXT_3P] | |
| if row['Mismatches'] == 0 and 'N' not in row['Target']: | |
| assert row['Guide'] == sequence_complement([row['Target'][CONTEXT_5P:TARGET_LEN - CONTEXT_3P]])[0] | |
| # append new off-targets | |
| off_targets = pd.concat([off_targets, pd.DataFrame(dict_off_targets)]) | |
| # progress update | |
| if status_bar: | |
| status_text.text("Scanning for off-targets Percent complete: {:.2f}%".format(int(100 * min(i / len(reference_transcripts), 1)))) | |
| status_bar.progress(int(100 * min(i / len(reference_transcripts), 1))) | |
| print('\rPercent complete: {:.2f}%'.format(100 * min(i / len(reference_transcripts), 1)), end='') | |
| print('') | |
| return off_targets | |
| def predict_off_target(off_targets: pd.DataFrame, model: tf.keras.Model): | |
| if len(off_targets) == 0: | |
| return pd.DataFrame() | |
| # append predictions off-target predictions | |
| model_inputs = tf.concat([ | |
| tf.reshape(one_hot_encode_sequence(off_targets['Target'], add_context_padding=False), [len(off_targets), -1]), | |
| tf.reshape(one_hot_encode_sequence(off_targets['Guide'], add_context_padding=True), [len(off_targets), -1]), | |
| ], axis=-1) | |
| off_targets['Normalized LFC'] = model.predict(model_inputs, batch_size=BATCH_SIZE_COMPUTE, verbose=False) | |
| return off_targets.sort_values('Normalized LFC') | |
| def tiger_exhibit(transcripts: pd.DataFrame, status_bar=None, status_text=None, check_off_targets=False): | |
| # load model | |
| if os.path.exists('model'): | |
| tiger = tf.keras.models.load_model('model') | |
| else: | |
| print('no saved model!') | |
| exit() | |
| # find top guides for each transcript | |
| print('Finding top guides for each transcript') | |
| on_target_predictions = pd.DataFrame(columns=['On-target ID', 'Guide', 'Normalized LFC']) | |
| for i, (index, row) in enumerate(transcripts.iterrows()): | |
| df = predict_on_target(row['seq'], model=tiger) | |
| df['On-target ID'] = index | |
| on_target_predictions = pd.concat([on_target_predictions, df.iloc[:NUM_TOP_GUIDES]]) | |
| # progress update | |
| if status_bar: | |
| status_text.text("Scanning for on-targets Percent complete: {:.2f}%".format(100 * min((i + 1) / len(transcripts), 1))) | |
| status_bar.progress(int(100 * min((i + 1) / len(transcripts), 1))) | |
| print('\rPercent complete: {:.2f}%'.format(100 * min((i + 1) / len(transcripts), 1)), end='') | |
| print('') | |
| # predict off-target effects for top guides | |
| off_target_predictions = pd.DataFrame() | |
| if check_off_targets: | |
| off_targets = find_off_targets(on_target_predictions, status_bar, status_text) | |
| off_target_predictions = predict_off_target(off_targets, model=tiger) | |
| # reverse guide sequences | |
| on_target_predictions['Guide'] = on_target_predictions['Guide'].apply(lambda s: s[::-1]) | |
| if check_off_targets and len(off_target_predictions) > 0: | |
| off_target_predictions['Guide'] = off_target_predictions['Guide'].apply(lambda s: s[::-1]) | |
| return on_target_predictions.reset_index(drop=True), off_target_predictions.reset_index(drop=True) | |
| if __name__ == '__main__': | |
| # common arguments | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--check_off_targets', action='store_true', default=False) | |
| parser.add_argument('--fasta_path', type=str, default=None) | |
| parser.add_argument('--simple_test', action='store_true', default=False) | |
| args = parser.parse_args() | |
| # simple test case | |
| if args.simple_test: | |
| # first 50 from EIF3B-003's CDS | |
| simple_test = pd.DataFrame(dict(id=['ManualEntry'], seq=['ATGCAGGACGCGGAGAACGTGGCGGTGCCCGAGGCGGCCGAGGAGCGCGC'])) | |
| simple_test.set_index('id', inplace=True) | |
| df_on_target, df_off_target = tiger_exhibit(simple_test, check_off_targets=args.check_off_targets) | |
| df_on_target.to_csv('on_target.csv') | |
| if args.check_off_targets: | |
| df_off_target.to_csv('off_target.csv') | |
| # directory of fasta files | |
| elif args.fasta_path is not None and os.path.exists(args.fasta_path): | |
| # check for any existing results | |
| if os.path.exists('on_target.csv') or os.path.exists('off_target.csv'): | |
| raise FileExistsError('please rename or delete existing results') | |
| # load transcripts | |
| df_transcripts = load_transcripts([os.path.join(args.fasta_path, f) for f in os.listdir(args.fasta_path)]) | |
| # process in batches | |
| batch = 0 | |
| num_batches = len(df_transcripts) // BATCH_SIZE_TRANSCRIPTS | |
| num_batches += (len(df_transcripts) % BATCH_SIZE_TRANSCRIPTS > 0) | |
| for idx in range(0, len(df_transcripts), BATCH_SIZE_TRANSCRIPTS): | |
| batch += 1 | |
| print('Batch {:d} of {:d}'.format(batch, num_batches)) | |
| # run batch | |
| idx_stop = min(idx + BATCH_SIZE_TRANSCRIPTS, len(df_transcripts)) | |
| df_on_target, df_off_target = tiger_exhibit(df_transcripts[idx:idx_stop], check_off_targets=args.check_off_targets) | |
| # save batch results | |
| df_on_target.to_csv('on_target.csv', header=batch == 1, index=False, mode='a') | |
| if args.check_off_targets: | |
| df_off_target.to_csv('off_target.csv', header=batch == 1, index=False, mode='a') | |
| # clear session to prevent memory blow up | |
| tf.keras.backend.clear_session() | |