Andrew Stirn commited on
Commit
814d067
·
1 Parent(s): cb8c873

cleanup and test cases

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. run.py → tiger.py +17 -15
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import streamlit as st
2
- from run import tiger_predict, GUIDE_LEN, NUCLEOTIDE_TOKENS
3
 
4
 
5
  @st.cache
 
1
  import streamlit as st
2
+ from tiger import tiger_predict, GUIDE_LEN, NUCLEOTIDE_TOKENS
3
 
4
 
5
  @st.cache
run.py → tiger.py RENAMED
@@ -28,18 +28,20 @@ def process_data(x):
28
  return input_gens, model_input_x
29
 
30
 
31
- def gen_report_table(input_gens, res):
32
- res = res.numpy().flatten().tolist()
33
- # print("ftaltten res: ", res)
34
- data = {"Gene": input_gens, "res": res}
35
- tbl = pd.DataFrame.from_dict(data)
36
- return tbl
37
-
38
-
39
- def tiger_predict(x):
40
- input_gens, model_input_x = process_data(x)
41
- # print("input gene: ", input_gens)
42
- # print("model_input: ", model_input_x)
43
- res = tiger.predict_step(model_input_x)
44
- # print("res: ", res)
45
- return gen_report_table(input_gens, res)
 
 
 
28
  return input_gens, model_input_x
29
 
30
 
31
+ def tiger_predict(transcript_seq: str):
32
+ # parse transcript sequence into 23-nt target sequences and their one-hot encodings
33
+ target_seq, target_seq_one_hot = process_data(transcript_seq)
34
+
35
+ # get predictions
36
+ normalized_lfc = tiger.predict_step(target_seq_one_hot)
37
+ predictions = pd.DataFrame({'Target site': target_seq, 'Normalized LFC': tf.squeeze(normalized_lfc).numpy()})
38
+
39
+ return predictions
40
+
41
+
42
+ if __name__ == '__main__':
43
+
44
+ # simple test case
45
+ transcript_sequence = 'ACGTACGTACGTACGTACGTACGTACGTACGT'
46
+ df = tiger_predict(transcript_sequence)
47
+ print(df)