import gradio as gr import spaces from transformers import AutoImageProcessor, DFineForObjectDetection from PIL import Image, ImageDraw, ImageFont import torch # Load model and processor (keep on CPU initially for Zero GPU) processor = AutoImageProcessor.from_pretrained("ustc-community/dfine-medium-obj2coco") model = DFineForObjectDetection.from_pretrained("ustc-community/dfine-medium-obj2coco") # IMPORTANT: For Zero GPU, keep model on CPU initially model = model.to("cpu") # Inference function with Zero GPU decorator @spaces.GPU(duration=15) # Specify duration for Zero GPU def detect_objects(image): # Move model to GPU only during inference device = torch.device("cuda") model.to(device) # Process image inputs = processor(images=image, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) # Post-process results results = processor.post_process_object_detection( outputs, target_sizes=torch.tensor([image.size[::-1]]), threshold=0.3 ) # Filter to keep only logos if len(results) > 0: # Find the label ID for "logo" in the model's label mapping logo_label_id = None for label_id, label_name in model.config.id2label.items(): if label_name.lower() == "logo": logo_label_id = label_id break # Filter results to keep only logos if logo_label_id is not None and len(results[0]["boxes"]) > 0: logo_mask = results[0]["labels"] == logo_label_id results[0]["boxes"] = results[0]["boxes"][logo_mask] results[0]["labels"] = results[0]["labels"][logo_mask] results[0]["scores"] = results[0]["scores"][logo_mask] # Move model back to CPU after inference (important for Zero GPU) model.to("cpu") torch.cuda.empty_cache() # Clear GPU cache # Draw bounding boxes on the original image image_with_boxes = image.copy() draw = ImageDraw.Draw(image_with_boxes) # Try to use a larger font if available try: font = ImageFont.truetype("DejaVuSans.ttf", 24) except: try: font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf", 24) except: font = ImageFont.load_default() detection_results = [] if len(results) > 0 and len(results[0]["boxes"]) > 0: object_counter = 1 for box, label, score in zip(results[0]["boxes"], results[0]["labels"], results[0]["scores"]): # Convert tensors to CPU before processing box = box.cpu().tolist() label_id = label.cpu().item() score_val = score.cpu().item() # Calculate width and height width_px = box[2] - box[0] height_px = box[3] - box[1] # Convert to mm (divide by 11.91 and round to 2 decimals) width_mm = round(width_px / 11.91, 2) height_mm = round(height_px / 11.91, 2) # Round coordinates box = [round(x, 2) for x in box] # Get generic object name object_name = f"Object {object_counter}" label_text = object_name # Draw bounding box draw.rectangle(box, outline=(45, 136, 58), width=4) # Draw label only (no score, no size info) text_bbox = draw.textbbox((box[0], box[1] - 2), label_text, font=font) draw.rectangle([text_bbox[0]-2, text_bbox[1]-2, text_bbox[2]+2, text_bbox[3]+2], fill=(45, 136, 58)) draw.text((box[0], box[1] - 2), label_text, fill="white", font=font) # Store detection info with generic name detection_results.append({ "label": object_name, "actual_label": model.config.id2label[label_id], # Store actual label internally if needed "score": score_val, "box": box, "width_px": int(width_px), "height_px": int(height_px), "width_mm": width_mm, "height_mm": height_mm }) object_counter += 1 # Create detection summary summary = f"Detected {len(detection_results)} object(s)\n\n" for i, det in enumerate(detection_results[:10], 1): # Show top 10 detections summary += f"{det['label']}: {det['score']:.2%}\n" summary += f" Size: {det['width_px']} × {det['height_px']} px | {det['width_mm']} × {det['height_mm']} mm\n\n" summary += f" Bounding Box: TL({det['box'][0]}, {det['box'][1]}) TR({det['box'][2]}, {det['box'][1]}) BR({det['box'][2]}, {det['box'][3]}) BL({det['box'][0]}, {det['box'][3]})\n\n" return image_with_boxes, summary # Create Gradio interface with gr.Blocks(title="Logo Detection", css=""" .green-button { background-color: rgb(145, 236, 158) !important; border-color: rgb(145, 236, 158) !important; color: #333 !important; } .green-button:hover { background-color: rgb(125, 216, 138) !important; border-color: rgb(125, 216, 138) !important; } /* Override Gradio's orange with green */ .gr-button-primary { background-color: rgb(145, 236, 158) !important; border-color: rgb(145, 236, 158) !important; } /* Progress bars */ .progress-bar { background-color: rgb(145, 236, 158) !important; } /* Input focus states */ .gr-input:focus, .gr-textarea:focus { border-color: rgb(145, 236, 158) !important; outline-color: rgb(145, 236, 158) !important; } /* Override orange in various Gradio elements */ .gr-check-radio:checked { background-color: rgb(145, 236, 158) !important; border-color: rgb(145, 236, 158) !important; } /* Links */ a { color: rgb(45, 136, 58) !important; } /* Loading spinner */ .gr-loading { color: rgb(145, 236, 158) !important; } /* Slider handles and tracks */ .gr-slider input[type="range"]::-webkit-slider-thumb { background-color: rgb(145, 236, 158) !important; } .gr-slider input[type="range"]::-moz-range-thumb { background-color: rgb(145, 236, 158) !important; } /* Any element using Gradio's primary color */ [style*="rgb(249, 115, 22)"] { color: rgb(145, 236, 158) !important; } [style*="background-color: rgb(249, 115, 22)"] { background-color: rgb(145, 236, 158) !important; } """) as demo: gr.Markdown(""" # Logo Detection with Size Measurements Upload an image to detect logos. This Space uses Zero GPU for efficient inference. **Features:** - Logo detection only - Size display in pixels (blue label) - Size display in millimeters (green label) - converted using 11.91 pixels/mm - Objects are labeled generically as "Object 1", "Object 2", etc. """) with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Input Image") detect_btn = gr.Button("Detect Objects", variant="primary", elem_classes="green-button") with gr.Column(): output_image = gr.Image(label="Detection Results") output_text = gr.Textbox(label="Detection Summary", lines=12) # Set up event handler detect_btn.click( fn=detect_objects, inputs=input_image, outputs=[output_image, output_text] ) # Add examples (comment out if you don't have example images) # gr.Examples( # examples=[ # ["example1.jpg"], # ["example2.jpg"], # ], # inputs=input_image, # outputs=[output_image, output_text], # fn=detect_objects, # cache_examples=False # Don't cache for Zero GPU # ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)