| from typing import List | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from transformers import CLIPProcessor, CLIPModel | |
| IMAGENET_CLASSES_FILE = "imagenet-classes.txt" | |
| EXAMPLES = ["dog.jpeg", "car.png"] | |
| MARKDOWN = """ | |
| # Zero-Shot Image Classification with MetaCLIP | |
| This is the demo for a zero-shot image classification model based on | |
| [MetaCLIP](https://github.com/facebookresearch/MetaCLIP), described in the paper | |
| [Demystifying CLIP Data](https://arxiv.org/abs/2309.16671) that formalizes CLIP data | |
| curation as a simple algorithm. | |
| """ | |
| def load_text_lines(file_path: str) -> List[str]: | |
| with open(file_path, 'r') as file: | |
| lines = file.readlines() | |
| return [line.rstrip() for line in lines] | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = CLIPModel.from_pretrained("facebook/metaclip-b32-400m").to(device) | |
| processor = CLIPProcessor.from_pretrained("facebook/metaclip-b32-400m") | |
| imagenet_classes = load_text_lines(IMAGENET_CLASSES_FILE) | |
| def classify_image(input_image) -> str: | |
| inputs = processor( | |
| text=imagenet_classes, | |
| images=input_image, | |
| return_tensors="pt", | |
| padding=True).to(device) | |
| outputs = model(**inputs) | |
| probs = outputs.logits_per_image.softmax(dim=1) | |
| class_index = np.argmax(probs.detach().cpu().numpy()) | |
| return imagenet_classes[class_index] | |
| with gr.Blocks() as demo: | |
| gr.Markdown(MARKDOWN) | |
| with gr.Row(): | |
| image = gr.Image(image_mode='RGB', type='pil') | |
| output_text = gr.Textbox(label="Output") | |
| submit_button = gr.Button("Submit") | |
| submit_button.click(classify_image, inputs=[image], outputs=output_text) | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| fn=classify_image, | |
| inputs=[image], | |
| outputs=[output_text], | |
| cache_examples=True, | |
| run_on_click=True | |
| ) | |
| demo.launch(debug=False) | |