vbt2025's picture
Update app.py
a6a7e77 verified
import gradio as gr
import spaces
from transformers import AutoImageProcessor, DFineForObjectDetection
from PIL import Image, ImageDraw, ImageFont
import torch
# Load model and processor (keep on CPU initially for Zero GPU)
processor = AutoImageProcessor.from_pretrained("ustc-community/dfine-medium-obj2coco")
model = DFineForObjectDetection.from_pretrained("ustc-community/dfine-medium-obj2coco")
# IMPORTANT: For Zero GPU, keep model on CPU initially
model = model.to("cpu")
# Inference function with Zero GPU decorator
@spaces.GPU(duration=15) # Specify duration for Zero GPU
def detect_objects(image):
# Move model to GPU only during inference
device = torch.device("cuda")
model.to(device)
# Process image
inputs = processor(images=image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
# Post-process results
results = processor.post_process_object_detection(
outputs,
target_sizes=torch.tensor([image.size[::-1]]),
threshold=0.3
)
# Filter to keep only logos
if len(results) > 0:
# Find the label ID for "logo" in the model's label mapping
logo_label_id = None
for label_id, label_name in model.config.id2label.items():
if label_name.lower() == "logo":
logo_label_id = label_id
break
# Filter results to keep only logos
if logo_label_id is not None and len(results[0]["boxes"]) > 0:
logo_mask = results[0]["labels"] == logo_label_id
results[0]["boxes"] = results[0]["boxes"][logo_mask]
results[0]["labels"] = results[0]["labels"][logo_mask]
results[0]["scores"] = results[0]["scores"][logo_mask]
# Move model back to CPU after inference (important for Zero GPU)
model.to("cpu")
torch.cuda.empty_cache() # Clear GPU cache
# Draw bounding boxes on the original image
image_with_boxes = image.copy()
draw = ImageDraw.Draw(image_with_boxes)
# Try to use a larger font if available
try:
font = ImageFont.truetype("DejaVuSans.ttf", 24)
except:
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf", 24)
except:
font = ImageFont.load_default()
detection_results = []
if len(results) > 0 and len(results[0]["boxes"]) > 0:
object_counter = 1
for box, label, score in zip(results[0]["boxes"], results[0]["labels"], results[0]["scores"]):
# Convert tensors to CPU before processing
box = box.cpu().tolist()
label_id = label.cpu().item()
score_val = score.cpu().item()
# Calculate width and height
width_px = box[2] - box[0]
height_px = box[3] - box[1]
# Convert to mm (divide by 11.91 and round to 2 decimals)
width_mm = round(width_px / 11.91, 2)
height_mm = round(height_px / 11.91, 2)
# Round coordinates
box = [round(x, 2) for x in box]
# Get generic object name
object_name = f"Object {object_counter}"
label_text = object_name
# Draw bounding box
draw.rectangle(box, outline=(45, 136, 58), width=4)
# Draw label only (no score, no size info)
text_bbox = draw.textbbox((box[0], box[1] - 2), label_text, font=font)
draw.rectangle([text_bbox[0]-2, text_bbox[1]-2, text_bbox[2]+2, text_bbox[3]+2], fill=(45, 136, 58))
draw.text((box[0], box[1] - 2), label_text, fill="white", font=font)
# Store detection info with generic name
detection_results.append({
"label": object_name,
"actual_label": model.config.id2label[label_id], # Store actual label internally if needed
"score": score_val,
"box": box,
"width_px": int(width_px),
"height_px": int(height_px),
"width_mm": width_mm,
"height_mm": height_mm
})
object_counter += 1
# Create detection summary
summary = f"Detected {len(detection_results)} object(s)\n\n"
for i, det in enumerate(detection_results[:10], 1): # Show top 10 detections
summary += f"{det['label']}: {det['score']:.2%}\n"
summary += f" Size: {det['width_px']} × {det['height_px']} px | {det['width_mm']} × {det['height_mm']} mm\n\n"
summary += f" Bounding Box: TL({det['box'][0]}, {det['box'][1]}) TR({det['box'][2]}, {det['box'][1]}) BR({det['box'][2]}, {det['box'][3]}) BL({det['box'][0]}, {det['box'][3]})\n\n"
return image_with_boxes, summary
# Create Gradio interface
with gr.Blocks(title="Logo Detection", css="""
.green-button {
background-color: rgb(145, 236, 158) !important;
border-color: rgb(145, 236, 158) !important;
color: #333 !important;
}
.green-button:hover {
background-color: rgb(125, 216, 138) !important;
border-color: rgb(125, 216, 138) !important;
}
/* Override Gradio's orange with green */
.gr-button-primary {
background-color: rgb(145, 236, 158) !important;
border-color: rgb(145, 236, 158) !important;
}
/* Progress bars */
.progress-bar {
background-color: rgb(145, 236, 158) !important;
}
/* Input focus states */
.gr-input:focus, .gr-textarea:focus {
border-color: rgb(145, 236, 158) !important;
outline-color: rgb(145, 236, 158) !important;
}
/* Override orange in various Gradio elements */
.gr-check-radio:checked {
background-color: rgb(145, 236, 158) !important;
border-color: rgb(145, 236, 158) !important;
}
/* Links */
a {
color: rgb(45, 136, 58) !important;
}
/* Loading spinner */
.gr-loading {
color: rgb(145, 236, 158) !important;
}
/* Slider handles and tracks */
.gr-slider input[type="range"]::-webkit-slider-thumb {
background-color: rgb(145, 236, 158) !important;
}
.gr-slider input[type="range"]::-moz-range-thumb {
background-color: rgb(145, 236, 158) !important;
}
/* Any element using Gradio's primary color */
[style*="rgb(249, 115, 22)"] {
color: rgb(145, 236, 158) !important;
}
[style*="background-color: rgb(249, 115, 22)"] {
background-color: rgb(145, 236, 158) !important;
}
""") as demo:
gr.Markdown("""
# Logo Detection with Size Measurements
Upload an image to detect logos.
This Space uses Zero GPU for efficient inference.
**Features:**
- Logo detection only
- Size display in pixels (blue label)
- Size display in millimeters (green label) - converted using 11.91 pixels/mm
- Objects are labeled generically as "Object 1", "Object 2", etc.
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
detect_btn = gr.Button("Detect Objects", variant="primary", elem_classes="green-button")
with gr.Column():
output_image = gr.Image(label="Detection Results")
output_text = gr.Textbox(label="Detection Summary", lines=12)
# Set up event handler
detect_btn.click(
fn=detect_objects,
inputs=input_image,
outputs=[output_image, output_text]
)
# Add examples (comment out if you don't have example images)
# gr.Examples(
# examples=[
# ["example1.jpg"],
# ["example2.jpg"],
# ],
# inputs=input_image,
# outputs=[output_image, output_text],
# fn=detect_objects,
# cache_examples=False # Don't cache for Zero GPU
# )
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)