Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| import spaces | |
| import numpy as np | |
| from PIL import Image | |
| import json, os, random | |
| import gradio as gr | |
| import torchvision.transforms.functional as TF | |
| from safetensors.torch import load_file # Import the load_file function from safetensors | |
| from matplotlib import cm | |
| from huggingface_hub import hf_hub_download | |
| from typing import Tuple | |
| from models import get_model | |
| def resize_density_map(x: Tensor, size: Tuple[int, int]) -> Tensor: | |
| x_sum = torch.sum(x, dim=(-1, -2)) | |
| x = F.interpolate(x, size=size, mode="bilinear") | |
| scale_factor = torch.nan_to_num(torch.sum(x, dim=(-1, -2)) / x_sum, nan=0.0, posinf=0.0, neginf=0.0) | |
| return x * scale_factor | |
| def init_seeds(seed: int) -> None: | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| mean = (0.485, 0.456, 0.406) | |
| std = (0.229, 0.224, 0.225) | |
| alpha = 0.8 | |
| init_seeds(42) | |
| # ----------------------------- | |
| # Define the model architecture | |
| # ----------------------------- | |
| truncation = 4 | |
| reduction = 8 | |
| granularity = "fine" | |
| anchor_points = "average" | |
| input_size = 224 | |
| # Comment the lines below to test non-CLIP models. | |
| prompt_type = "word" | |
| num_vpt = 32 | |
| vpt_drop = 0. | |
| deep_vpt = True | |
| repo_id = "Yiming-M/CLIP-EBC" | |
| model_configs = { | |
| "CLIP_EBC_ViT_L_14": { | |
| "model_name": "clip_vit_l_14", | |
| "filename": "nwpu_weights/CLIP_EBC_ViT_L_14/model.safetensors", | |
| }, | |
| "CLIP_EBC_ViT_B_16": { | |
| "model_name": "clip_vit_b_16", | |
| "filename": "nwpu_weights/CLIP_EBC_ViT_B_16/model.safetensors", | |
| }, | |
| } | |
| # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| device = "cuda" | |
| if truncation is None: # regression, no truncation. | |
| bins, anchor_points = None, None | |
| else: | |
| with open(os.path.join("configs", f"reduction_{reduction}.json"), "r") as f: | |
| config = json.load(f)[str(truncation)]["nwpu"] | |
| bins = config["bins"][granularity] | |
| anchor_points = config["anchor_points"][granularity]["average"] if anchor_points == "average" else config["anchor_points"][granularity]["middle"] | |
| bins = [(float(b[0]), float(b[1])) for b in bins] | |
| anchor_points = [float(p) for p in anchor_points] | |
| # Use a global reference to store the model instance | |
| loaded_model = None | |
| def load_model(model_choice: str): | |
| global loaded_model | |
| config = model_configs[model_choice] | |
| model_name = config["model_name"] | |
| filename = config["filename"] | |
| # Prepare bins and anchor_points if using classification | |
| if truncation is None: | |
| bins_, anchor_points_ = None, None | |
| else: | |
| with open(os.path.join("configs", f"reduction_{reduction}.json"), "r") as f: | |
| config_json = json.load(f)[str(truncation)]["nwpu"] | |
| bins_ = config_json["bins"][granularity] | |
| anchor_points_ = config_json["anchor_points"][granularity]["average"] if anchor_points == "average" else config_json["anchor_points"][granularity]["middle"] | |
| bins_ = [(float(b[0]), float(b[1])) for b in bins_] | |
| anchor_points_ = [float(p) for p in anchor_points_] | |
| # Build model | |
| model = get_model( | |
| backbone=model_name, | |
| input_size=input_size, | |
| reduction=reduction, | |
| bins=bins_, | |
| anchor_points=anchor_points_, | |
| prompt_type=prompt_type, | |
| num_vpt=num_vpt, | |
| vpt_drop=vpt_drop, | |
| deep_vpt=deep_vpt, | |
| ) | |
| weights_path = hf_hub_download(repo_id, filename) | |
| state_dict = load_file(weights_path) | |
| new_state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()} | |
| model.load_state_dict(new_state_dict) | |
| model.eval() | |
| loaded_model = model | |
| # ----------------------------- | |
| # Preprocessing function | |
| # ----------------------------- | |
| # Adjust the image transforms to match what your model expects. | |
| def transform(image: Image.Image): | |
| assert isinstance(image, Image.Image), "Input must be a PIL Image" | |
| image_tensor = TF.to_tensor(image) | |
| image_height, image_width = image_tensor.shape[-2:] | |
| if image_height < input_size or image_width < input_size: | |
| # Find the ratio to resize the image while maintaining the aspect ratio | |
| ratio = max(input_size / image_height, input_size / image_width) | |
| new_height = int(image_height * ratio) + 1 | |
| new_width = int(image_width * ratio) + 1 | |
| image_tensor = TF.resize(image_tensor, (new_height, new_width), interpolation=TF.InterpolationMode.BICUBIC, antialias=True) | |
| image_tensor = TF.normalize(image_tensor, mean=mean, std=std) | |
| return image_tensor.unsqueeze(0) # Add batch dimension | |
| # ----------------------------- | |
| # Inference function | |
| # ----------------------------- | |
| def predict(image: Image.Image, model_choice: str = "CLIP_EBC_ViT_B_16"): | |
| """ | |
| Given an input image, preprocess it, run the model to obtain a density map, | |
| compute the total crowd count, and prepare the density map for display. | |
| """ | |
| global loaded_model | |
| if loaded_model is None or model_configs[model_choice]["model_name"] not in loaded_model.__class__.__name__: | |
| load_model(model_choice) | |
| loaded_model.to(device) | |
| # Preprocess the image | |
| input_width, input_height = image.size | |
| input_tensor = transform(image).to(device) # shape: (1, 3, H, W) | |
| with torch.no_grad(): | |
| density_map = loaded_model(input_tensor) # expected shape: (1, 1, H, W) | |
| total_count = density_map.sum().item() | |
| resized_density_map = resize_density_map(density_map, (input_height, input_width)).cpu().squeeze().numpy() | |
| # Normalize the density map for display purposes | |
| eps = 1e-8 | |
| density_map_norm = (resized_density_map - resized_density_map.min()) / (resized_density_map.max() - resized_density_map.min() + eps) | |
| # Apply a colormap (e.g., 'jet') to get an RGBA image | |
| colormap = cm.get_cmap("jet") | |
| # The colormap returns values in [0,1]. Scale to [0,255] and convert to uint8. | |
| density_map_color = (colormap(density_map_norm) * 255).astype(np.uint8) | |
| density_map_color_img = Image.fromarray(density_map_color).convert("RGBA") | |
| # Ensure the original image is in RGBA format. | |
| image_rgba = image.convert("RGBA") | |
| overlayed_image = Image.blend(image_rgba, density_map_color_img, alpha=alpha) | |
| return image, overlayed_image, f"Predicted Count: {total_count:.2f}" | |
| # ----------------------------- | |
| # Build Gradio Interface using Blocks for a two-column layout | |
| # ----------------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Crowd Counting by CLIP-EBC (Pre-trained on NWPU-Crowd)") | |
| gr.Markdown("Upload an image or select an example below to see the predicted crowd density map and total count.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_choice = gr.Dropdown( | |
| choices=list(model_configs.keys()), | |
| value="CLIP_EBC_ViT_B_16", | |
| label="Select Model" | |
| ) | |
| input_img = gr.Image(label="Input Image", sources=["upload", "clipboard"], type="pil") | |
| submit_btn = gr.Button("Predict") | |
| with gr.Column(): | |
| output_img = gr.Image(label="Predicted Density Map", type="pil") | |
| output_text = gr.Textbox(label="Total Count") | |
| submit_btn.click(fn=predict, inputs=[input_img, model_choice], outputs=[input_img, output_img, output_text]) | |
| gr.Examples( | |
| examples=[ | |
| ["example1.jpg"], | |
| ["example2.jpg"], | |
| ["example3.jpg"], | |
| ["example4.jpg"], | |
| ["example5.jpg"], | |
| ], | |
| inputs=input_img, | |
| label="Try an example" | |
| ) | |
| demo.launch() |