Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,506 Bytes
64305ae 60e70c4 64305ae 60e70c4 64305ae b78c3ee 64305ae adfec35 64305ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
"""MobileCLIP2 Zero-Shot Classification Demo"""
import torch
import open_clip
import gradio as gr
from mobileclip.modules.common.mobileone import reparameterize_model
import spaces
################################################################################
# Model Configuration
################################################################################
AVAILABLE_MODELS = {
"MobileCLIP2-B": ("MobileCLIP2-B", "dfndr2b"),
"MobileCLIP2-S0": ("MobileCLIP2-S0", "dfndr2b"),
"MobileCLIP2-S2": ("MobileCLIP2-S2", "dfndr2b"),
"MobileCLIP2-S3": ("MobileCLIP2-S3", "dfndr2b"),
"MobileCLIP2-S4": ("MobileCLIP2-S4", "dfndr2b"),
"MobileCLIP2-L-14": ("MobileCLIP2-L-14", "dfndr2b"),
}
# Cache for loaded models
model_cache = {}
################################################################################
# Model Loading
################################################################################
def load_model(model_name):
"""Load and cache MobileCLIP2 model"""
if model_name in model_cache:
return model_cache[model_name]
model_id, pretrained = AVAILABLE_MODELS[model_name]
# Create model and preprocessing transforms
model, _, preprocess = open_clip.create_model_and_transforms(
model_id,
pretrained=pretrained
)
tokenizer = open_clip.get_tokenizer(model_id)
# Reparameterize model for inference
model = reparameterize_model(model.eval())
# Cache the model components
model_cache[model_name] = {
"model": model,
"preprocess": preprocess,
"tokenizer": tokenizer
}
return model_cache[model_name]
################################################################################
# Inference
################################################################################
@spaces.GPU(duration=120)
def classify_image(image, candidate_labels, model_name):
"""
Classify image using selected MobileCLIP2 model
Args:
image: PIL Image
candidate_labels: comma-separated string of labels
model_name: selected model from dropdown
Returns:
Dictionary of label probabilities
"""
if image is None:
return {}
# Parse labels
labels = [label.strip() for label in candidate_labels.split(",") if label.strip()]
if not labels:
return {}
# Load model components
model_components = load_model(model_name)
model = model_components["model"]
preprocess = model_components["preprocess"]
tokenizer = model_components["tokenizer"]
# Preprocess image
image_tensor = preprocess(image.convert('RGB')).unsqueeze(0)
# Tokenize text
text_tokens = tokenizer(labels)
# Run inference
with torch.no_grad(), torch.cuda.amp.autocast():
image_features = model.encode_image(image_tensor)
text_features = model.encode_text(text_tokens)
# Normalize features
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
# Compute similarity and probabilities
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
# Format output as dictionary
output = {labels[i]: float(text_probs[0][i]) for i in range(len(labels))}
return output
################################################################################
# Gradio Interface
################################################################################
with gr.Blocks() as demo:
gr.Markdown("# MobileCLIP2 Zero-Shot Image Classification")
gr.Markdown(
"Classify images using MobileCLIP2 models. Select a model, upload an image, "
"and provide comma-separated class labels to get predictions."
)
gr.Markdown("See [MobileCLIP2 model collection](https://huggingface.co/collections/apple/mobileclip2-68ac947dcb035c54bcd20c47).")
with gr.Row():
with gr.Column():
model_dropdown = gr.Dropdown(
choices=list(AVAILABLE_MODELS.keys()),
value="MobileCLIP2-S2",
label="Select MobileCLIP2 Model",
info="Choose which model to use for classification"
)
image_input = gr.Image(type="pil", label="Upload Image")
text_input = gr.Textbox(
label="Class Labels (comma separated)",
placeholder="e.g., a cat, a dog, a bird"
)
run_button = gr.Button("Classify", variant="primary")
with gr.Column():
output_label = gr.Label(
label="Classification Results",
num_top_classes=5
)
# Examples
examples = [
["MobileCLIP2-S2", "./baklava.jpg", "dessert on a plate, a serving of baklava, a plate and spoon"],
["MobileCLIP2-S2", "./cat.jpg", "a cat, two cats, three cats"],
["MobileCLIP2-S2", "./cat.jpg", "two sleeping cats, two cats playing, three cats laying down"],
]
gr.Examples(
examples=examples,
inputs=[model_dropdown, image_input, text_input],
outputs=[output_label],
fn=classify_image,
cache_examples=False
)
# Connect button
run_button.click(
fn=classify_image,
inputs=[image_input, text_input, model_dropdown],
outputs=[output_label]
)
if __name__ == "__main__":
demo.launch() |