Spaces:
Running
Running
| import gradio as gr | |
| import json | |
| import os | |
| import string | |
| import re | |
| import torch | |
| from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig | |
| from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration | |
| import fasttext | |
| from huggingface_hub import hf_hub_download | |
| summarization_model_names = [ | |
| "google/bigbird-pegasus-large-arxiv", | |
| "facebook/bart-large-cnn", | |
| "google/t5-v1_1-large", | |
| "sshleifer/distilbart-cnn-12-6", | |
| "allenai/led-base-16384", | |
| "google/pegasus-xsum", | |
| "togethercomputer/LLaMA-2-7B-32K" | |
| ] | |
| # Placeholder for the summarizer pipeline, tokenizer, and maximum tokens | |
| summarizer = None | |
| tokenizer = None | |
| max_tokens = None | |
| # Function to load the selected model | |
| def load_summarization_model(model_name): | |
| global summarizer, tokenizer, max_tokens | |
| try: | |
| summarizer = pipeline("summarization", model=model_name, torch_dtype=torch.bfloat16) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| config = AutoConfig.from_pretrained(model_name) | |
| if hasattr(config, 'max_position_embeddings'): | |
| max_tokens = config.max_position_embeddings | |
| elif hasattr(config, 'n_positions'): | |
| max_tokens = config.n_positions | |
| elif hasattr(config, 'd_model'): | |
| max_tokens = config.d_model # for T5 models, d_model is a rough proxy | |
| else: | |
| max_tokens = "Unknown" | |
| return f"Model {model_name} loaded successfully! Max tokens: {max_tokens}" | |
| except Exception as e: | |
| return f"Failed to load model {model_name}. Error: {str(e)}" | |
| def summarize_text(input, min_length, max_length): | |
| if summarizer is None: | |
| return "No model loaded!" | |
| input_tokens = tokenizer.encode(input, return_tensors="pt") | |
| num_tokens = input_tokens.shape[1] | |
| if num_tokens > max_tokens: | |
| return f"Error: The input text has {num_tokens} tokens, which exceeds the maximum allowed {max_tokens} tokens. Please enter shorter text." | |
| min_summary_length = int(num_tokens * (min_length / 100)) | |
| max_summary_length = int(num_tokens * (max_length / 100)) | |
| output = summarizer(input, min_length=min_summary_length, max_length=max_summary_length) | |
| return output[0]['summary_text'] | |
| model_path = hf_hub_download(repo_id="cis-lmu/glotlid", filename="model.bin") | |
| identification_model = fasttext.load_model(model_path) | |
| def lang_ident(text): | |
| label, array = identification_model.predict(text) | |
| label = get_name(label[0].split('__')[-1].replace('_Hans', '_Hani').replace('_Hant', '_Hani')) | |
| return {"language" : label, "score" : array[0]} | |
| pretrained_model: str = "facebook/m2m100_1.2B" | |
| cache_dir: str = "models/" | |
| tokenizer = M2M100Tokenizer.from_pretrained(pretrained_model, cache_dir=cache_dir) | |
| translation_model = M2M100ForConditionalGeneration.from_pretrained( | |
| pretrained_model, cache_dir=cache_dir) | |
| transcription = pipeline("automatic-speech-recognition", model= "openai/whisper-base") | |
| clasification = pipeline( | |
| "audio-classification", | |
| model="anton-l/xtreme_s_xlsr_300m_minds14", | |
| ) | |
| def language_names(json_path): | |
| with open(json_path, 'r') as json_file: | |
| data = json.load(json_file) | |
| return data | |
| label2name = language_names("assetslanguage_names.json") | |
| def get_name(label): | |
| """Get the name of language from label""" | |
| iso_3 = label.split('_')[0] | |
| name = label2name[iso_3] | |
| return name | |
| def audio_a_text(audio): | |
| text = transcription(audio)["text"] | |
| return text | |
| def text_to_sentimient(audio): | |
| #text = transcription(audio)["text"] | |
| return clasification(audio) | |
| lang_id = { | |
| "Afrikaans": "af", | |
| "Amharic": "am", | |
| "Arabic": "ar", | |
| "Asturian": "ast", | |
| "Azerbaijani": "az", | |
| "Bashkir": "ba", | |
| "Belarusian": "be", | |
| "Bulgarian": "bg", | |
| "Bengali": "bn", | |
| "Breton": "br", | |
| "Bosnian": "bs", | |
| "Catalan": "ca", | |
| "Cebuano": "ceb", | |
| "Czech": "cs", | |
| "Welsh": "cy", | |
| "Danish": "da", | |
| "German": "de", | |
| "Greeek": "el", | |
| "English": "en", | |
| "Spanish": "es", | |
| "Estonian": "et", | |
| "Persian": "fa", | |
| "Fulah": "ff", | |
| "Finnish": "fi", | |
| "French": "fr", | |
| "Western Frisian": "fy", | |
| "Irish": "ga", | |
| "Gaelic": "gd", | |
| "Galician": "gl", | |
| "Gujarati": "gu", | |
| "Hausa": "ha", | |
| "Hebrew": "he", | |
| "Hindi": "hi", | |
| "Croatian": "hr", | |
| "Haitian": "ht", | |
| "Hungarian": "hu", | |
| "Armenian": "hy", | |
| "Indonesian": "id", | |
| "Igbo": "ig", | |
| "Iloko": "ilo", | |
| "Icelandic": "is", | |
| "Italian": "it", | |
| "Japanese": "ja", | |
| "Javanese": "jv", | |
| "Georgian": "ka", | |
| "Kazakh": "kk", | |
| "Central Khmer": "km", | |
| "Kannada": "kn", | |
| "Korean": "ko", | |
| "Luxembourgish": "lb", | |
| "Ganda": "lg", | |
| "Lingala": "ln", | |
| "Lao": "lo", | |
| "Lithuanian": "lt", | |
| "Latvian": "lv", | |
| "Malagasy": "mg", | |
| "Macedonian": "mk", | |
| "Malayalam": "ml", | |
| "Mongolian": "mn", | |
| "Marathi": "mr", | |
| "Malay": "ms", | |
| "Burmese": "my", | |
| "Nepali": "ne", | |
| "Dutch": "nl", | |
| "Norwegian": "no", | |
| "Northern Sotho": "ns", | |
| "Occitan": "oc", | |
| "Oriya": "or", | |
| "Panjabi": "pa", | |
| "Polish": "pl", | |
| "Pushto": "ps", | |
| "Portuguese": "pt", | |
| "Romanian": "ro", | |
| "Russian": "ru", | |
| "Sindhi": "sd", | |
| "Sinhala": "si", | |
| "Slovak": "sk", | |
| "Slovenian": "sl", | |
| "Somali": "so", | |
| "Albanian": "sq", | |
| "Serbian": "sr", | |
| "Swati": "ss", | |
| "Sundanese": "su", | |
| "Swedish": "sv", | |
| "Swahili": "sw", | |
| "Tamil": "ta", | |
| "Thai": "th", | |
| "Tagalog": "tl", | |
| "Tswana": "tn", | |
| "Turkish": "tr", | |
| "Ukrainian": "uk", | |
| "Urdu": "ur", | |
| "Uzbek": "uz", | |
| "Vietnamese": "vi", | |
| "Wolof": "wo", | |
| "Xhosa": "xh", | |
| "Yiddish": "yi", | |
| "Yoruba": "yo", | |
| "Chinese": "zh", | |
| "Zulu": "zu", | |
| } | |
| def translation_text(source_lang, target_lang, user_input): | |
| src_lang = lang_id[source_lang] | |
| trg_lang = lang_id[target_lang] | |
| tokenizer.src_lang = src_lang | |
| with torch.no_grad(): | |
| encoded_input = tokenizer(user_input, return_tensors="pt") | |
| generated_tokens = translation_model.generate(**encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang)) | |
| translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
| return translated_text | |
| def print_s(source_lang, target_lang, text0): | |
| print(source_lang) | |
| return lang_id[source_lang], lang_id[target_lang], text0 | |
| demo = gr.Blocks() | |
| with demo: | |
| text0 = gr.Textbox() | |
| text = gr.Textbox() | |
| #gr.Markdown("Speech analyzer") | |
| #audio = gr.Audio(type="filepath", label = "Upload a file") | |
| model_dropdown = gr.Dropdown(choices=model_names, label="Choose a model", value="sshleifer/distilbart-cnn-12-6") | |
| load_message = gr.Textbox(label="Load Status", interactive=False) | |
| b1 = gr.Button("Load Model") | |
| min_length_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Minimum Summary Length (%)", value=10) | |
| max_length_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Maximum Summary Length (%)", value=20) | |
| summarize_button = gr.Button("Summarize Text") | |
| b1.click(fn=load_model, inputs=model_dropdown, outputs=load_message) | |
| summarize_button.click(fn=summarize_text, inputs=[text0, min_length_slider, max_length_slider], | |
| outputs=text) | |
| source_lang = gr.Dropdown(label="Source lang", choices=list(lang_id.keys()), value=list(lang_id.keys())[0]) | |
| target_lang = gr.Dropdown(label="target lang", choices=list(lang_id.keys()), value=list(lang_id.keys())[0]) | |
| #gr.Examples(examples = list(lang_id.keys()), | |
| # inputs=[ | |
| # source_lang]) | |
| #b1 = gr.Button("convert to text") | |
| b3 = gr.Button("translate") | |
| b3.click(translation_text, inputs = [source_lang, target_lang, text0], outputs = text) | |
| #b1.click(audio_a_text, inputs=audio, outputs=text) | |
| b2 = gr.Button("Classification of language") | |
| b2.click(lang_ident,inputs = text0, outputs=text) | |
| demo.launch() | |