Spaces:
Sleeping
Sleeping
| # """ | |
| # Author: Amir Hossein Kargaran | |
| # Date: August, 2023 | |
| # Description: This code applies LIME (Local Interpretable Model-Agnostic Explanations) on language identification models. | |
| # MIT License | |
| # Some part of the code is adopted from here: https://gist.github.com/ageitgey/60a8b556a9047a4ca91d6034376e5980 | |
| # """ | |
| import gradio as gr | |
| from io import BytesIO | |
| from fasttext.FastText import _FastText | |
| import re | |
| import lime.lime_text | |
| import numpy as np | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| from selenium import webdriver | |
| from selenium.common.exceptions import WebDriverException | |
| import os | |
| # Define a dictionary to map model choices to their respective paths | |
| model_paths = { | |
| "OpenLID": ["laurievb/OpenLID", 'model.bin'], | |
| "GlotLID": ["cis-lmu/glotlid", 'model.bin'], | |
| "NLLB": ["facebook/fasttext-language-identification", 'model.bin'] | |
| } | |
| # Create a dictionary to cache classifiers | |
| cached_classifiers = {} | |
| def load_classifier(model_choice): | |
| if model_choice in cached_classifiers: | |
| return cached_classifiers[model_choice] | |
| # Load the FastText language identification model from Hugging Face Hub | |
| model_path = hf_hub_download(repo_id=model_paths[model_choice][0], filename=model_paths[model_choice][1]) | |
| # Create the FastText classifier | |
| classifier = _FastText(model_path) | |
| cached_classifiers[model_choice] = classifier | |
| return classifier | |
| # cache all models | |
| for model_choice in model_paths.keys(): | |
| load_classifier(model_choice) | |
| def remove_label_prefix(item): | |
| return item.replace('__label__', '') | |
| def remove_label_prefix_list(input_list): | |
| if isinstance(input_list[0], list): | |
| return [[remove_label_prefix(item) for item in inner_list] for inner_list in input_list] | |
| else: | |
| return [remove_label_prefix(item) for item in input_list] | |
| def tokenize_string(sentence, n=None): | |
| if n is None: | |
| tokens = sentence.split() | |
| else: | |
| tokens = [] | |
| for i in range(len(sentence) - n + 1): | |
| tokens.append(sentence[i:i + n]) | |
| return tokens | |
| def fasttext_prediction_in_sklearn_format(classifier, texts, num_class): | |
| # if isinstance(texts, str): | |
| # texts = [texts] | |
| res = [] | |
| labels, probabilities = classifier.predict(texts, -1) | |
| labels = remove_label_prefix_list(labels) | |
| for label, probs, text in zip(labels, probabilities, texts): | |
| order = np.argsort(np.array(label)) | |
| res.append(probs[order]) | |
| return np.array(res) | |
| def generate_explanation_html(input_sentence, explainer, classifier, num_class): | |
| preprocessed_sentence = input_sentence | |
| exp = explainer.explain_instance( | |
| preprocessed_sentence, | |
| classifier_fn=lambda x: fasttext_prediction_in_sklearn_format(classifier, x, num_class), | |
| top_labels=2, | |
| num_features=20, | |
| ) | |
| output_html_filename = "explanation.html" | |
| exp.save_to_file(output_html_filename) | |
| return output_html_filename | |
| def take_screenshot(local_html_path): | |
| options = webdriver.ChromeOptions() | |
| options.add_argument('--headless') | |
| options.add_argument('--no-sandbox') | |
| options.add_argument('--disable-dev-shm-usage') | |
| try: | |
| local_html_path = os.path.abspath(local_html_path) | |
| wd = webdriver.Chrome(options=options) | |
| wd.set_window_size(1366, 728) | |
| wd.get('file://' + local_html_path) | |
| wd.implicitly_wait(10) | |
| screenshot = wd.get_screenshot_as_png() | |
| except WebDriverException as e: | |
| return Image.new('RGB', (1, 1)) | |
| finally: | |
| if wd: | |
| wd.quit() | |
| return Image.open(BytesIO(screenshot)) | |
| # Define the merge function | |
| def merge_function(input_sentence, selected_model): | |
| input_sentence = input_sentence.replace('\n', ' ') | |
| # Load the FastText language identification model from Hugging Face Hub | |
| classifier = load_classifier(selected_model) | |
| class_names = remove_label_prefix_list(classifier.labels) | |
| class_names = np.sort(class_names) | |
| num_class = len(class_names) | |
| # Load Lime | |
| explainer = lime.lime_text.LimeTextExplainer( | |
| split_expression=tokenize_string, | |
| bow=False, | |
| class_names=class_names) | |
| # Generate output | |
| output_html_filename = generate_explanation_html(input_sentence, explainer, classifier, num_class) | |
| im = take_screenshot(output_html_filename) | |
| return im, output_html_filename | |
| # Define the Gradio interface | |
| input_text = gr.Textbox(label="Input Text", value="J'ai visited la beautiful beach avec mes amis for a relaxing journée under the sun.") | |
| model_choice = gr.Radio(choices=["GlotLID", "OpenLID", "NLLB"], label="Select Model", value='GlotLID') | |
| output_explanation = gr.File(label="Explanation HTML") | |
| iface = gr.Interface( | |
| fn=merge_function, | |
| inputs=[input_text, model_choice], | |
| outputs=[gr.Image(type="pil", height=364, width=683, label="Explanation Image"), output_explanation], | |
| title="LIME LID", | |
| description="This code applies LIME (Local Interpretable Model-Agnostic Explanations) on fasttext language identification.", | |
| allow_flagging='never', | |
| theme=gr.themes.Soft() | |
| ) | |
| iface.launch() | |