Andrew Stirn commited on
Commit
34274e5
·
1 Parent(s): 610b0ca

tensorflow tokenizer

Browse files
Files changed (1) hide show
  1. tiger.py +21 -14
tiger.py CHANGED
@@ -13,22 +13,29 @@ else:
13
  exit()
14
 
15
 
16
- def process_data(x):
17
- x = [item.upper() for item in x]
18
- number_of_input = len(x) - GUIDE_LEN + 1
19
- input_gens = []
20
- for i in range(number_of_input):
21
- input_gens.append("".join(x[i:i + GUIDE_LEN]))
22
- merged_token = []
23
- token_x = [NUCLEOTIDE_TOKENS[item] for item in x]
24
- for i in range(number_of_input):
25
- merged_token.extend(token_x[i:i + GUIDE_LEN])
26
- one_hot_x = tf.one_hot(merged_token, depth=4)
27
- model_input_x = tf.reshape(one_hot_x, [-1, GUIDE_LEN * 4])
28
- return input_gens, model_input_x
 
 
 
 
 
 
29
 
30
 
31
  def tiger_predict(transcript_seq: str):
 
32
  # parse transcript sequence into 23-nt target sequences and their one-hot encodings
33
  target_seq, target_seq_one_hot = process_data(transcript_seq)
34
 
@@ -42,6 +49,6 @@ def tiger_predict(transcript_seq: str):
42
  if __name__ == '__main__':
43
 
44
  # simple test case
45
- transcript_sequence = 'ACGTACGTACGTACGTACGTACGTACGTACGT'
46
  df = tiger_predict(transcript_sequence)
47
  print(df)
 
13
  exit()
14
 
15
 
16
+ def process_data(transcript_seq: str):
17
+
18
+ # convert to upper case
19
+ transcript_seq = transcript_seq.upper()
20
+
21
+ # get all target sites
22
+ num_target_sites = len(transcript_seq) - GUIDE_LEN + 1
23
+ target_seq = [transcript_seq[i:i + GUIDE_LEN] for i in range(num_target_sites)]
24
+
25
+ # get one-hot encodings
26
+ nucleotide_table = tf.lookup.StaticVocabularyTable(
27
+ initializer=tf.lookup.KeyValueTensorInitializer(
28
+ keys=tf.constant(list(NUCLEOTIDE_TOKENS.keys()), dtype=tf.string),
29
+ values=tf.constant(list(NUCLEOTIDE_TOKENS.values()), dtype=tf.int64)),
30
+ num_oov_buckets=1)
31
+ target_tokens = nucleotide_table.lookup(tf.stack([list(t) for t in target_seq], axis=0))
32
+ target_one_hot = tf.reshape(tf.one_hot(target_tokens, depth=4), [num_target_sites, -1])
33
+
34
+ return target_seq, target_one_hot
35
 
36
 
37
  def tiger_predict(transcript_seq: str):
38
+
39
  # parse transcript sequence into 23-nt target sequences and their one-hot encodings
40
  target_seq, target_seq_one_hot = process_data(transcript_seq)
41
 
 
49
  if __name__ == '__main__':
50
 
51
  # simple test case
52
+ transcript_sequence = 'ACGTACGTACGTACGTACGTACGTACGTACGT'.lower()
53
  df = tiger_predict(transcript_sequence)
54
  print(df)