Spaces:
Runtime error
Runtime error
| """ | |
| File: model_translation.py | |
| Description: | |
| Loading models for text translations | |
| Author: Didier Guillevic | |
| Date: 2024-03-16 | |
| """ | |
| import spaces | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO) | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration | |
| from transformers import BitsAndBytesConfig | |
| from model_spacy import nlp_xx as model_spacy | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_8bit=True, | |
| llm_int8_threshold=200.0 # https://discuss.huggingface.co/t/correct-usage-of-bitsandbytesconfig/33809/5 | |
| ) | |
| # The 100 languages supported by the facebook/m2m100_418M model | |
| # https://huggingface.co/facebook/m2m100_418M | |
| # plus the 'AUTOMATIC' option where we will use a language detector. | |
| language_codes = { | |
| 'AUTOMATIC': 'auto', | |
| 'Afrikaans (af)': 'af', | |
| 'Albanian (sq)': 'sq', | |
| 'Amharic (am)': 'am', | |
| 'Arabic (ar)': 'ar', | |
| 'Armenian (hy)': 'hy', | |
| 'Asturian (ast)': 'ast', | |
| 'Azerbaijani (az)': 'az', | |
| 'Bashkir (ba)': 'ba', | |
| 'Belarusian (be)': 'be', | |
| 'Bengali (bn)': 'bn', | |
| 'Bosnian (bs)': 'bs', | |
| 'Breton (br)': 'br', | |
| 'Bulgarian (bg)': 'bg', | |
| 'Burmese (my)': 'my', | |
| 'Catalan; Valencian (ca)': 'ca', | |
| 'Cebuano (ceb)': 'ceb', | |
| 'Central Khmer (km)': 'km', | |
| 'Chinese (zh)': 'zh', | |
| 'Croatian (hr)': 'hr', | |
| 'Czech (cs)': 'cs', | |
| 'Danish (da)': 'da', | |
| 'Dutch; Flemish (nl)': 'nl', | |
| 'English (en)': 'en', | |
| 'Estonian (et)': 'et', | |
| 'Finnish (fi)': 'fi', | |
| 'French (fr)': 'fr', | |
| 'Fulah (ff)': 'ff', | |
| 'Gaelic; Scottish Gaelic (gd)': 'gd', | |
| 'Galician (gl)': 'gl', | |
| 'Ganda (lg)': 'lg', | |
| 'Georgian (ka)': 'ka', | |
| 'German (de)': 'de', | |
| 'Greeek (el)': 'el', | |
| 'Gujarati (gu)': 'gu', | |
| 'Haitian; Haitian Creole (ht)': 'ht', | |
| 'Hausa (ha)': 'ha', | |
| 'Hebrew (he)': 'he', | |
| 'Hindi (hi)': 'hi', | |
| 'Hungarian (hu)': 'hu', | |
| 'Icelandic (is)': 'is', | |
| 'Igbo (ig)': 'ig', | |
| 'Iloko (ilo)': 'ilo', | |
| 'Indonesian (id)': 'id', | |
| 'Irish (ga)': 'ga', | |
| 'Italian (it)': 'it', | |
| 'Japanese (ja)': 'ja', | |
| 'Javanese (jv)': 'jv', | |
| 'Kannada (kn)': 'kn', | |
| 'Kazakh (kk)': 'kk', | |
| 'Korean (ko)': 'ko', | |
| 'Lao (lo)': 'lo', | |
| 'Latvian (lv)': 'lv', | |
| 'Lingala (ln)': 'ln', | |
| 'Lithuanian (lt)': 'lt', | |
| 'Luxembourgish; Letzeburgesch (lb)': 'lb', | |
| 'Macedonian (mk)': 'mk', | |
| 'Malagasy (mg)': 'mg', | |
| 'Malay (ms)': 'ms', | |
| 'Malayalam (ml)': 'ml', | |
| 'Marathi (mr)': 'mr', | |
| 'Mongolian (mn)': 'mn', | |
| 'Nepali (ne)': 'ne', | |
| 'Northern Sotho (ns)': 'ns', | |
| 'Norwegian (no)': 'no', | |
| 'Occitan (post 1500) (oc)': 'oc', | |
| 'Oriya (or)': 'or', | |
| 'Panjabi; Punjabi (pa)': 'pa', | |
| 'Persian (fa)': 'fa', | |
| 'Polish (pl)': 'pl', | |
| 'Portuguese (pt)': 'pt', | |
| 'Pushto; Pashto (ps)': 'ps', | |
| 'Romanian; Moldavian; Moldovan (ro)': 'ro', | |
| 'Russian (ru)': 'ru', | |
| 'Serbian (sr)': 'sr', | |
| 'Sindhi (sd)': 'sd', | |
| 'Sinhala; Sinhalese (si)': 'si', | |
| 'Slovak (sk)': 'sk', | |
| 'Slovenian (sl)': 'sl', | |
| 'Somali (so)': 'so', | |
| 'Spanish (es)': 'es', | |
| 'Sundanese (su)': 'su', | |
| 'Swahili (sw)': 'sw', | |
| 'Swati (ss)': 'ss', | |
| 'Swedish (sv)': 'sv', | |
| 'Tagalog (tl)': 'tl', | |
| 'Tamil (ta)': 'ta', | |
| 'Thai (th)': 'th', | |
| 'Tswana (tn)': 'tn', | |
| 'Turkish (tr)': 'tr', | |
| 'Ukrainian (uk)': 'uk', | |
| 'Urdu (ur)': 'ur', | |
| 'Uzbek (uz)': 'uz', | |
| 'Vietnamese (vi)': 'vi', | |
| 'Welsh (cy)': 'cy', | |
| 'Western Frisian (fy)': 'fy', | |
| 'Wolof (wo)': 'wo', | |
| 'Xhosa (xh)': 'xh', | |
| 'Yiddish (yi)': 'yi', | |
| 'Yoruba (yo)': 'yo', | |
| 'Zulu (zu)': 'zu' | |
| } | |
| tgt_language_codes = { | |
| 'English (en)': 'en', | |
| 'French (fr)': 'fr' | |
| } | |
| def build_text_chunks( | |
| text: str, | |
| sents_per_chunk: int=5, | |
| words_per_chunk=200) -> list[str]: | |
| """Split a given text into chunks with at most sents_per_chnks and words_per_chunk | |
| Given a text: | |
| - Split the text into sentences. | |
| - Build text chunks: | |
| - Consider up to sents_per_chunk | |
| - Ensure that we do not exceed words_per_chunk | |
| """ | |
| # Split text into sentences... | |
| sentences = [ | |
| sent.text.strip() for sent in model_spacy(text).sents if sent.text.strip() | |
| ] | |
| logger.info(f"TEXT: {text[:25]}, NB_SENTS: {len(sentences)}") | |
| # Create text chunks of N sentences | |
| chunks = [] | |
| chunk = '' | |
| chunk_nb_sentences = 0 | |
| chunk_nb_words = 0 | |
| for i in range(0, len(sentences)): | |
| # Get sentence | |
| sent = sentences[i] | |
| sent_nb_words = len(sent.split()) | |
| # If chunk already 'full', save chunk, start new chunk | |
| if ( | |
| (chunk_nb_words + sent_nb_words > words_per_chunk) or | |
| (chunk_nb_sentences + 1 > sents_per_chunk) | |
| ): | |
| chunks.append(chunk) | |
| chunk = '' | |
| chunk_nb_sentences = 0 | |
| chunk_nb_words = 0 | |
| # Append sentence to current chunk. One sentence per line. | |
| chunk = (chunk + '\n' + sent) if chunk else sent | |
| chunk_nb_sentences += 1 | |
| chunk_nb_words += sent_nb_words | |
| # Append last chunk | |
| if chunk: | |
| chunks.append(chunk) | |
| return chunks | |
| class Singleton(type): | |
| _instances = {} | |
| def __call__(cls, *args, **kwargs): | |
| if cls not in cls._instances: | |
| cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) | |
| return cls._instances[cls] | |
| class ModelM2M100(metaclass=Singleton): | |
| """Loads an instance of the M2M100 model. | |
| Model: https://huggingface.co/facebook/m2m100_1.2B | |
| """ | |
| def __init__(self): | |
| self._model_name = "facebook/m2m100_418M" | |
| self._tokenizer = M2M100Tokenizer.from_pretrained(self._model_name) | |
| self._model = M2M100ForConditionalGeneration.from_pretrained( | |
| self._model_name, | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| low_cpu_mem_usage=True | |
| #quantization_config=quantization_config | |
| ) | |
| self._model = torch.compile(self._model) | |
| def translate( | |
| self, | |
| text: str, | |
| src_lang: str, | |
| tgt_lang: str, | |
| chunk_text: bool=True, | |
| sents_per_chunk: int=5, | |
| words_per_chunk: int=200 | |
| ) -> str: | |
| """Translate the given text from src_lang to tgt_lang. | |
| The text will be split into chunks to ensure the chunks fit into the | |
| model input_max_length (usually 512 tokens). | |
| """ | |
| chunks = [text,] | |
| if chunk_text: | |
| chunks = build_text_chunks(text, sents_per_chunk, words_per_chunk) | |
| self._tokenizer.src_lang = src_lang | |
| translated_chunks = [] | |
| for chunk in chunks: | |
| input_ids = self._tokenizer( | |
| chunk, | |
| return_tensors="pt").input_ids.to(self._model.device) | |
| outputs = self._model.generate( | |
| input_ids=input_ids, | |
| forced_bos_token_id=self._tokenizer.get_lang_id(tgt_lang)) | |
| translated_chunk = self._tokenizer.batch_decode( | |
| outputs, | |
| skip_special_tokens=True)[0] | |
| translated_chunks.append(translated_chunk) | |
| return '\n'.join(translated_chunks) | |
| def model_name(self): | |
| return self._model_name | |
| def tokenizer(self): | |
| return self._tokenizer | |
| def model(self): | |
| return self._model | |
| def device(self): | |
| return self._model.device | |
| class ModelMADLAD(metaclass=Singleton): | |
| """Loads an instance of the Google MADLAD model (3B). | |
| Model: https://huggingface.co/google/madlad400-3b-mt | |
| """ | |
| def __init__(self): | |
| self._model_name = "google/madlad400-3b-mt" | |
| self._input_max_length = 512 # config.json n_positions | |
| self._output_max_length = 512 # config.json n_positions | |
| self._tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_name, use_fast=True | |
| ) | |
| self._model = AutoModelForSeq2SeqLM.from_pretrained( | |
| self._model_name, | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| low_cpu_mem_usage=True, | |
| quantization_config=quantization_config | |
| ) | |
| self._model = torch.compile(self._model) | |
| def translate( | |
| self, | |
| text: str, | |
| tgt_lang: str, | |
| chunk_text: True, | |
| sents_per_chunk: int=5, | |
| words_per_chunk: int=5 | |
| ) -> str: | |
| """Translate given text into the target language. | |
| The text will be split into chunks to ensure the chunks fit into the | |
| model input_max_length (usually 512 tokens). | |
| """ | |
| chunks = [text,] | |
| if chunk_text: | |
| chunks = build_text_chunks(text, sents_per_chunk, words_per_chunk) | |
| translated_chunks = [] | |
| for chunk in chunks: | |
| input_text = f"<2{tgt_lang}> {chunk}" | |
| logger.info(f" Translating: {input_text[:50]}") | |
| input_ids = self._tokenizer( | |
| input_text, | |
| return_tensors="pt", | |
| max_length=self._input_max_length, | |
| truncation=True, | |
| padding="longest").input_ids.to(self._model.device) | |
| outputs = self._model.generate( | |
| input_ids=input_ids, | |
| max_length=self._output_max_length) | |
| translated_chunk = self._tokenizer.decode( | |
| outputs[0], | |
| skip_special_tokens=True) | |
| translated_chunks.append(translated_chunk) | |
| return '\n'.join(translated_chunks) | |
| def model_name(self): | |
| return self._model_name | |
| def tokenizer(self): | |
| return self._tokenizer | |
| def model(self): | |
| return self._model | |
| def device(self): | |
| return self._model.device | |
| # Bi-lingual individual models | |
| src_langs = set(["ar", "en", "fa", "fr", "he", "ja", "zh"]) | |
| model_names = { | |
| "ar": "Helsinki-NLP/opus-mt-ar-en", | |
| "en": "Helsinki-NLP/opus-mt-en-fr", | |
| "fa": "Helsinki-NLP/opus-mt-tc-big-fa-itc", | |
| "fr": "Helsinki-NLP/opus-mt-fr-en", | |
| "he": "Helsinki-NLP/opus-mt-tc-big-he-en", | |
| "zh": "Helsinki-NLP/opus-mt-zh-en", | |
| } | |
| # Registry for all loaded bilingual models | |
| tokenizer_model_registry = {} | |
| device = 'cpu' | |
| def get_tokenizer_model_for_src_lang(src_lang: str) -> (AutoTokenizer, AutoModelForSeq2SeqLM): | |
| """ | |
| Return the (tokenizer, model) for a given source language. | |
| """ | |
| src_lang = src_lang.lower() | |
| # Already loaded? | |
| if src_lang in tokenizer_model_registry: | |
| return tokenizer_model_registry.get(src_lang) | |
| # Load tokenizer and model | |
| model_name = model_names.get(src_lang) | |
| if not model_name: | |
| raise Exception(f"No model defined for language: {src_lang}") | |
| # We will leave the models on the CPU (for now) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
| if model.config.torch_dtype != torch.float16: | |
| model = model.half() | |
| model.to(device) | |
| tokenizer_model_registry[src_lang] = (tokenizer, model) | |
| return (tokenizer, model) | |
| # Max number of words for given input text | |
| # - Usually 512 tokens (max position encodings, as well as max length) | |
| # - Let's set to some number of words somewhat lower than that threshold | |
| # - e.g. 200 words | |
| max_words_per_chunk = 200 | |