Spaces:
Running
Running
| # coding=utf-8 | |
| # Copyright 2023 The HuggingFace Inc. team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Source: https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/agents/translation.py | |
| from transformers.models.auto import AutoModelForSeq2SeqLM, AutoTokenizer | |
| from smolagents.tools import PipelineTool, Tool | |
| LANGUAGE_CODES = { | |
| "Acehnese Arabic": "ace_Arab", | |
| "Acehnese Latin": "ace_Latn", | |
| "Mesopotamian Arabic": "acm_Arab", | |
| "Ta'izzi-Adeni Arabic": "acq_Arab", | |
| "Tunisian Arabic": "aeb_Arab", | |
| "Afrikaans": "afr_Latn", | |
| "South Levantine Arabic": "ajp_Arab", | |
| "Akan": "aka_Latn", | |
| "Amharic": "amh_Ethi", | |
| "North Levantine Arabic": "apc_Arab", | |
| "Modern Standard Arabic": "arb_Arab", | |
| "Modern Standard Arabic Romanized": "arb_Latn", | |
| "Najdi Arabic": "ars_Arab", | |
| "Moroccan Arabic": "ary_Arab", | |
| "Egyptian Arabic": "arz_Arab", | |
| "Assamese": "asm_Beng", | |
| "Asturian": "ast_Latn", | |
| "Awadhi": "awa_Deva", | |
| "Central Aymara": "ayr_Latn", | |
| "South Azerbaijani": "azb_Arab", | |
| "North Azerbaijani": "azj_Latn", | |
| "Bashkir": "bak_Cyrl", | |
| "Bambara": "bam_Latn", | |
| "Balinese": "ban_Latn", | |
| "Belarusian": "bel_Cyrl", | |
| "Bemba": "bem_Latn", | |
| "Bengali": "ben_Beng", | |
| "Bhojpuri": "bho_Deva", | |
| "Banjar Arabic": "bjn_Arab", | |
| "Banjar Latin": "bjn_Latn", | |
| "Standard Tibetan": "bod_Tibt", | |
| "Bosnian": "bos_Latn", | |
| "Buginese": "bug_Latn", | |
| "Bulgarian": "bul_Cyrl", | |
| "Catalan": "cat_Latn", | |
| "Cebuano": "ceb_Latn", | |
| "Czech": "ces_Latn", | |
| "Chokwe": "cjk_Latn", | |
| "Central Kurdish": "ckb_Arab", | |
| "Crimean Tatar": "crh_Latn", | |
| "Welsh": "cym_Latn", | |
| "Danish": "dan_Latn", | |
| "German": "deu_Latn", | |
| "Southwestern Dinka": "dik_Latn", | |
| "Dyula": "dyu_Latn", | |
| "Dzongkha": "dzo_Tibt", | |
| "Greek": "ell_Grek", | |
| "English": "eng_Latn", | |
| "Esperanto": "epo_Latn", | |
| "Estonian": "est_Latn", | |
| "Basque": "eus_Latn", | |
| "Ewe": "ewe_Latn", | |
| "Faroese": "fao_Latn", | |
| "Fijian": "fij_Latn", | |
| "Finnish": "fin_Latn", | |
| "Fon": "fon_Latn", | |
| "French": "fra_Latn", | |
| "Friulian": "fur_Latn", | |
| "Nigerian Fulfulde": "fuv_Latn", | |
| "Scottish Gaelic": "gla_Latn", | |
| "Irish": "gle_Latn", | |
| "Galician": "glg_Latn", | |
| "Guarani": "grn_Latn", | |
| "Gujarati": "guj_Gujr", | |
| "Haitian Creole": "hat_Latn", | |
| "Hausa": "hau_Latn", | |
| "Hebrew": "heb_Hebr", | |
| "Hindi": "hin_Deva", | |
| "Chhattisgarhi": "hne_Deva", | |
| "Croatian": "hrv_Latn", | |
| "Hungarian": "hun_Latn", | |
| "Armenian": "hye_Armn", | |
| "Igbo": "ibo_Latn", | |
| "Ilocano": "ilo_Latn", | |
| "Indonesian": "ind_Latn", | |
| "Icelandic": "isl_Latn", | |
| "Italian": "ita_Latn", | |
| "Javanese": "jav_Latn", | |
| "Japanese": "jpn_Jpan", | |
| "Kabyle": "kab_Latn", | |
| "Jingpho": "kac_Latn", | |
| "Kamba": "kam_Latn", | |
| "Kannada": "kan_Knda", | |
| "Kashmiri Arabic": "kas_Arab", | |
| "Kashmiri Devanagari": "kas_Deva", | |
| "Georgian": "kat_Geor", | |
| "Central Kanuri Arabic": "knc_Arab", | |
| "Central Kanuri Latin": "knc_Latn", | |
| "Kazakh": "kaz_Cyrl", | |
| "Kabiyè": "kbp_Latn", | |
| "Kabuverdianu": "kea_Latn", | |
| "Khmer": "khm_Khmr", | |
| "Kikuyu": "kik_Latn", | |
| "Kinyarwanda": "kin_Latn", | |
| "Kyrgyz": "kir_Cyrl", | |
| "Kimbundu": "kmb_Latn", | |
| "Northern Kurdish": "kmr_Latn", | |
| "Kikongo": "kon_Latn", | |
| "Korean": "kor_Hang", | |
| "Lao": "lao_Laoo", | |
| "Ligurian": "lij_Latn", | |
| "Limburgish": "lim_Latn", | |
| "Lingala": "lin_Latn", | |
| "Lithuanian": "lit_Latn", | |
| "Lombard": "lmo_Latn", | |
| "Latgalian": "ltg_Latn", | |
| "Luxembourgish": "ltz_Latn", | |
| "Luba-Kasai": "lua_Latn", | |
| "Ganda": "lug_Latn", | |
| "Luo": "luo_Latn", | |
| "Mizo": "lus_Latn", | |
| "Standard Latvian": "lvs_Latn", | |
| "Magahi": "mag_Deva", | |
| "Maithili": "mai_Deva", | |
| "Malayalam": "mal_Mlym", | |
| "Marathi": "mar_Deva", | |
| "Minangkabau Arabic ": "min_Arab", | |
| "Minangkabau Latin": "min_Latn", | |
| "Macedonian": "mkd_Cyrl", | |
| "Plateau Malagasy": "plt_Latn", | |
| "Maltese": "mlt_Latn", | |
| "Meitei Bengali": "mni_Beng", | |
| "Halh Mongolian": "khk_Cyrl", | |
| "Mossi": "mos_Latn", | |
| "Maori": "mri_Latn", | |
| "Burmese": "mya_Mymr", | |
| "Dutch": "nld_Latn", | |
| "Norwegian Nynorsk": "nno_Latn", | |
| "Norwegian Bokmål": "nob_Latn", | |
| "Nepali": "npi_Deva", | |
| "Northern Sotho": "nso_Latn", | |
| "Nuer": "nus_Latn", | |
| "Nyanja": "nya_Latn", | |
| "Occitan": "oci_Latn", | |
| "West Central Oromo": "gaz_Latn", | |
| "Odia": "ory_Orya", | |
| "Pangasinan": "pag_Latn", | |
| "Eastern Panjabi": "pan_Guru", | |
| "Papiamento": "pap_Latn", | |
| "Western Persian": "pes_Arab", | |
| "Polish": "pol_Latn", | |
| "Portuguese": "por_Latn", | |
| "Dari": "prs_Arab", | |
| "Southern Pashto": "pbt_Arab", | |
| "Ayacucho Quechua": "quy_Latn", | |
| "Romanian": "ron_Latn", | |
| "Rundi": "run_Latn", | |
| "Russian": "rus_Cyrl", | |
| "Sango": "sag_Latn", | |
| "Sanskrit": "san_Deva", | |
| "Santali": "sat_Olck", | |
| "Sicilian": "scn_Latn", | |
| "Shan": "shn_Mymr", | |
| "Sinhala": "sin_Sinh", | |
| "Slovak": "slk_Latn", | |
| "Slovenian": "slv_Latn", | |
| "Samoan": "smo_Latn", | |
| "Shona": "sna_Latn", | |
| "Sindhi": "snd_Arab", | |
| "Somali": "som_Latn", | |
| "Southern Sotho": "sot_Latn", | |
| "Spanish": "spa_Latn", | |
| "Tosk Albanian": "als_Latn", | |
| "Sardinian": "srd_Latn", | |
| "Serbian": "srp_Cyrl", | |
| "Swati": "ssw_Latn", | |
| "Sundanese": "sun_Latn", | |
| "Swedish": "swe_Latn", | |
| "Swahili": "swh_Latn", | |
| "Silesian": "szl_Latn", | |
| "Tamil": "tam_Taml", | |
| "Tatar": "tat_Cyrl", | |
| "Telugu": "tel_Telu", | |
| "Tajik": "tgk_Cyrl", | |
| "Tagalog": "tgl_Latn", | |
| "Thai": "tha_Thai", | |
| "Tigrinya": "tir_Ethi", | |
| "Tamasheq Latin": "taq_Latn", | |
| "Tamasheq Tifinagh": "taq_Tfng", | |
| "Tok Pisin": "tpi_Latn", | |
| "Tswana": "tsn_Latn", | |
| "Tsonga": "tso_Latn", | |
| "Turkmen": "tuk_Latn", | |
| "Tumbuka": "tum_Latn", | |
| "Turkish": "tur_Latn", | |
| "Twi": "twi_Latn", | |
| "Central Atlas Tamazight": "tzm_Tfng", | |
| "Uyghur": "uig_Arab", | |
| "Ukrainian": "ukr_Cyrl", | |
| "Umbundu": "umb_Latn", | |
| "Urdu": "urd_Arab", | |
| "Northern Uzbek": "uzn_Latn", | |
| "Venetian": "vec_Latn", | |
| "Vietnamese": "vie_Latn", | |
| "Waray": "war_Latn", | |
| "Wolof": "wol_Latn", | |
| "Xhosa": "xho_Latn", | |
| "Eastern Yiddish": "ydd_Hebr", | |
| "Yoruba": "yor_Latn", | |
| "Yue Chinese": "yue_Hant", | |
| "Chinese Simplified": "zho_Hans", | |
| "Chinese Traditional": "zho_Hant", | |
| "Standard Malay": "zsm_Latn", | |
| "Zulu": "zul_Latn", | |
| } | |
| class TranslationTool(PipelineTool): | |
| """ | |
| Example: | |
| ```py | |
| translator = TranslationTool() | |
| translator("This is a super nice API!", src_lang="English", tgt_lang="French") | |
| ``` | |
| """ | |
| lang_to_code = LANGUAGE_CODES | |
| default_checkpoint = "facebook/nllb-200-distilled-600M" | |
| description = ( | |
| "This is a tool that translates text from a language to another.\n\n" | |
| f"Both `src_lang` and `tgt_lang` should belong to this list of languages: {list(lang_to_code.keys())}." | |
| ) | |
| name = "translator" | |
| pre_processor_class = AutoTokenizer | |
| model_class = AutoModelForSeq2SeqLM | |
| inputs = { | |
| "text": {"type": "string", "description": "The text to translate"}, | |
| "src_lang": { | |
| "type": "string", | |
| "description": "The language of the text to translate. Written in plain English, such as 'Romanian', or 'Albanian'", | |
| }, | |
| "tgt_lang": { | |
| "type": "string", | |
| "description": "The language for the desired output language. Written in plain English, such as 'Romanian', or 'Albanian'", | |
| }, | |
| } | |
| output_type = "string" | |
| def encode(self, text, src_lang, tgt_lang): | |
| if src_lang not in self.lang_to_code: | |
| raise ValueError(f"{src_lang} is not a supported language.") | |
| if tgt_lang not in self.lang_to_code: | |
| raise ValueError(f"{tgt_lang} is not a supported language.") | |
| src_lang = self.lang_to_code[src_lang] | |
| tgt_lang = self.lang_to_code[tgt_lang] | |
| return self.pre_processor._build_translation_inputs( | |
| text, return_tensors="pt", src_lang=src_lang, tgt_lang=tgt_lang | |
| ) | |
| def decode(self, outputs): | |
| return self.post_processor.decode(outputs[0].tolist(), skip_special_tokens=True) | |
| def forward(self, inputs): | |
| return self.model.generate(**inputs) | |