File size: 12,083 Bytes
d78d0d1
89be9f9
5fc4e72
89be9f9
1ef81e0
eac7d3f
89be9f9
 
59874d6
 
 
a690e02
457a981
eac7d3f
 
d78d0d1
f311bf4
 
 
89be9f9
e38af10
 
 
 
 
 
a2591b6
d78d0d1
 
 
 
 
 
 
 
 
 
6216d88
d78d0d1
 
 
 
 
 
 
 
f311bf4
d78d0d1
 
 
 
eac7d3f
 
34274e5
 
eac7d3f
34274e5
eac7d3f
 
457a981
 
34274e5
 
 
 
 
eac7d3f
 
 
 
 
 
 
 
 
 
f311bf4
eac7d3f
 
 
 
 
 
 
 
 
 
1e16292
eac7d3f
 
 
457a981
 
 
eac7d3f
 
457a981
 
89be9f9
 
de06d10
59874d6
457a981
 
814d067
 
f311bf4
457a981
f57c1f6
814d067
 
 
 
89ee386
de06d10
 
d78d0d1
1ef81e0
 
f57c1f6
1ef81e0
 
 
 
 
f57c1f6
d78d0d1
1ef81e0
f311bf4
 
1ef81e0
f311bf4
1ef81e0
 
73c24bb
f311bf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ef81e0
 
9f169cd
 
 
d78d0d1
1ef81e0
 
f57c1f6
1ef81e0
 
de06d10
350befe
 
de06d10
 
 
 
 
 
f311bf4
de06d10
f57c1f6
de06d10
 
a2591b6
814d067
de06d10
 
 
 
 
 
 
f57c1f6
f311bf4
f57c1f6
f311bf4
f57c1f6
 
 
1ef81e0
f311bf4
9f169cd
 
 
f311bf4
 
 
de06d10
99a0d01
a2591b6
99a0d01
 
 
 
 
a2591b6
99a0d01
4e66e6b
f57c1f6
4e66e6b
 
 
 
f57c1f6
 
a2591b6
f57c1f6
 
 
 
4e66e6b
f57c1f6
 
89ffb34
f57c1f6
a2591b6
f57c1f6
fda1178
 
f57c1f6
f311bf4
 
 
a2591b6
 
 
 
f311bf4
 
 
 
0450f78
f311bf4
 
0450f78
 
f311bf4
0450f78
 
 
a2591b6
0450f78
 
a2591b6
 
 
f311bf4
0450f78
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
import argparse
import os
import gzip
import pandas as pd
import tensorflow as tf
from Bio import SeqIO

GUIDE_LEN = 23
CONTEXT_5P = 3
CONTEXT_3P = 0
TARGET_LEN = CONTEXT_5P + GUIDE_LEN + CONTEXT_3P
NUCLEOTIDE_TOKENS = dict(zip(['A', 'C', 'G', 'T', 'N'], [0, 1, 2, 3, 255]))
NUCLEOTIDE_COMPLEMENT = dict(zip(['A', 'C', 'G', 'T'], ['T', 'G', 'C', 'A']))
NUM_TOP_GUIDES = 10
NUM_MISMATCHES = 3
REFERENCE_TRANSCRIPTS = ('gencode.v19.pc_transcripts.fa.gz', 'gencode.v19.lncRNA_transcripts.fa.gz')
BATCH_SIZE_COMPUTE = 500
BATCH_SIZE_SCAN = 20
BATCH_SIZE_TRANSCRIPTS = 50

# configure GPUs
for gpu in tf.config.list_physical_devices('GPU'):
    tf.config.experimental.set_memory_growth(gpu, enable=True)
if len(tf.config.list_physical_devices('GPU')) > 0:
    tf.config.experimental.set_visible_devices(tf.config.list_physical_devices('GPU')[0], 'GPU')


def load_transcripts(fasta_files):

    # load all transcripts from fasta files into a DataFrame
    transcripts = pd.DataFrame()
    for file in fasta_files:
        try:
            if os.path.splitext(file)[1] == '.gz':
                with gzip.open(file, 'rt') as f:
                    df = pd.DataFrame([(t.id, str(t.seq)) for t in SeqIO.parse(f, 'fasta')], columns=['id', 'seq'])
            else:
                df = pd.DataFrame([(t.id, str(t.seq)) for t in SeqIO.parse(file, 'fasta')], columns=['id', 'seq'])
        except Exception as e:
            print(e, 'while loading', file)
            continue
        transcripts = pd.concat([transcripts, df])

    # set index
    transcripts['id'] = transcripts['id'].apply(lambda s: s.split('|')[0])
    transcripts.set_index('id', inplace=True)
    assert not transcripts.index.has_duplicates, "duplicate transcript ID's detected"

    return transcripts


def sequence_complement(sequence: list):
    return [''.join([NUCLEOTIDE_COMPLEMENT[nt] for nt in list(seq)]) for seq in sequence]


def one_hot_encode_sequence(sequence: list, add_context_padding: bool = False):

    # stack list of sequences into a tensor
    sequence = tf.ragged.stack([tf.constant(list(seq)) for seq in sequence], axis=0)

    # tokenize sequence
    nucleotide_table = tf.lookup.StaticVocabularyTable(
        initializer=tf.lookup.KeyValueTensorInitializer(
            keys=tf.constant(list(NUCLEOTIDE_TOKENS.keys()), dtype=tf.string),
            values=tf.constant(list(NUCLEOTIDE_TOKENS.values()), dtype=tf.int64)),
        num_oov_buckets=1)
    sequence = tf.RaggedTensor.from_row_splits(values=nucleotide_table.lookup(sequence.values),
                                               row_splits=sequence.row_splits).to_tensor(255)

    # add context padding if requested
    if add_context_padding:
        pad_5p = 255 * tf.ones([sequence.shape[0], CONTEXT_5P], dtype=sequence.dtype)
        pad_3p = 255 * tf.ones([sequence.shape[0], CONTEXT_3P], dtype=sequence.dtype)
        sequence = tf.concat([pad_5p, sequence, pad_3p], axis=1)

    # one-hot encode
    sequence = tf.one_hot(sequence, depth=4, dtype=tf.float16)

    return sequence


def process_data(transcript_seq: str):

    # convert to upper case
    transcript_seq = transcript_seq.upper()

    # get all target sites
    target_seq = [transcript_seq[i: i + TARGET_LEN] for i in range(len(transcript_seq) - TARGET_LEN + 1)]

    # prepare guide sequences
    guide_seq = sequence_complement([seq[CONTEXT_5P:len(seq) - CONTEXT_3P] for seq in target_seq])

    # model inputs
    model_inputs = tf.concat([
        tf.reshape(one_hot_encode_sequence(target_seq, add_context_padding=False), [len(target_seq), -1]),
        tf.reshape(one_hot_encode_sequence(guide_seq, add_context_padding=True), [len(guide_seq), -1]),
        ], axis=-1)
    return target_seq, guide_seq, model_inputs


def predict_on_target(transcript_seq: str, model: tf.keras.Model):

    # parse transcript sequence
    target_seq, guide_seq, model_inputs = process_data(transcript_seq)

    # get predictions
    normalized_lfc = model.predict(model_inputs, batch_size=BATCH_SIZE_COMPUTE, verbose=False)
    predictions = pd.DataFrame({'Guide': guide_seq, 'Normalized LFC': tf.squeeze(normalized_lfc).numpy()})
    predictions = predictions.sort_values('Normalized LFC')

    return predictions


def find_off_targets(top_guides: pd.DataFrame, status_bar, status_text):

    # load reference transcripts
    reference_transcripts = load_transcripts([os.path.join('transcripts', f) for f in REFERENCE_TRANSCRIPTS])

    # one-hot encode guides to form a filter
    guide_filter = one_hot_encode_sequence(sequence_complement(top_guides['Guide']), add_context_padding=False)
    guide_filter = tf.transpose(guide_filter, [1, 2, 0])

    # loop over transcripts in batches
    i = 0
    print('Scanning for off-targets')
    off_targets = pd.DataFrame()
    while i < len(reference_transcripts):
        # select batch
        df_batch = reference_transcripts.iloc[i:min(i + BATCH_SIZE_SCAN, len(reference_transcripts))]
        i += BATCH_SIZE_SCAN

        # find locations of off-targets
        transcripts = one_hot_encode_sequence(df_batch['seq'].values.tolist(), add_context_padding=False)
        num_mismatches = GUIDE_LEN - tf.nn.conv1d(transcripts, guide_filter, stride=1, padding='SAME')
        loc_off_targets = tf.where(tf.round(num_mismatches) <= NUM_MISMATCHES).numpy()

        # off-targets discovered
        if len(loc_off_targets) > 0:

            # log off-targets
            dict_off_targets = pd.DataFrame({
                'On-target ID': top_guides.iloc[loc_off_targets[:, 2]]['On-target ID'],
                'Guide': top_guides.iloc[loc_off_targets[:, 2]]['Guide'],
                'Off-target ID': df_batch.index.values[loc_off_targets[:, 0]],
                'Target': df_batch['seq'].values[loc_off_targets[:, 0]],
                'Mismatches': tf.gather_nd(num_mismatches, loc_off_targets).numpy().astype(int),
                'Midpoint': loc_off_targets[:, 1],
            }).to_dict('records')

            # trim transcripts to targets
            for row in dict_off_targets:
                start_location = row['Midpoint'] - (GUIDE_LEN // 2)
                if start_location < CONTEXT_5P:
                    row['Target'] = row['Target'][0:GUIDE_LEN + CONTEXT_3P]
                    row['Target'] = 'N' * (TARGET_LEN - len(row['Target'])) + row['Target']
                elif start_location + GUIDE_LEN + CONTEXT_3P > len(row['Target']):
                    row['Target'] = row['Target'][start_location - CONTEXT_5P:]
                    row['Target'] = row['Target'] + 'N' * (TARGET_LEN - len(row['Target']))
                else:
                    row['Target'] = row['Target'][start_location - CONTEXT_5P:start_location + GUIDE_LEN + CONTEXT_3P]
                if row['Mismatches'] == 0 and 'N' not in row['Target']:
                    assert row['Guide'] == sequence_complement([row['Target'][CONTEXT_5P:TARGET_LEN - CONTEXT_3P]])[0]

            # append new off-targets
            off_targets = pd.concat([off_targets, pd.DataFrame(dict_off_targets)])

        # progress update
        if status_bar:
            status_text.text("Scanning for off-targets Percent complete: {:.2f}%".format(int(100 * min(i / len(reference_transcripts), 1))))
            status_bar.progress(int(100 * min(i / len(reference_transcripts), 1)))
        print('\rPercent complete: {:.2f}%'.format(100 * min(i / len(reference_transcripts), 1)), end='')
    print('')

    return off_targets


def predict_off_target(off_targets: pd.DataFrame, model: tf.keras.Model):
    if len(off_targets) == 0:
        return pd.DataFrame()

    # append predictions off-target predictions
    model_inputs = tf.concat([
        tf.reshape(one_hot_encode_sequence(off_targets['Target'], add_context_padding=False), [len(off_targets), -1]),
        tf.reshape(one_hot_encode_sequence(off_targets['Guide'], add_context_padding=True), [len(off_targets), -1]),
        ], axis=-1)
    off_targets['Normalized LFC'] = model.predict(model_inputs, batch_size=BATCH_SIZE_COMPUTE, verbose=False)

    return off_targets.sort_values('Normalized LFC')


def tiger_exhibit(transcripts: pd.DataFrame, status_bar=None, status_text=None, check_off_targets=False):

    # load model
    if os.path.exists('model'):
        tiger = tf.keras.models.load_model('model')
    else:
        print('no saved model!')
        exit()

    # find top guides for each transcript
    print('Finding top guides for each transcript')
    on_target_predictions = pd.DataFrame(columns=['On-target ID', 'Guide', 'Normalized LFC'])
    for i, (index, row) in enumerate(transcripts.iterrows()):
        df = predict_on_target(row['seq'], model=tiger)
        df['On-target ID'] = index
        on_target_predictions = pd.concat([on_target_predictions, df.iloc[:NUM_TOP_GUIDES]])

        # progress update
        if status_bar:
            status_text.text("Scanning for on-targets Percent complete: {:.2f}%".format(100 * min((i + 1) / len(transcripts), 1)))
            status_bar.progress(int(100 * min((i + 1) / len(transcripts), 1)))
        print('\rPercent complete: {:.2f}%'.format(100 * min((i + 1) / len(transcripts), 1)), end='')
    print('')

    # predict off-target effects for top guides
    off_target_predictions = pd.DataFrame()
    if check_off_targets:
        off_targets = find_off_targets(on_target_predictions,  status_bar, status_text)
        off_target_predictions = predict_off_target(off_targets, model=tiger)

    # reverse guide sequences
    on_target_predictions['Guide'] = on_target_predictions['Guide'].apply(lambda s: s[::-1])
    if check_off_targets and len(off_target_predictions) > 0:
        off_target_predictions['Guide'] = off_target_predictions['Guide'].apply(lambda s: s[::-1])

    return on_target_predictions.reset_index(drop=True), off_target_predictions.reset_index(drop=True)


if __name__ == '__main__':

    # common arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--check_off_targets', action='store_true', default=False)
    parser.add_argument('--fasta_path', type=str, default=None)
    parser.add_argument('--simple_test', action='store_true', default=False)
    args = parser.parse_args()

    # simple test case
    if args.simple_test:
        # first 50 from EIF3B-003's CDS
        simple_test = pd.DataFrame(dict(id=['ManualEntry'], seq=['ATGCAGGACGCGGAGAACGTGGCGGTGCCCGAGGCGGCCGAGGAGCGCGC']))
        simple_test.set_index('id', inplace=True)
        df_on_target, df_off_target = tiger_exhibit(simple_test, check_off_targets=args.check_off_targets)
        df_on_target.to_csv('on_target.csv')
        if args.check_off_targets:
            df_off_target.to_csv('off_target.csv')

    # directory of fasta files
    elif args.fasta_path is not None and os.path.exists(args.fasta_path):

        # check for any existing results
        if os.path.exists('on_target.csv') or os.path.exists('off_target.csv'):
            raise FileExistsError('please rename or delete existing results')

        # load transcripts
        df_transcripts = load_transcripts([os.path.join(args.fasta_path, f) for f in os.listdir(args.fasta_path)])

        # process in batches
        batch = 0
        num_batches = len(df_transcripts) // BATCH_SIZE_TRANSCRIPTS
        num_batches += (len(df_transcripts) % BATCH_SIZE_TRANSCRIPTS > 0)
        for idx in range(0, len(df_transcripts), BATCH_SIZE_TRANSCRIPTS):
            batch += 1
            print('Batch {:d} of {:d}'.format(batch, num_batches))

            # run batch
            idx_stop = min(idx + BATCH_SIZE_TRANSCRIPTS, len(df_transcripts))
            df_on_target, df_off_target = tiger_exhibit(df_transcripts[idx:idx_stop], check_off_targets=args.check_off_targets)

            # save batch results
            df_on_target.to_csv('on_target.csv', header=batch == 1, index=False, mode='a')
            if args.check_off_targets:
                df_off_target.to_csv('off_target.csv', header=batch == 1, index=False, mode='a')

            # clear session to prevent memory blow up
            tf.keras.backend.clear_session()