tiger / run.py
Andrew Stirn
cleanup
cb8c873
raw
history blame
1.29 kB
import os
import tensorflow as tf
import pandas as pd
GUIDE_LEN = 23
NUCLEOTIDE_TOKENS = dict(zip(['A', 'C', 'G', 'T'], [0, 1, 2, 3]))
# load model
if os.path.exists('model'):
tiger = tf.keras.models.load_model('model')
else:
print('no saved model!')
exit()
def process_data(x):
x = [item.upper() for item in x]
number_of_input = len(x) - GUIDE_LEN + 1
input_gens = []
for i in range(number_of_input):
input_gens.append("".join(x[i:i + GUIDE_LEN]))
merged_token = []
token_x = [NUCLEOTIDE_TOKENS[item] for item in x]
for i in range(number_of_input):
merged_token.extend(token_x[i:i + GUIDE_LEN])
one_hot_x = tf.one_hot(merged_token, depth=4)
model_input_x = tf.reshape(one_hot_x, [-1, GUIDE_LEN * 4])
return input_gens, model_input_x
def gen_report_table(input_gens, res):
res = res.numpy().flatten().tolist()
# print("ftaltten res: ", res)
data = {"Gene": input_gens, "res": res}
tbl = pd.DataFrame.from_dict(data)
return tbl
def tiger_predict(x):
input_gens, model_input_x = process_data(x)
# print("input gene: ", input_gens)
# print("model_input: ", model_input_x)
res = tiger.predict_step(model_input_x)
# print("res: ", res)
return gen_report_table(input_gens, res)