Spaces:
Runtime error
Runtime error
| from typing import Optional | |
| import gradio as gr | |
| from PIL import Image | |
| from off_topic import OffTopicDetector, Translator | |
| translator = Translator("Helsinki-NLP/opus-mt-roa-en") | |
| detector = OffTopicDetector("openai/clip-vit-base-patch32", image_size="V", translator=translator) | |
| def validate_item(item_id: str, use_title: bool, threshold: float): | |
| images, domain, probas, valid_probas, invalid_probas = detector.predict_probas_item(item_id, use_title=use_title) | |
| valid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() >= threshold] | |
| invalid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() < threshold] | |
| return f"## Domain: {domain}", valid_images, invalid_images | |
| def validate_images(img_url_1, img_url_2, img_url_3, domain: str, title: str, threshold: float): | |
| img_urls = [url for url in [img_url_1, img_url_2, img_url_3] if url != ""] | |
| site, domain = domain.split("-") | |
| domain_text = domain.replace("_", " ").lower() | |
| if title == "": | |
| title = None | |
| images, output = detector.predict_probas_url(img_urls, domain_text, site, title) | |
| probas, valid_probas, invalid_probas = output | |
| valid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() >= threshold] | |
| invalid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() < threshold] | |
| return f"## Domain: {domain}", valid_images, invalid_images | |
| with gr.Blocks() as demo: | |
| gr.Markdown(""" | |
| # Off topic image detector | |
| ### This app takes an item ID and classifies its pictures as valid/invalid depending on whether they relate to the domain in which it's been listed. | |
| Input an item ID or select one of the preloaded examples below.""") | |
| with gr.Tab("From item_id"): | |
| with gr.Row(): | |
| item_id = gr.Textbox(label="Item ID") | |
| with gr.Column(): | |
| use_title = gr.Checkbox(label="Use translated item title", value=True) | |
| threshold = gr.Number(label="Threshold", value=0.25, precision=2) | |
| submit = gr.Button("Submit") | |
| gr.HTML("<hr>") | |
| domain = gr.Markdown() | |
| valid = gr.Gallery(label="Valid images").style(grid=[1, 2, 3], height="auto") | |
| gr.HTML("<hr>") | |
| invalid = gr.Gallery(label="Invalid images").style(grid=[1, 2, 3], height="auto") | |
| submit.click(inputs=[item_id, use_title, threshold], outputs=[domain, valid, invalid], fn=validate_item) | |
| gr.HTML("<hr>") | |
| gr.Examples( | |
| examples=[["MLC572974424", True, 0.25], ["MLU449951849", True, 0.25], ["MLA1293465558", True, 0.25], | |
| ["MLB3184663685", True, 0.25], ["MLC1392230619", True, 0.25], ["MCO546152796", True, 0.25]], | |
| inputs=[item_id, use_title, threshold], | |
| outputs=[domain, valid, invalid], | |
| fn=validate_item, | |
| cache_examples=True, | |
| ) | |
| with gr.Tab("From image urls"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| img_url_1 = gr.Textbox(label="Picture URL") | |
| img_url_2 = gr.Textbox(label="Picture URL") | |
| img_url_3 = gr.Textbox(label="Picture URL") | |
| with gr.Column(): | |
| domain = gr.Textbox(label="Domain ID", placeholder="Required") | |
| title = gr.Textbox(label="Item title", placeholder="Optional") | |
| threshold = gr.Number(label="Threshold", value=0.25, precision=2) | |
| submit = gr.Button("Submit") | |
| gr.HTML("<hr>") | |
| domain_output = gr.Markdown() | |
| valid = gr.Gallery(label="Valid images").style(grid=[1, 2, 3], height="auto") | |
| gr.HTML("<hr>") | |
| invalid = gr.Gallery(label="Invalid images").style(grid=[1, 2, 3], height="auto") | |
| submit.click(inputs=[img_url_1, img_url_2, img_url_3, domain, title, threshold], outputs=[domain_output, valid, invalid], fn=validate_images) | |
| gr.HTML("<hr>") | |
| #gr.Examples( | |
| # examples=[["MLC572974424", True, 0.25], ["MLU449951849", True, 0.25], ["MLA1293465558", True, 0.25], | |
| # ["MLB3184663685", True, 0.25], ["MLC1392230619", True, 0.25], ["MCO546152796", True, 0.25]], | |
| # inputs=[item_id, use_title, threshold], | |
| # outputs=[domain, valid, invalid], | |
| # fn=validate, | |
| # cache_examples=True, | |
| #) | |
| demo.launch() | |