Spaces:
Running
on
Zero
Running
on
Zero
| """MobileCLIP2 Zero-Shot Classification Demo""" | |
| import torch | |
| import open_clip | |
| import gradio as gr | |
| from mobileclip.modules.common.mobileone import reparameterize_model | |
| import spaces | |
| ################################################################################ | |
| # Model Configuration | |
| ################################################################################ | |
| AVAILABLE_MODELS = { | |
| "MobileCLIP2-B": ("MobileCLIP2-B", "dfndr2b"), | |
| "MobileCLIP2-S0": ("MobileCLIP2-S0", "dfndr2b"), | |
| "MobileCLIP2-S2": ("MobileCLIP2-S2", "dfndr2b"), | |
| "MobileCLIP2-S3": ("MobileCLIP2-S3", "dfndr2b"), | |
| "MobileCLIP2-S4": ("MobileCLIP2-S4", "dfndr2b"), | |
| "MobileCLIP2-L-14": ("MobileCLIP2-L-14", "dfndr2b"), | |
| } | |
| # Cache for loaded models | |
| model_cache = {} | |
| ################################################################################ | |
| # Model Loading | |
| ################################################################################ | |
| def load_model(model_name): | |
| """Load and cache MobileCLIP2 model""" | |
| if model_name in model_cache: | |
| return model_cache[model_name] | |
| model_id, pretrained = AVAILABLE_MODELS[model_name] | |
| # Create model and preprocessing transforms | |
| model, _, preprocess = open_clip.create_model_and_transforms( | |
| model_id, | |
| pretrained=pretrained | |
| ) | |
| tokenizer = open_clip.get_tokenizer(model_id) | |
| # Reparameterize model for inference | |
| model = reparameterize_model(model.eval()) | |
| # Cache the model components | |
| model_cache[model_name] = { | |
| "model": model, | |
| "preprocess": preprocess, | |
| "tokenizer": tokenizer | |
| } | |
| return model_cache[model_name] | |
| ################################################################################ | |
| # Inference | |
| ################################################################################ | |
| def classify_image(image, candidate_labels, model_name): | |
| """ | |
| Classify image using selected MobileCLIP2 model | |
| Args: | |
| image: PIL Image | |
| candidate_labels: comma-separated string of labels | |
| model_name: selected model from dropdown | |
| Returns: | |
| Dictionary of label probabilities | |
| """ | |
| if image is None: | |
| return {} | |
| # Parse labels | |
| labels = [label.strip() for label in candidate_labels.split(",") if label.strip()] | |
| if not labels: | |
| return {} | |
| # Load model components | |
| model_components = load_model(model_name) | |
| model = model_components["model"] | |
| preprocess = model_components["preprocess"] | |
| tokenizer = model_components["tokenizer"] | |
| # Preprocess image | |
| image_tensor = preprocess(image.convert('RGB')).unsqueeze(0) | |
| # Tokenize text | |
| text_tokens = tokenizer(labels) | |
| # Run inference | |
| with torch.no_grad(), torch.cuda.amp.autocast(): | |
| image_features = model.encode_image(image_tensor) | |
| text_features = model.encode_text(text_tokens) | |
| # Normalize features | |
| image_features /= image_features.norm(dim=-1, keepdim=True) | |
| text_features /= text_features.norm(dim=-1, keepdim=True) | |
| # Compute similarity and probabilities | |
| text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) | |
| # Format output as dictionary | |
| output = {labels[i]: float(text_probs[0][i]) for i in range(len(labels))} | |
| return output | |
| ################################################################################ | |
| # Gradio Interface | |
| ################################################################################ | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# MobileCLIP2 Zero-Shot Image Classification") | |
| gr.Markdown( | |
| "Classify images using MobileCLIP2 models. Select a model, upload an image, " | |
| "and provide comma-separated class labels to get predictions." | |
| ) | |
| gr.Markdown("See [MobileCLIP2 model collection](https://huggingface.co/collections/apple/mobileclip2-68ac947dcb035c54bcd20c47).") | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_dropdown = gr.Dropdown( | |
| choices=list(AVAILABLE_MODELS.keys()), | |
| value="MobileCLIP2-S2", | |
| label="Select MobileCLIP2 Model", | |
| info="Choose which model to use for classification" | |
| ) | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| text_input = gr.Textbox( | |
| label="Class Labels (comma separated)", | |
| placeholder="e.g., a cat, a dog, a bird" | |
| ) | |
| run_button = gr.Button("Classify", variant="primary") | |
| with gr.Column(): | |
| output_label = gr.Label( | |
| label="Classification Results", | |
| num_top_classes=5 | |
| ) | |
| # Examples | |
| examples = [ | |
| ["MobileCLIP2-S2", "./baklava.jpg", "dessert on a plate, a serving of baklava, a plate and spoon"], | |
| ["MobileCLIP2-S2", "./cat.jpg", "a cat, two cats, three cats"], | |
| ["MobileCLIP2-S2", "./cat.jpg", "two sleeping cats, two cats playing, three cats laying down"], | |
| ] | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[model_dropdown, image_input, text_input], | |
| outputs=[output_label], | |
| fn=classify_image, | |
| cache_examples=False | |
| ) | |
| # Connect button | |
| run_button.click( | |
| fn=classify_image, | |
| inputs=[image_input, text_input, model_dropdown], | |
| outputs=[output_label] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |