File size: 5,506 Bytes
64305ae
 
 
 
 
 
60e70c4
 
64305ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60e70c4
64305ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b78c3ee
64305ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adfec35
64305ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""MobileCLIP2 Zero-Shot Classification Demo"""
import torch
import open_clip
import gradio as gr
from mobileclip.modules.common.mobileone import reparameterize_model

import spaces

################################################################################
# Model Configuration
################################################################################
AVAILABLE_MODELS = {
    "MobileCLIP2-B": ("MobileCLIP2-B", "dfndr2b"),
    "MobileCLIP2-S0": ("MobileCLIP2-S0", "dfndr2b"),
    "MobileCLIP2-S2": ("MobileCLIP2-S2", "dfndr2b"),
    "MobileCLIP2-S3": ("MobileCLIP2-S3", "dfndr2b"),
    "MobileCLIP2-S4": ("MobileCLIP2-S4", "dfndr2b"),
    "MobileCLIP2-L-14": ("MobileCLIP2-L-14", "dfndr2b"),
}

# Cache for loaded models
model_cache = {}

################################################################################
# Model Loading
################################################################################
def load_model(model_name):
    """Load and cache MobileCLIP2 model"""
    if model_name in model_cache:
        return model_cache[model_name]
    
    model_id, pretrained = AVAILABLE_MODELS[model_name]
    
    # Create model and preprocessing transforms
    model, _, preprocess = open_clip.create_model_and_transforms(
        model_id, 
        pretrained=pretrained
    )
    tokenizer = open_clip.get_tokenizer(model_id)
    
    # Reparameterize model for inference
    model = reparameterize_model(model.eval())
    
    # Cache the model components
    model_cache[model_name] = {
        "model": model,
        "preprocess": preprocess,
        "tokenizer": tokenizer
    }
    
    return model_cache[model_name]

################################################################################
# Inference
################################################################################
@spaces.GPU(duration=120)
def classify_image(image, candidate_labels, model_name):
    """
    Classify image using selected MobileCLIP2 model
    
    Args:
        image: PIL Image
        candidate_labels: comma-separated string of labels
        model_name: selected model from dropdown
    
    Returns:
        Dictionary of label probabilities
    """
    if image is None:
        return {}
    
    # Parse labels
    labels = [label.strip() for label in candidate_labels.split(",") if label.strip()]
    
    if not labels:
        return {}
    
    # Load model components
    model_components = load_model(model_name)
    model = model_components["model"]
    preprocess = model_components["preprocess"]
    tokenizer = model_components["tokenizer"]
    
    # Preprocess image
    image_tensor = preprocess(image.convert('RGB')).unsqueeze(0)
    
    # Tokenize text
    text_tokens = tokenizer(labels)
    
    # Run inference
    with torch.no_grad(), torch.cuda.amp.autocast():
        image_features = model.encode_image(image_tensor)
        text_features = model.encode_text(text_tokens)
        
        # Normalize features
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        
        # Compute similarity and probabilities
        text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
    
    # Format output as dictionary
    output = {labels[i]: float(text_probs[0][i]) for i in range(len(labels))}
    
    return output

################################################################################
# Gradio Interface
################################################################################
with gr.Blocks() as demo:
    gr.Markdown("# MobileCLIP2 Zero-Shot Image Classification")
    gr.Markdown(
        "Classify images using MobileCLIP2 models. Select a model, upload an image, "
        "and provide comma-separated class labels to get predictions."
    )
    gr.Markdown("See [MobileCLIP2 model collection](https://huggingface.co/collections/apple/mobileclip2-68ac947dcb035c54bcd20c47).")
    
    with gr.Row():
        with gr.Column():
            model_dropdown = gr.Dropdown(
                choices=list(AVAILABLE_MODELS.keys()),
                value="MobileCLIP2-S2",
                label="Select MobileCLIP2 Model",
                info="Choose which model to use for classification"
            )
            image_input = gr.Image(type="pil", label="Upload Image")
            text_input = gr.Textbox(
                label="Class Labels (comma separated)",
                placeholder="e.g., a cat, a dog, a bird"
            )
            run_button = gr.Button("Classify", variant="primary")
        
        with gr.Column():
            output_label = gr.Label(
                label="Classification Results",
                num_top_classes=5
            )
    
    # Examples
    examples = [
        ["MobileCLIP2-S2", "./baklava.jpg", "dessert on a plate, a serving of baklava, a plate and spoon"],
        ["MobileCLIP2-S2", "./cat.jpg", "a cat, two cats, three cats"],
        ["MobileCLIP2-S2", "./cat.jpg", "two sleeping cats, two cats playing, three cats laying down"],
    ]
    
    gr.Examples(
        examples=examples,
        inputs=[model_dropdown, image_input, text_input],
        outputs=[output_label],
        fn=classify_image,
        cache_examples=False
    )
    
    # Connect button
    run_button.click(
        fn=classify_image,
        inputs=[image_input, text_input, model_dropdown],
        outputs=[output_label]
    )

if __name__ == "__main__":
    demo.launch()