Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import spaces | |
| from transformers import AutoImageProcessor, DFineForObjectDetection | |
| from PIL import Image, ImageDraw, ImageFont | |
| import torch | |
| # Load model and processor (keep on CPU initially for Zero GPU) | |
| processor = AutoImageProcessor.from_pretrained("ustc-community/dfine-medium-obj2coco") | |
| model = DFineForObjectDetection.from_pretrained("ustc-community/dfine-medium-obj2coco") | |
| # IMPORTANT: For Zero GPU, keep model on CPU initially | |
| model = model.to("cpu") | |
| # Inference function with Zero GPU decorator | |
| # Specify duration for Zero GPU | |
| def detect_objects(image): | |
| # Move model to GPU only during inference | |
| device = torch.device("cuda") | |
| model.to(device) | |
| # Process image | |
| inputs = processor(images=image, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # Post-process results | |
| results = processor.post_process_object_detection( | |
| outputs, | |
| target_sizes=torch.tensor([image.size[::-1]]), | |
| threshold=0.3 | |
| ) | |
| # Filter to keep only logos | |
| if len(results) > 0: | |
| # Find the label ID for "logo" in the model's label mapping | |
| logo_label_id = None | |
| for label_id, label_name in model.config.id2label.items(): | |
| if label_name.lower() == "logo": | |
| logo_label_id = label_id | |
| break | |
| # Filter results to keep only logos | |
| if logo_label_id is not None and len(results[0]["boxes"]) > 0: | |
| logo_mask = results[0]["labels"] == logo_label_id | |
| results[0]["boxes"] = results[0]["boxes"][logo_mask] | |
| results[0]["labels"] = results[0]["labels"][logo_mask] | |
| results[0]["scores"] = results[0]["scores"][logo_mask] | |
| # Move model back to CPU after inference (important for Zero GPU) | |
| model.to("cpu") | |
| torch.cuda.empty_cache() # Clear GPU cache | |
| # Draw bounding boxes on the original image | |
| image_with_boxes = image.copy() | |
| draw = ImageDraw.Draw(image_with_boxes) | |
| # Try to use a larger font if available | |
| try: | |
| font = ImageFont.truetype("DejaVuSans.ttf", 24) | |
| except: | |
| try: | |
| font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf", 24) | |
| except: | |
| font = ImageFont.load_default() | |
| detection_results = [] | |
| if len(results) > 0 and len(results[0]["boxes"]) > 0: | |
| object_counter = 1 | |
| for box, label, score in zip(results[0]["boxes"], results[0]["labels"], results[0]["scores"]): | |
| # Convert tensors to CPU before processing | |
| box = box.cpu().tolist() | |
| label_id = label.cpu().item() | |
| score_val = score.cpu().item() | |
| # Calculate width and height | |
| width_px = box[2] - box[0] | |
| height_px = box[3] - box[1] | |
| # Convert to mm (divide by 11.91 and round to 2 decimals) | |
| width_mm = round(width_px / 11.91, 2) | |
| height_mm = round(height_px / 11.91, 2) | |
| # Round coordinates | |
| box = [round(x, 2) for x in box] | |
| # Get generic object name | |
| object_name = f"Object {object_counter}" | |
| label_text = object_name | |
| # Draw bounding box | |
| draw.rectangle(box, outline=(45, 136, 58), width=4) | |
| # Draw label only (no score, no size info) | |
| text_bbox = draw.textbbox((box[0], box[1] - 2), label_text, font=font) | |
| draw.rectangle([text_bbox[0]-2, text_bbox[1]-2, text_bbox[2]+2, text_bbox[3]+2], fill=(45, 136, 58)) | |
| draw.text((box[0], box[1] - 2), label_text, fill="white", font=font) | |
| # Store detection info with generic name | |
| detection_results.append({ | |
| "label": object_name, | |
| "actual_label": model.config.id2label[label_id], # Store actual label internally if needed | |
| "score": score_val, | |
| "box": box, | |
| "width_px": int(width_px), | |
| "height_px": int(height_px), | |
| "width_mm": width_mm, | |
| "height_mm": height_mm | |
| }) | |
| object_counter += 1 | |
| # Create detection summary | |
| summary = f"Detected {len(detection_results)} object(s)\n\n" | |
| for i, det in enumerate(detection_results[:10], 1): # Show top 10 detections | |
| summary += f"{det['label']}: {det['score']:.2%}\n" | |
| summary += f" Size: {det['width_px']} × {det['height_px']} px | {det['width_mm']} × {det['height_mm']} mm\n\n" | |
| summary += f" Bounding Box: TL({det['box'][0]}, {det['box'][1]}) TR({det['box'][2]}, {det['box'][1]}) BR({det['box'][2]}, {det['box'][3]}) BL({det['box'][0]}, {det['box'][3]})\n\n" | |
| return image_with_boxes, summary | |
| # Create Gradio interface | |
| with gr.Blocks(title="Logo Detection", css=""" | |
| .green-button { | |
| background-color: rgb(145, 236, 158) !important; | |
| border-color: rgb(145, 236, 158) !important; | |
| color: #333 !important; | |
| } | |
| .green-button:hover { | |
| background-color: rgb(125, 216, 138) !important; | |
| border-color: rgb(125, 216, 138) !important; | |
| } | |
| /* Override Gradio's orange with green */ | |
| .gr-button-primary { | |
| background-color: rgb(145, 236, 158) !important; | |
| border-color: rgb(145, 236, 158) !important; | |
| } | |
| /* Progress bars */ | |
| .progress-bar { | |
| background-color: rgb(145, 236, 158) !important; | |
| } | |
| /* Input focus states */ | |
| .gr-input:focus, .gr-textarea:focus { | |
| border-color: rgb(145, 236, 158) !important; | |
| outline-color: rgb(145, 236, 158) !important; | |
| } | |
| /* Override orange in various Gradio elements */ | |
| .gr-check-radio:checked { | |
| background-color: rgb(145, 236, 158) !important; | |
| border-color: rgb(145, 236, 158) !important; | |
| } | |
| /* Links */ | |
| a { | |
| color: rgb(45, 136, 58) !important; | |
| } | |
| /* Loading spinner */ | |
| .gr-loading { | |
| color: rgb(145, 236, 158) !important; | |
| } | |
| /* Slider handles and tracks */ | |
| .gr-slider input[type="range"]::-webkit-slider-thumb { | |
| background-color: rgb(145, 236, 158) !important; | |
| } | |
| .gr-slider input[type="range"]::-moz-range-thumb { | |
| background-color: rgb(145, 236, 158) !important; | |
| } | |
| /* Any element using Gradio's primary color */ | |
| [style*="rgb(249, 115, 22)"] { | |
| color: rgb(145, 236, 158) !important; | |
| } | |
| [style*="background-color: rgb(249, 115, 22)"] { | |
| background-color: rgb(145, 236, 158) !important; | |
| } | |
| """) as demo: | |
| gr.Markdown(""" | |
| # Logo Detection with Size Measurements | |
| Upload an image to detect logos. | |
| This Space uses Zero GPU for efficient inference. | |
| **Features:** | |
| - Logo detection only | |
| - Size display in pixels (blue label) | |
| - Size display in millimeters (green label) - converted using 11.91 pixels/mm | |
| - Objects are labeled generically as "Object 1", "Object 2", etc. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(type="pil", label="Input Image") | |
| detect_btn = gr.Button("Detect Objects", variant="primary", elem_classes="green-button") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Detection Results") | |
| output_text = gr.Textbox(label="Detection Summary", lines=12) | |
| # Set up event handler | |
| detect_btn.click( | |
| fn=detect_objects, | |
| inputs=input_image, | |
| outputs=[output_image, output_text] | |
| ) | |
| # Add examples (comment out if you don't have example images) | |
| # gr.Examples( | |
| # examples=[ | |
| # ["example1.jpg"], | |
| # ["example2.jpg"], | |
| # ], | |
| # inputs=input_image, | |
| # outputs=[output_image, output_text], | |
| # fn=detect_objects, | |
| # cache_examples=False # Don't cache for Zero GPU | |
| # ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |