Spaces:
Sleeping
Sleeping
| # coding=utf-8 | |
| # Copyright 2023 The GlotLID Authors. | |
| # Lint as: python3 | |
| # This space is built based on AMR-KELEG/ALDi space. | |
| # GlotLID Space | |
| import string | |
| import constants | |
| import pandas as pd | |
| import streamlit as st | |
| from huggingface_hub import hf_hub_download | |
| from GlotScript import get_script_predictor | |
| import matplotlib.pyplot as plt | |
| import fasttext | |
| import altair as alt | |
| from altair import X, Y, Scale | |
| import base64 | |
| import json | |
| import os | |
| import re | |
| def load_sp(): | |
| sp = get_script_predictor() | |
| return sp | |
| sp = load_sp() | |
| def get_script(text): | |
| """Get the writing systems of given text. | |
| Args: | |
| text: The text to be preprocessed. | |
| Returns: | |
| The main script and list of all scripts. | |
| """ | |
| res = sp(text) | |
| main_script = res[0] if res[0] else 'Zyyy' | |
| all_scripts_dict = res[2]['details'] | |
| if all_scripts_dict: | |
| all_scripts = list(all_scripts_dict.keys()) | |
| else: | |
| all_scripts = 'Zyyy' | |
| for ws in all_scripts: | |
| if ws in ['Kana', 'Hrkt', 'Hani', 'Hira']: | |
| all_scripts.append('Jpan') | |
| all_scripts = list(set(all_scripts)) | |
| return main_script, all_scripts | |
| def preprocess_text(text): | |
| """Apply preprocessing to the given text. | |
| Args: | |
| text: Thetext to be preprocessed. | |
| Returns: | |
| The preprocessed text. | |
| """ | |
| # remove \n | |
| text = text.replace('\n', ' ') | |
| # get rid of characters that are ubiquitous | |
| replace_by = " " | |
| replacement_map = { | |
| ord(c): replace_by | |
| for c in ':•#{|}' + string.digits | |
| } | |
| text = text.translate(replacement_map) | |
| # make multiple space one space | |
| text = re.sub(r'\s+', ' ', text) | |
| # strip the text | |
| text = text.strip() | |
| return text | |
| def language_names(json_path): | |
| with open(json_path, 'r') as json_file: | |
| data = json.load(json_file) | |
| return data | |
| label2name = language_names("assets/language_names.json") | |
| def get_name(label): | |
| """Get the name of language from label""" | |
| iso_3 = label.split('_')[0] | |
| name = label2name[iso_3] | |
| return name | |
| def render_svg(svg): | |
| """Renders the given svg string.""" | |
| b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8") | |
| html = rf'<p align="center"> <img src="data:image/svg+xml;base64,{b64}", width="40%"/> </p>' | |
| c = st.container() | |
| c.write(html, unsafe_allow_html=True) | |
| def render_metadata(): | |
| """Renders the metadata.""" | |
| html = r"""<p align="center"> | |
| <a href="https://huggingface.co/cis-lmu/glotlid"><img alt="HuggingFace Model" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-8A2BE2"></a> | |
| <a href="https://github.com/cisnlp/GlotLID"><img alt="GitHub" src="https://img.shields.io/badge/%F0%9F%93%A6%20GitHub-orange"></a> | |
| <a href="https://github.com/cisnlp/GlotLID/blob/main/LICENSE"><img alt="GitHub license" src="https://img.shields.io/github/license/cisnlp/GlotLID?logoColor=blue"></a> | |
| <a href="https://github.com/cisnlp/GlotLID"><img alt="GitHub stars" src="https://img.shields.io/github/stars/cisnlp/GlotLID"></a> | |
| <a href="https://arxiv.org/abs/2310.16248"><img alt="arXiv" src="https://img.shields.io/badge/arXiv-2310.16248-b31b1b.svg"></a> | |
| </p>""" | |
| c = st.container() | |
| c.write(html, unsafe_allow_html=True) | |
| def citation(): | |
| """Renders the metadata.""" | |
| _CITATION = """ | |
| @inproceedings{ | |
| kargaran2023glotlid, | |
| title={GlotLID: Language Identification for Low-Resource Languages}, | |
| author={Kargaran, Amir Hossein and Imani, Ayyoob and Yvon, Fran{\c{c}}ois and Sch{\"u}tze, Hinrich}, | |
| booktitle={The 2023 Conference on Empirical Methods in Natural Language Processing}, | |
| year={2023}, | |
| url={https://openreview.net/forum?id=dl4e3EBz5j} | |
| }""" | |
| st.code(_CITATION, language="python", line_numbers=False) | |
| def convert_df(df): | |
| # IMPORTANT: Cache the conversion to prevent computation on every rerun | |
| return df.to_csv(index=None).encode("utf-8") | |
| def load_GlotLID(model_name, file_name): | |
| model_path = hf_hub_download(repo_id=model_name, filename=file_name) | |
| model = fasttext.load_model(model_path) | |
| return model | |
| model_1 = load_GlotLID(constants.MODEL_NAME, "model_v1.bin") | |
| model_2 = load_GlotLID(constants.MODEL_NAME, "model_v2.bin") | |
| model_3 = load_GlotLID(constants.MODEL_NAME, "model_v3.bin") | |
| # @st.cache_resource | |
| def plot(label, prob): | |
| ORANGE_COLOR = "#FF8000" | |
| BLACK_COLOR = "#31333F" | |
| fig, ax = plt.subplots(figsize=(8, 1)) | |
| fig.patch.set_facecolor("none") | |
| ax.set_facecolor("none") | |
| ax.spines["left"].set_color(BLACK_COLOR) | |
| ax.spines["bottom"].set_color(BLACK_COLOR) | |
| ax.tick_params(axis="x", colors=BLACK_COLOR) | |
| ax.spines[["right", "top"]].set_visible(False) | |
| ax.barh(y=[0], width=[prob], color=ORANGE_COLOR) | |
| ax.set_xlim(0, 1) | |
| ax.set_ylim(-1, 1) | |
| ax.set_title(f"Label: {label}, Language: {get_name(label)}", color=BLACK_COLOR) | |
| ax.get_yaxis().set_visible(False) | |
| ax.set_xlabel("Confidence", color=BLACK_COLOR) | |
| st.pyplot(fig) | |
| def compute(sentences, version = 'v3'): | |
| """Computes the language probablities and labels for the given sentences. | |
| Args: | |
| sentences: A list of sentences. | |
| Returns: | |
| A list of language probablities and labels for the given sentences. | |
| """ | |
| progress_text = "Computing Language..." | |
| model_choice = model_3 if version == 'v3' else (model_2 if version == 'v2' else model_1) | |
| my_bar = st.progress(0, text=progress_text) | |
| probs = [] | |
| labels = [] | |
| sentences = [preprocess_text(sent) for sent in sentences] | |
| for index, sent in enumerate(sentences): | |
| output = model_choice.predict(sent) | |
| output_label = output[0][0].split('__')[-1] | |
| output_prob = max(min(output[1][0], 1), 0) | |
| output_label_language = output_label.split('_')[0] | |
| # script control | |
| if version in ['v2', 'v3'] and output_label_language!= 'zxx': | |
| main_script, all_scripts = get_script(sent) | |
| output_label_script = output_label.split('_')[1] | |
| if output_label_script not in all_scripts: | |
| output_label_script = main_script | |
| output_label = f"und_{output_label_script}" | |
| output_prob = 1.0 | |
| labels = labels + [output_label] | |
| probs = probs + [output_prob] | |
| my_bar.progress( | |
| min((index) / len(sentences), 1), | |
| text=progress_text, | |
| ) | |
| my_bar.empty() | |
| return probs, labels | |
| st.markdown("[](https://huggingface.co/spaces/cis-lmu/glotlid-space?duplicate=true)") | |
| render_svg(open("assets/glotlid_logo.svg").read()) | |
| render_metadata() | |
| st.markdown("**GlotLID** is an open-source language identification model with support for more than **2000 labels (V3)**.") | |
| tab1, tab2 = st.tabs(["Input a Sentence", "Upload a File"]) | |
| with tab1: | |
| version = st.radio( | |
| "Choose model", | |
| ["v1", "v2", "v3"], | |
| captions=["GlotLID version 1", "GlotLID version 2", "GlotLID version 3 (More languages, better quality data)"], | |
| index = 2, | |
| key = 'version_tab1', | |
| horizontal = True | |
| ) | |
| sent = st.text_input( | |
| "Sentence:", placeholder="Enter a sentence.", on_change=None | |
| ) | |
| # TODO: Check if this is needed! | |
| clicked = st.button("Submit") | |
| if sent: | |
| probs, labels = compute([sent], version=version) | |
| prob = probs[0] | |
| label = labels[0] | |
| # Check if the file exists | |
| if not os.path.exists('logs.txt'): | |
| with open('logs.txt', 'w') as file: | |
| pass | |
| print(f"{sent}, {label}: {prob}") | |
| with open("logs.txt", "a") as f: | |
| f.write(f"{sent}, {label}: {prob}\n") | |
| # plot | |
| plot(label, prob) | |
| with tab2: | |
| version = st.radio( | |
| "Choose model", | |
| ["v1", "v2", "v3"], | |
| captions=["GlotLID version 1", "GlotLID version 2", "GlotLID version 3 (More languages, better quality data)" ], | |
| index = 2, | |
| key = 'version_tab2', | |
| horizontal = True | |
| ) | |
| file = st.file_uploader("Upload a file", type=["txt"]) | |
| if file is not None: | |
| df = pd.read_csv(file, sep="¦\t¦", header=None, engine='python') | |
| df.columns = ["Sentence"] | |
| df.reset_index(drop=True, inplace=True) | |
| # TODO: Run the model | |
| df['Prob'], df["Label"] = compute(df["Sentence"].tolist(), version= version) | |
| df['Language'] = df["Label"].apply(get_name) | |
| # A horizontal rule | |
| st.markdown("""---""") | |
| chart = ( | |
| alt.Chart(df.reset_index()) | |
| .mark_area(color="darkorange", opacity=0.5) | |
| .encode( | |
| x=X(field="index", title="Sentence Index"), | |
| y=Y("Prob", scale=Scale(domain=[0, 1])), | |
| ) | |
| ) | |
| st.altair_chart(chart.interactive(), use_container_width=True) | |
| col1, col2 = st.columns([4, 1]) | |
| with col1: | |
| # Display the output | |
| st.table( | |
| df, | |
| ) | |
| with col2: | |
| # Add a download button | |
| csv = convert_df(df) | |
| st.download_button( | |
| label=":file_folder: Download predictions as CSV", | |
| data=csv, | |
| file_name="GlotLID.csv", | |
| mime="text/csv", | |
| ) | |
| # citation() |