Spaces:
Sleeping
Sleeping
| import re | |
| import sys | |
| import typing as tp | |
| import unicodedata | |
| import torch | |
| from sacremoses import MosesPunctNormalizer | |
| from sentence_splitter import SentenceSplitter | |
| from transformers import AutoModelForSeq2SeqLM, NllbTokenizer | |
| MODEL_URL = "slone/nllb-210-v1" | |
| LANGUAGES = { | |
| "Русский | Russian": "rus_Cyrl", | |
| "English | Английский": "eng_Latn", | |
| "Azərbaycan | Azerbaijani | Азербайджанский": "azj_Latn", | |
| "Башҡорт | Bashkir | Башкирский": "bak_Cyrl", | |
| "Буряад | Buryat | Бурятский": "bxr_Cyrl", | |
| "Чӑваш | Chuvash | Чувашский": "chv_Cyrl", | |
| "Хакас | Khakas | Хакасский": "kjh_Cyrl", | |
| "Къарачай-малкъар | Karachay-Balkar | Карачаево-балкарский": "krc_Cyrl", | |
| "Марий | Meadow Mari | Марийский": "mhr_Cyrl", | |
| "Эрзянь | Erzya | Эрзянский": "myv_Cyrl", | |
| "Татар | Tatar | Татарский": "tat_Cyrl", | |
| "Тыва | Тувинский | Tuvan ": "tyv_Cyrl", | |
| } | |
| L1 = "rus_Cyrl" | |
| L2 = "eng_Latn" | |
| def get_non_printing_char_replacer(replace_by: str = " ") -> tp.Callable[[str], str]: | |
| non_printable_map = { | |
| ord(c): replace_by | |
| for c in (chr(i) for i in range(sys.maxunicode + 1)) | |
| # same as \p{C} in perl | |
| # see https://www.unicode.org/reports/tr44/#General_Category_Values | |
| if unicodedata.category(c) in {"C", "Cc", "Cf", "Cs", "Co", "Cn"} | |
| } | |
| def replace_non_printing_char(line) -> str: | |
| return line.translate(non_printable_map) | |
| return replace_non_printing_char | |
| class TextPreprocessor: | |
| """ | |
| Mimic the text preprocessing made for the NLLB model. | |
| This code is adapted from the Stopes repo of the NLLB team: | |
| https://github.com/facebookresearch/stopes/blob/main/stopes/pipelines/monolingual/monolingual_line_processor.py#L214 | |
| """ | |
| def __init__(self, lang="en"): | |
| self.mpn = MosesPunctNormalizer(lang=lang) | |
| self.mpn.substitutions = [ | |
| (re.compile(r), sub) for r, sub in self.mpn.substitutions | |
| ] | |
| self.replace_nonprint = get_non_printing_char_replacer(" ") | |
| def __call__(self, text: str) -> str: | |
| clean = self.mpn.normalize(text) | |
| clean = self.replace_nonprint(clean) | |
| # replace 𝓕𝔯𝔞𝔫𝔠𝔢𝔰𝔠𝔞 by Francesca | |
| clean = unicodedata.normalize("NFKC", clean) | |
| return clean | |
| def sentenize_with_fillers(text, splitter, fix_double_space=True, ignore_errors=False): | |
| """Apply a sentence splitter and return the sentences and all separators before and after them""" | |
| if fix_double_space: | |
| text = re.sub(" +", " ", text) | |
| sentences = splitter.split(text) | |
| fillers = [] | |
| i = 0 | |
| for sentence in sentences: | |
| start_idx = text.find(sentence, i) | |
| if ignore_errors and start_idx == -1: | |
| # print(f"sent not found after {i}: `{sentence}`") | |
| start_idx = i + 1 | |
| assert start_idx != -1, f"sent not found after {i}: `{sentence}`" | |
| fillers.append(text[i:start_idx]) | |
| i = start_idx + len(sentence) | |
| fillers.append(text[i:]) | |
| return sentences, fillers | |
| class Translator: | |
| def __init__(self): | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_URL, low_cpu_mem_usage=True) | |
| if torch.cuda.is_available(): | |
| self.model.cuda() | |
| self.tokenizer = NllbTokenizer.from_pretrained(MODEL_URL) | |
| self.splitter = SentenceSplitter("ru") | |
| self.preprocessor = TextPreprocessor() | |
| self.languages = LANGUAGES | |
| def translate( | |
| self, | |
| text, | |
| src_lang=L1, | |
| tgt_lang=L2, | |
| max_length="auto", | |
| num_beams=4, | |
| by_sentence=True, | |
| preprocess=True, | |
| **kwargs, | |
| ): | |
| """Translate a text sentence by sentence, preserving the fillers around the sentences.""" | |
| if by_sentence: | |
| sents, fillers = sentenize_with_fillers( | |
| text, splitter=self.splitter, ignore_errors=True | |
| ) | |
| else: | |
| sents = [text] | |
| fillers = ["", ""] | |
| if preprocess: | |
| sents = [self.preprocessor(sent) for sent in sents] | |
| results = [] | |
| for sent, sep in zip(sents, fillers): | |
| results.append(sep) | |
| results.append( | |
| self.translate_single( | |
| sent, | |
| src_lang=src_lang, | |
| tgt_lang=tgt_lang, | |
| max_length=max_length, | |
| num_beams=num_beams, | |
| **kwargs, | |
| ) | |
| ) | |
| results.append(fillers[-1]) | |
| return "".join(results) | |
| def translate_single( | |
| self, | |
| text, | |
| src_lang=L1, | |
| tgt_lang=L2, | |
| max_length="auto", | |
| num_beams=4, | |
| n_out=None, | |
| **kwargs, | |
| ): | |
| self.tokenizer.src_lang = src_lang | |
| encoded = self.tokenizer( | |
| text, return_tensors="pt", truncation=True, max_length=512 | |
| ) | |
| if max_length == "auto": | |
| max_length = int(32 + 2.0 * encoded.input_ids.shape[1]) | |
| generated_tokens = self.model.generate( | |
| **encoded.to(self.model.device), | |
| forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang], | |
| max_length=max_length, | |
| num_beams=num_beams, | |
| num_return_sequences=n_out or 1, | |
| **kwargs, | |
| ) | |
| out = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) | |
| if isinstance(text, str) and n_out is None: | |
| return out[0] | |
| return out | |
| if __name__ == "__main__": | |
| print("Initializing a translator to pre-download models...") | |
| translator = Translator() | |
| print("Initialization successful!") | |