File size: 8,238 Bytes
d78d0d1
89be9f9
5fc4e72
1ef81e0
89be9f9
1ef81e0
eac7d3f
89be9f9
 
59874d6
 
 
a690e02
457a981
eac7d3f
 
d78d0d1
89be9f9
e38af10
 
 
 
 
 
89be9f9
d78d0d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eac7d3f
 
34274e5
 
eac7d3f
34274e5
eac7d3f
 
457a981
 
34274e5
 
 
 
 
eac7d3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e16292
eac7d3f
 
 
457a981
 
 
eac7d3f
 
457a981
34274e5
457a981
89be9f9
 
de06d10
59874d6
457a981
 
814d067
 
de06d10
457a981
1ef81e0
814d067
 
 
 
e38af10
de06d10
 
d78d0d1
1ef81e0
 
 
 
e38af10
1ef81e0
 
 
 
 
d78d0d1
1ef81e0
d78d0d1
1ef81e0
 
 
 
e38af10
1ef81e0
73c24bb
1ef81e0
 
 
 
 
 
 
 
 
d78d0d1
1ef81e0
 
 
 
 
a690e02
 
 
 
 
 
 
 
 
 
1ef81e0
 
 
 
 
 
de06d10
350befe
 
de06d10
 
 
 
 
 
 
 
6ec253f
de06d10
 
4e66e6b
814d067
de06d10
 
 
 
 
 
 
4e66e6b
de06d10
1ef81e0
4e66e6b
 
1ef81e0
de06d10
4e66e6b
3a14dfb
4e66e6b
3a14dfb
4e66e6b
 
 
 
 
cbbf8c9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import argparse
import os
import gzip
import numpy as np
import pandas as pd
import tensorflow as tf
from Bio import SeqIO

GUIDE_LEN = 23
CONTEXT_5P = 3
CONTEXT_3P = 0
TARGET_LEN = CONTEXT_5P + GUIDE_LEN + CONTEXT_3P
NUCLEOTIDE_TOKENS = dict(zip(['A', 'C', 'G', 'T', 'N'], [0, 1, 2, 3, 255]))
NUCLEOTIDE_COMPLEMENT = dict(zip(['A', 'C', 'G', 'T'], ['T', 'G', 'C', 'A']))
NUM_TOP_GUIDES = 10
NUM_MISMATCHES = 3
REFERENCE_TRANSCRIPTS = ('gencode.v19.pc_transcripts.fa.gz', 'gencode.v19.lncRNA_transcripts.fa.gz')

# configure GPUs
for gpu in tf.config.list_physical_devices('GPU'):
    tf.config.experimental.set_memory_growth(gpu, enable=True)
if len(tf.config.list_physical_devices('GPU')) > 0:
    tf.config.experimental.set_visible_devices(tf.config.list_physical_devices('GPU')[0], 'GPU')


def load_transcripts(fasta_files):

    # load all transcripts from fasta files into a DataFrame
    transcripts = pd.DataFrame()
    for file in fasta_files:
        try:
            if os.path.splitext(file)[1] == '.gz':
                with gzip.open(file, 'rt') as f:
                    df = pd.DataFrame([(t.id, str(t.seq)) for t in SeqIO.parse(f, 'fasta')], columns=['id', 'seq'])
            else:
                df = pd.DataFrame([(t.id, str(t.seq)) for t in SeqIO.parse(f, 'fasta')], columns=['id', 'seq'])
        except Exception as e:
            print(e, 'while loading', file)
            continue
        transcripts = pd.concat([transcripts, df])

    # set index
    transcripts['id'] = transcripts['id'].apply(lambda s: s.split('|')[0])
    transcripts.set_index('id', inplace=True)
    assert not transcripts.index.has_duplicates

    return transcripts


def sequence_complement(sequence: list):
    return [''.join([NUCLEOTIDE_COMPLEMENT[nt] for nt in list(seq)]) for seq in sequence]


def one_hot_encode_sequence(sequence: list, add_context_padding: bool = False):

    # stack list of sequences into a tensor
    sequence = tf.ragged.stack([tf.constant(list(seq)) for seq in sequence], axis=0)

    # tokenize sequence
    nucleotide_table = tf.lookup.StaticVocabularyTable(
        initializer=tf.lookup.KeyValueTensorInitializer(
            keys=tf.constant(list(NUCLEOTIDE_TOKENS.keys()), dtype=tf.string),
            values=tf.constant(list(NUCLEOTIDE_TOKENS.values()), dtype=tf.int64)),
        num_oov_buckets=1)
    sequence = tf.RaggedTensor.from_row_splits(values=nucleotide_table.lookup(sequence.values),
                                               row_splits=sequence.row_splits).to_tensor(255)

    # add context padding if requested
    if add_context_padding:
        pad_5p = 255 * tf.ones([sequence.shape[0], CONTEXT_5P], dtype=sequence.dtype)
        pad_3p = 255 * tf.ones([sequence.shape[0], CONTEXT_3P], dtype=sequence.dtype)
        sequence = tf.concat([pad_5p, sequence, pad_3p], axis=1)

    # one-hot encode
    sequence = tf.one_hot(sequence, depth=4)

    return sequence


def process_data(transcript_seq: str):

    # convert to upper case
    transcript_seq = transcript_seq.upper()

    # get all target sites
    target_seq = [transcript_seq[i: i + TARGET_LEN] for i in range(len(transcript_seq) - TARGET_LEN + 1)]

    # prepare guide sequences
    guide_seq = sequence_complement([seq[CONTEXT_5P:len(seq) - CONTEXT_3P] for seq in target_seq])

    # model inputs
    model_inputs = tf.concat([
        tf.reshape(one_hot_encode_sequence(target_seq, add_context_padding=False), [len(target_seq), -1]),
        tf.reshape(one_hot_encode_sequence(guide_seq, add_context_padding=True), [len(guide_seq), -1]),
        ], axis=-1)

    return target_seq, guide_seq, model_inputs


def predict_on_target(transcript_seq: str, model: tf.keras.Model):

    # parse transcript sequence
    target_seq, guide_seq, model_inputs = process_data(transcript_seq)

    # get predictions
    normalized_lfc = model.predict_step(model_inputs)
    predictions = pd.DataFrame({'Guide': guide_seq, 'Normalized LFC': tf.squeeze(normalized_lfc).numpy()})
    predictions = predictions.set_index('Guide').sort_values('Normalized LFC')

    return predictions


def find_off_targets(guides, batch_size=500):

    # load reference transcripts
    reference_transcripts = load_transcripts([os.path.join('transcripts', f) for f in REFERENCE_TRANSCRIPTS])

    # one-hot encode guides to form a filter
    guide_filter = one_hot_encode_sequence(sequence_complement(guides), add_context_padding=False)
    guide_filter = tf.transpose(guide_filter, [1, 2, 0])
    guide_filter = tf.cast(guide_filter, tf.float16)

    # loop over transcripts in batches
    i = 0
    print('Scanning for off-targets')
    df_off_targets = pd.DataFrame()
    while i < len(reference_transcripts):
        # select batch
        df_batch = reference_transcripts.iloc[i:min(i + batch_size, len(reference_transcripts))]
        i += batch_size

        # find and log off-targets
        transcripts = one_hot_encode_sequence(df_batch['seq'].values.tolist(), add_context_padding=False)
        transcripts = tf.cast(transcripts, guide_filter.dtype)
        num_mismatches = GUIDE_LEN - tf.nn.conv1d(transcripts, guide_filter, stride=1, padding='SAME')
        loc_off_targets = tf.where(tf.round(num_mismatches) <= NUM_MISMATCHES).numpy()
        df_off_targets = pd.concat([df_off_targets, pd.DataFrame({
            'Guide': np.array(guides)[loc_off_targets[:, 2]],
            'Isoform': df_batch.index.values[loc_off_targets[:, 0]],
            'Mismatches': tf.gather_nd(num_mismatches, loc_off_targets).numpy().astype(int),
            'Midpoint': loc_off_targets[:, 1],
            'Target': df_batch['seq'].values[loc_off_targets[:, 0]],
        })])

        # progress update
        print('\rPercent complete: {:.2f}%'.format(100 * min(i / len(reference_transcripts), 1)), end='')
    print('')

    # trim transcripts to targets
    dict_off_targets = df_off_targets.to_dict('records')
    for row in dict_off_targets:
        start_location = row['Midpoint'] - (GUIDE_LEN // 2)
        if start_location < CONTEXT_5P:
            row['Target'] = row['Target'][0:GUIDE_LEN + CONTEXT_3P]
            row['Target'] = 'N' * (TARGET_LEN - len(row['Target'])) + row['Target']
        elif start_location + GUIDE_LEN + CONTEXT_3P > len(row['Target']):
            row['Target'] = row['Target'][start_location - CONTEXT_5P:]
            row['Target'] = row['Target'] + 'N' * (TARGET_LEN - len(row['Target']))
        else:
            row['Target'] = row['Target'][start_location - CONTEXT_5P:start_location + GUIDE_LEN + CONTEXT_3P]
        if row['Mismatches'] == 0 and 'N' not in row['Target']:
            assert row['Guide'] == sequence_complement([row['Target'][CONTEXT_5P:TARGET_LEN-CONTEXT_3P]])[0]
    df_off_targets = pd.DataFrame(dict_off_targets)

    return df_off_targets


def predict_off_target(off_targets: pd.DataFrame, model: tf.keras.Model):
    if len(off_targets) == 0:
        return pd.DataFrame()

    # append predictions off-target predictions
    model_inputs = tf.concat([
        tf.reshape(one_hot_encode_sequence(off_targets['Target'], add_context_padding=False), [len(off_targets), -1]),
        tf.reshape(one_hot_encode_sequence(off_targets['Guide'], add_context_padding=True), [len(off_targets), -1]),
        ], axis=-1)
    off_targets['Normalized LFC'] = model.predict_step(model_inputs)

    return off_targets.set_index('Guide').sort_values('Normalized LFC')


def tiger_exhibit(transcript):

    # load model
    if os.path.exists('model'):
        tiger = tf.keras.models.load_model('model')
    else:
        print('no saved model!')
        exit()

    # on-target predictions
    on_target_predictions = predict_on_target(transcript, model=tiger)

    # keep only top guides
    on_target_predictions = on_target_predictions.iloc[:NUM_TOP_GUIDES]

    # predict off-target effects for top guides
    off_targets = find_off_targets(on_target_predictions.index.values.tolist())
    off_target_predictions = predict_off_target(off_targets, model=tiger)

    return on_target_predictions, off_target_predictions


if __name__ == '__main__':

    # simple test case
    print(tiger_exhibit('ATGCAGGACGCGGAGAACGTGGCGGTGCCCGAGGCGGCCGAGGAGCGCGC'.lower()))  # first 50 from EIF3B-003's CDS