Spaces:
Runtime error
Runtime error
| from os import truncate | |
| from sacremoses import MosesPunctNormalizer | |
| from sacremoses import MosesTokenizer | |
| from sacremoses import MosesDetokenizer | |
| from subword_nmt.apply_bpe import BPE, read_vocabulary | |
| import codecs | |
| from tqdm import tqdm | |
| from indicnlp.tokenize import indic_tokenize | |
| from indicnlp.tokenize import indic_detokenize | |
| from indicnlp.normalize import indic_normalize | |
| from indicnlp.transliterate import unicode_transliterate | |
| from mosestokenizer import MosesSentenceSplitter | |
| from indicnlp.tokenize import sentence_tokenize | |
| from inference.custom_interactive import Translator | |
| INDIC = ["as", "bn", "gu", "hi", "kn", "ml", "mr", "or", "pa", "ta", "te"] | |
| def split_sentences(paragraph, language): | |
| if language == "en": | |
| with MosesSentenceSplitter(language) as splitter: | |
| return splitter([paragraph]) | |
| elif language in INDIC: | |
| return sentence_tokenize.sentence_split(paragraph, lang=language) | |
| def add_token(sent, tag_infos): | |
| """add special tokens specified by tag_infos to each element in list | |
| tag_infos: list of tuples (tag_type,tag) | |
| each tag_info results in a token of the form: __{tag_type}__{tag}__ | |
| """ | |
| tokens = [] | |
| for tag_type, tag in tag_infos: | |
| token = "__" + tag_type + "__" + tag + "__" | |
| tokens.append(token) | |
| return " ".join(tokens) + " " + sent | |
| def apply_lang_tags(sents, src_lang, tgt_lang): | |
| tagged_sents = [] | |
| for sent in sents: | |
| tagged_sent = add_token(sent.strip(), [("src", src_lang), ("tgt", tgt_lang)]) | |
| tagged_sents.append(tagged_sent) | |
| return tagged_sents | |
| def truncate_long_sentences(sents): | |
| MAX_SEQ_LEN = 200 | |
| new_sents = [] | |
| for sent in sents: | |
| words = sent.split() | |
| num_words = len(words) | |
| if num_words > MAX_SEQ_LEN: | |
| print_str = " ".join(words[:5]) + " .... " + " ".join(words[-5:]) | |
| sent = " ".join(words[:MAX_SEQ_LEN]) | |
| print( | |
| f"WARNING: Sentence {print_str} truncated to 200 tokens as it exceeds maximum length limit" | |
| ) | |
| new_sents.append(sent) | |
| return new_sents | |
| class Model: | |
| def __init__(self, expdir): | |
| self.expdir = expdir | |
| self.en_tok = MosesTokenizer(lang="en") | |
| self.en_normalizer = MosesPunctNormalizer() | |
| self.en_detok = MosesDetokenizer(lang="en") | |
| self.xliterator = unicode_transliterate.UnicodeIndicTransliterator() | |
| print("Initializing vocab and bpe") | |
| self.vocabulary = read_vocabulary( | |
| codecs.open(f"{expdir}/vocab/vocab.SRC", encoding="utf-8"), 5 | |
| ) | |
| self.bpe = BPE( | |
| codecs.open(f"{expdir}/vocab/bpe_codes.32k.SRC", encoding="utf-8"), | |
| -1, | |
| "@@", | |
| self.vocabulary, | |
| None, | |
| ) | |
| print("Initializing model for translation") | |
| # initialize the model | |
| self.translator = Translator( | |
| f"{expdir}/final_bin", f"{expdir}/model/checkpoint_best.pt", batch_size=100 | |
| ) | |
| # translate a batch of sentences from src_lang to tgt_lang | |
| def batch_translate(self, batch, src_lang, tgt_lang): | |
| assert isinstance(batch, list) | |
| preprocessed_sents = self.preprocess(batch, lang=src_lang) | |
| bpe_sents = self.apply_bpe(preprocessed_sents) | |
| tagged_sents = apply_lang_tags(bpe_sents, src_lang, tgt_lang) | |
| tagged_sents = truncate_long_sentences(tagged_sents) | |
| translations = self.translator.translate(tagged_sents) | |
| postprocessed_sents = self.postprocess(translations, tgt_lang) | |
| return postprocessed_sents | |
| # translate a paragraph from src_lang to tgt_lang | |
| def translate_paragraph(self, paragraph, src_lang, tgt_lang): | |
| assert isinstance(paragraph, str) | |
| sents = split_sentences(paragraph, src_lang) | |
| postprocessed_sents = self.batch_translate(sents, src_lang, tgt_lang) | |
| translated_paragraph = " ".join(postprocessed_sents) | |
| return translated_paragraph | |
| def preprocess_sent(self, sent, normalizer, lang): | |
| if lang == "en": | |
| return " ".join( | |
| self.en_tok.tokenize( | |
| self.en_normalizer.normalize(sent.strip()), escape=False | |
| ) | |
| ) | |
| else: | |
| # line = indic_detokenize.trivial_detokenize(line.strip(), lang) | |
| return unicode_transliterate.UnicodeIndicTransliterator.transliterate( | |
| " ".join( | |
| indic_tokenize.trivial_tokenize( | |
| normalizer.normalize(sent.strip()), lang | |
| ) | |
| ), | |
| lang, | |
| "hi", | |
| ).replace(" ् ", "्") | |
| def preprocess(self, sents, lang): | |
| """ | |
| Normalize, tokenize and script convert(for Indic) | |
| return number of sentences input file | |
| """ | |
| if lang == "en": | |
| # processed_sents = Parallel(n_jobs=-1, backend="multiprocessing")( | |
| # delayed(preprocess_line)(line, None, lang) for line in tqdm(sents, total=num_lines) | |
| # ) | |
| processed_sents = [ | |
| self.preprocess_sent(line, None, lang) for line in tqdm(sents) | |
| ] | |
| else: | |
| normfactory = indic_normalize.IndicNormalizerFactory() | |
| normalizer = normfactory.get_normalizer(lang) | |
| # processed_sents = Parallel(n_jobs=-1, backend="multiprocessing")( | |
| # delayed(preprocess_line)(line, normalizer, lang) for line in tqdm(infile, total=num_lines) | |
| # ) | |
| processed_sents = [ | |
| self.preprocess_sent(line, normalizer, lang) for line in tqdm(sents) | |
| ] | |
| return processed_sents | |
| def postprocess(self, sents, lang, common_lang="hi"): | |
| """ | |
| parse fairseq interactive output, convert script back to native Indic script (in case of Indic languages) and detokenize. | |
| infname: fairseq log file | |
| outfname: output file of translation (sentences not translated contain the dummy string 'DUMMY_OUTPUT' | |
| input_size: expected number of output sentences | |
| lang: language | |
| """ | |
| postprocessed_sents = [] | |
| if lang == "en": | |
| for sent in sents: | |
| # outfile.write(en_detok.detokenize(sent.split(" ")) + "\n") | |
| postprocessed_sents.append(self.en_detok.detokenize(sent.split(" "))) | |
| else: | |
| for sent in sents: | |
| outstr = indic_detokenize.trivial_detokenize( | |
| self.xliterator.transliterate(sent, common_lang, lang), lang | |
| ) | |
| # outfile.write(outstr + "\n") | |
| postprocessed_sents.append(outstr) | |
| return postprocessed_sents | |
| def apply_bpe(self, sents): | |
| return [self.bpe.process_line(sent) for sent in sents] | |