Spaces:
Runtime error
Runtime error
| """ | |
| File: model_translation.py | |
| Description: | |
| Loading models for text translations | |
| Author: Didier Guillevic | |
| Date: 2024-03-16 | |
| """ | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| src_langs = set(["ar", "en", "fa", "fr", "he", "ja", "zh"]) | |
| model_names = { | |
| "ar": "Helsinki-NLP/opus-mt-ar-en", | |
| "en": "Helsinki-NLP/opus-mt-en-fr", | |
| "fa": "Helsinki-NLP/opus-mt-tc-big-fa-itc", | |
| "fr": "Helsinki-NLP/opus-mt-fr-en", | |
| "he": "Helsinki-NLP/opus-mt-tc-big-he-en", | |
| "ja": "Helsinki-NLP/opus-mt-jap-en", | |
| "zh": "Helsinki-NLP/opus-mt-zh-en", | |
| } | |
| # Registry for all loaded bilingual models | |
| tokenizer_model_registry = {} | |
| device = 'cpu' | |
| def get_tokenizer_model_for_src_lang(src_lang: str) -> (AutoTokenizer, AutoModelForSeq2SeqLM): | |
| """ | |
| Return the (tokenizer, model) for a given source language. | |
| """ | |
| src_lang = src_lang.lower() | |
| # Already loaded? | |
| if src_lang in tokenizer_model_registry: | |
| return tokenizer_model_registry.get(src_lang) | |
| # Load tokenizer and model | |
| model_name = model_names.get(src_lang) | |
| if not model_name: | |
| raise Exception(f"No model defined for language: {src_lang}") | |
| # We will leave the models on the CPU (for now) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
| if model.config.torch_dtype != torch.float16: | |
| model = model.half() | |
| model.to(device) | |
| tokenizer_model_registry[src_lang] = (tokenizer, model) | |
| return (tokenizer, model) | |
| # Max number of words for given input text | |
| # - Usually 512 tokens (max position encodings, as well as max length) | |
| # - Let's set to some number of words somewhat lower than that threshold | |
| # - e.g. 200 words | |
| max_words_per_chunk = 200 | |
| # | |
| # Multilingual language pairs | |
| # | |
| from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration | |
| model_name_m2m100 = "facebook/m2m100_418M" | |
| tokenizer_m2m100 = M2M100Tokenizer.from_pretrained(model_name_m2m100) | |
| model_m2m100 = M2M100ForConditionalGeneration.from_pretrained( | |
| model_name_m2m100, | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| low_cpu_mem_usage=True | |
| ) | |
| # | |
| # Multilingual translation model | |
| # | |
| model_MADLAD_name = "google/madlad400-3b-mt" | |
| #model_MADLAD_name = "google/madlad400-7b-mt-bt" | |
| tokenizer_multilingual = AutoTokenizer.from_pretrained(model_MADLAD_name, use_fast=True) | |
| model_multilingual = AutoModelForSeq2SeqLM.from_pretrained( | |
| model_MADLAD_name, | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| low_cpu_mem_usage=True | |
| ) | |