Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import os | |
| import tensorflow as tf | |
| import pandas as pd | |
| GUIDE_LEN = 23 | |
| CONTEXT_5P = 3 | |
| CONTEXT_3P = 0 | |
| TARGET_LEN = CONTEXT_5P + GUIDE_LEN + CONTEXT_3P | |
| NUCLEOTIDE_TOKENS = dict(zip(['A', 'C', 'G', 'T'], [0, 1, 2, 3])) | |
| NUCLEOTIDE_COMPLEMENT = dict(zip(['A', 'C', 'G', 'T'], ['T', 'G', 'C', 'A'])) | |
| def process_data(transcript_seq: str): | |
| # convert to upper case | |
| transcript_seq = transcript_seq.upper() | |
| # get all target sites | |
| target_seq = [transcript_seq[i: i + TARGET_LEN] for i in range(len(transcript_seq) - TARGET_LEN)] | |
| # prepare guide sequences | |
| guide_seq = [seq[CONTEXT_5P:len(seq) - CONTEXT_3P] for seq in target_seq] | |
| guide_seq = [''.join([NUCLEOTIDE_COMPLEMENT[nt] for nt in list(seq)]) for seq in guide_seq] | |
| # tokenize sequence | |
| nucleotide_table = tf.lookup.StaticVocabularyTable( | |
| initializer=tf.lookup.KeyValueTensorInitializer( | |
| keys=tf.constant(list(NUCLEOTIDE_TOKENS.keys()), dtype=tf.string), | |
| values=tf.constant(list(NUCLEOTIDE_TOKENS.values()), dtype=tf.int64)), | |
| num_oov_buckets=1) | |
| target_tokens = nucleotide_table.lookup(tf.stack([list(t) for t in target_seq], axis=0)) | |
| guide_tokens = nucleotide_table.lookup(tf.stack([list(g) for g in guide_seq], axis=0)) | |
| pad_5p = 255 * tf.ones([guide_tokens.shape[0], CONTEXT_5P], dtype=guide_tokens.dtype) | |
| pad_3p = 255 * tf.ones([guide_tokens.shape[0], CONTEXT_3P], dtype=guide_tokens.dtype) | |
| guide_tokens = tf.concat([pad_5p, guide_tokens, pad_3p], axis=1) | |
| # model inputs | |
| model_inputs = tf.concat([ | |
| tf.reshape(tf.one_hot(target_tokens, depth=4), [len(target_seq), -1]), | |
| tf.reshape(tf.one_hot(guide_tokens, depth=4), [len(guide_tokens), -1]), | |
| ], axis=-1) | |
| return target_seq, guide_seq, model_inputs | |
| def tiger_predict(transcript_seq: str): | |
| # load model | |
| if os.path.exists('model'): | |
| tiger = tf.keras.models.load_model('model') | |
| else: | |
| print('no saved model!') | |
| exit() | |
| # parse transcript sequence | |
| target_seq, guide_seq, model_inputs = process_data(transcript_seq) | |
| # get predictions | |
| normalized_lfc = tiger.predict_step(model_inputs) | |
| predictions = pd.DataFrame({'Guide': guide_seq, 'Normalized LFC': tf.squeeze(normalized_lfc).numpy()}) | |
| return predictions | |
| if __name__ == '__main__': | |
| # simple test case | |
| transcript_sequence = 'ACGTACGTACGTACGTACGTACGTACGTACGT'.lower() | |
| df = tiger_predict(transcript_sequence) | |
| print(df) | |