mrdbourke's picture
Update app.py
b78c3ee verified
"""MobileCLIP2 Zero-Shot Classification Demo"""
import torch
import open_clip
import gradio as gr
from mobileclip.modules.common.mobileone import reparameterize_model
import spaces
################################################################################
# Model Configuration
################################################################################
AVAILABLE_MODELS = {
"MobileCLIP2-B": ("MobileCLIP2-B", "dfndr2b"),
"MobileCLIP2-S0": ("MobileCLIP2-S0", "dfndr2b"),
"MobileCLIP2-S2": ("MobileCLIP2-S2", "dfndr2b"),
"MobileCLIP2-S3": ("MobileCLIP2-S3", "dfndr2b"),
"MobileCLIP2-S4": ("MobileCLIP2-S4", "dfndr2b"),
"MobileCLIP2-L-14": ("MobileCLIP2-L-14", "dfndr2b"),
}
# Cache for loaded models
model_cache = {}
################################################################################
# Model Loading
################################################################################
def load_model(model_name):
"""Load and cache MobileCLIP2 model"""
if model_name in model_cache:
return model_cache[model_name]
model_id, pretrained = AVAILABLE_MODELS[model_name]
# Create model and preprocessing transforms
model, _, preprocess = open_clip.create_model_and_transforms(
model_id,
pretrained=pretrained
)
tokenizer = open_clip.get_tokenizer(model_id)
# Reparameterize model for inference
model = reparameterize_model(model.eval())
# Cache the model components
model_cache[model_name] = {
"model": model,
"preprocess": preprocess,
"tokenizer": tokenizer
}
return model_cache[model_name]
################################################################################
# Inference
################################################################################
@spaces.GPU(duration=120)
def classify_image(image, candidate_labels, model_name):
"""
Classify image using selected MobileCLIP2 model
Args:
image: PIL Image
candidate_labels: comma-separated string of labels
model_name: selected model from dropdown
Returns:
Dictionary of label probabilities
"""
if image is None:
return {}
# Parse labels
labels = [label.strip() for label in candidate_labels.split(",") if label.strip()]
if not labels:
return {}
# Load model components
model_components = load_model(model_name)
model = model_components["model"]
preprocess = model_components["preprocess"]
tokenizer = model_components["tokenizer"]
# Preprocess image
image_tensor = preprocess(image.convert('RGB')).unsqueeze(0)
# Tokenize text
text_tokens = tokenizer(labels)
# Run inference
with torch.no_grad(), torch.cuda.amp.autocast():
image_features = model.encode_image(image_tensor)
text_features = model.encode_text(text_tokens)
# Normalize features
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
# Compute similarity and probabilities
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
# Format output as dictionary
output = {labels[i]: float(text_probs[0][i]) for i in range(len(labels))}
return output
################################################################################
# Gradio Interface
################################################################################
with gr.Blocks() as demo:
gr.Markdown("# MobileCLIP2 Zero-Shot Image Classification")
gr.Markdown(
"Classify images using MobileCLIP2 models. Select a model, upload an image, "
"and provide comma-separated class labels to get predictions."
)
gr.Markdown("See [MobileCLIP2 model collection](https://huggingface.co/collections/apple/mobileclip2-68ac947dcb035c54bcd20c47).")
with gr.Row():
with gr.Column():
model_dropdown = gr.Dropdown(
choices=list(AVAILABLE_MODELS.keys()),
value="MobileCLIP2-S2",
label="Select MobileCLIP2 Model",
info="Choose which model to use for classification"
)
image_input = gr.Image(type="pil", label="Upload Image")
text_input = gr.Textbox(
label="Class Labels (comma separated)",
placeholder="e.g., a cat, a dog, a bird"
)
run_button = gr.Button("Classify", variant="primary")
with gr.Column():
output_label = gr.Label(
label="Classification Results",
num_top_classes=5
)
# Examples
examples = [
["MobileCLIP2-S2", "./baklava.jpg", "dessert on a plate, a serving of baklava, a plate and spoon"],
["MobileCLIP2-S2", "./cat.jpg", "a cat, two cats, three cats"],
["MobileCLIP2-S2", "./cat.jpg", "two sleeping cats, two cats playing, three cats laying down"],
]
gr.Examples(
examples=examples,
inputs=[model_dropdown, image_input, text_input],
outputs=[output_label],
fn=classify_image,
cache_examples=False
)
# Connect button
run_button.click(
fn=classify_image,
inputs=[image_input, text_input, model_dropdown],
outputs=[output_label]
)
if __name__ == "__main__":
demo.launch()