Spaces:
Running
on
Zero
Running
on
Zero
| import logging | |
| import os | |
| from typing import Tuple, List, Optional | |
| from pathlib import Path | |
| import shutil | |
| import tempfile | |
| import numpy as np | |
| import cv2 | |
| import gradio as gr | |
| from PIL import Image | |
| from transformers import pipeline | |
| from transformers.image_utils import load_image | |
| import tqdm | |
| # Configuration constants | |
| CHECKPOINTS = [ | |
| "ustc-community/dfine_m_obj365", | |
| "ustc-community/dfine_n_coco", | |
| "ustc-community/dfine_s_coco", | |
| "ustc-community/dfine_m_coco", | |
| "ustc-community/dfine_l_coco", | |
| "ustc-community/dfine_x_coco", | |
| "ustc-community/dfine_s_obj365", | |
| "ustc-community/dfine_l_obj365", | |
| "ustc-community/dfine_x_obj365", | |
| "ustc-community/dfine_s_obj2coco", | |
| "ustc-community/dfine_m_obj2coco", | |
| "ustc-community/dfine_l_obj2coco_e25", | |
| "ustc-community/dfine_x_obj2coco", | |
| ] | |
| MAX_NUM_FRAMES = 300 | |
| DEFAULT_CHECKPOINT = CHECKPOINTS[0] | |
| DEFAULT_CONFIDENCE_THRESHOLD = 0.3 | |
| IMAGE_EXAMPLES = [ | |
| {"path": "./image.jpg", "use_url": False, "url": "", "label": "Local Image"}, | |
| { | |
| "path": None, | |
| "use_url": True, | |
| "url": "https://live.staticflickr.com/65535/33021460783_1646d43c54_b.jpg", | |
| "label": "Flickr Image", | |
| }, | |
| ] | |
| VIDEO_EXAMPLES = [ | |
| {"path": "./video.mp4", "label": "Local Video"}, | |
| ] | |
| ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".avi", ".mov"} | |
| logging.basicConfig( | |
| level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| VIDEO_OUTPUT_DIR = Path("static/videos") | |
| VIDEO_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| def detect_objects( | |
| image: Optional[Image.Image], | |
| checkpoint: str, | |
| confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD, | |
| use_url: bool = False, | |
| url: str = "", | |
| ) -> Tuple[ | |
| Optional[Tuple[Image.Image, List[Tuple[Tuple[int, int, int, int], str]]]], | |
| gr.Markdown, | |
| ]: | |
| if use_url and url: | |
| try: | |
| input_image = load_image(url) | |
| except Exception as e: | |
| logger.error(f"Failed to load image from URL {url}: {str(e)}") | |
| return None, gr.Markdown( | |
| f"**Error**: Failed to load image from URL: {str(e)}", visible=True | |
| ) | |
| elif image is not None: | |
| if not isinstance(image, Image.Image): | |
| logger.error("Input image is not a PIL Image") | |
| return None, gr.Markdown("**Error**: Invalid image format.", visible=True) | |
| input_image = image | |
| else: | |
| return None, gr.Markdown( | |
| "**Error**: Please provide an image or URL.", visible=True | |
| ) | |
| try: | |
| pipe = pipeline( | |
| "object-detection", | |
| model=checkpoint, | |
| image_processor=checkpoint, | |
| device="cpu", | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to initialize model pipeline for {checkpoint}: {str(e)}") | |
| return None, gr.Markdown( | |
| f"**Error**: Failed to load model: {str(e)}", visible=True | |
| ) | |
| results = pipe(input_image, threshold=confidence_threshold) | |
| img_width, img_height = input_image.size | |
| annotations = [] | |
| for result in results: | |
| score = result["score"] | |
| if score < confidence_threshold: | |
| continue | |
| label = f"{result['label']} ({score:.2f})" | |
| box = result["box"] | |
| # Validate and convert box to (xmin, ymin, xmax, ymax) | |
| bbox_xmin = max(0, int(box["xmin"])) | |
| bbox_ymin = max(0, int(box["ymin"])) | |
| bbox_xmax = min(img_width, int(box["xmax"])) | |
| bbox_ymax = min(img_height, int(box["ymax"])) | |
| if bbox_xmax <= bbox_xmin or bbox_ymax <= bbox_ymin: | |
| continue | |
| bounding_box = (bbox_xmin, bbox_ymin, bbox_xmax, bbox_ymax) | |
| annotations.append((bounding_box, label)) | |
| if not annotations: | |
| return (input_image, []), gr.Markdown( | |
| "**Warning**: No objects detected above the confidence threshold. Try lowering the threshold.", | |
| visible=True, | |
| ) | |
| return (input_image, annotations), gr.Markdown(visible=False) | |
| def annotate_frame( | |
| image: Image.Image, annotations: List[Tuple[Tuple[int, int, int, int], str]] | |
| ) -> np.ndarray: | |
| image_np = np.array(image) | |
| image_bgr = image_np[:, :, ::-1].copy() # RGB to BGR | |
| for (xmin, ymin, xmax, ymax), label in annotations: | |
| cv2.rectangle(image_bgr, (xmin, ymin), (xmax, ymax), (255, 255, 255), 2) | |
| text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0] | |
| cv2.rectangle( | |
| image_bgr, | |
| (xmin, ymin - text_size[1] - 4), | |
| (xmin + text_size[0], ymin), | |
| (255, 255, 255), | |
| -1, | |
| ) | |
| cv2.putText( | |
| image_bgr, | |
| label, | |
| (xmin, ymin - 4), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.5, | |
| (0, 0, 0), | |
| 1, | |
| ) | |
| return image_bgr | |
| def process_video( | |
| video_path: str, | |
| checkpoint: str, | |
| confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD, | |
| progress: gr.Progress = gr.Progress(track_tqdm=True), | |
| ) -> Tuple[Optional[str], gr.Markdown]: | |
| if not video_path or not os.path.isfile(video_path): | |
| logger.error(f"Invalid video path: {video_path}") | |
| return None, gr.Markdown( | |
| "**Error**: Please provide a valid video file.", visible=True | |
| ) | |
| ext = os.path.splitext(video_path)[1].lower() | |
| if ext not in ALLOWED_VIDEO_EXTENSIONS: | |
| logger.error(f"Unsupported video format: {ext}") | |
| return None, gr.Markdown( | |
| f"**Error**: Unsupported video format. Use MP4, AVI, or MOV.", visible=True | |
| ) | |
| try: | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| logger.error(f"Failed to open video: {video_path}") | |
| return None, gr.Markdown( | |
| "**Error**: Failed to open video file.", visible=True | |
| ) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| # Use H.264 codec for browser compatibility | |
| # fourcc = cv2.VideoWriter_fourcc(*"H264") | |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| temp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) | |
| writer = cv2.VideoWriter(temp_file.name, fourcc, fps, (width, height)) | |
| if not writer.isOpened(): | |
| logger.error("Failed to initialize video writer") | |
| cap.release() | |
| temp_file.close() | |
| os.unlink(temp_file.name) | |
| return None, gr.Markdown( | |
| "**Error**: Failed to initialize video writer.", visible=True | |
| ) | |
| frame_count = 0 | |
| for _ in tqdm.tqdm( | |
| range(min(MAX_NUM_FRAMES, num_frames)), desc="Processing video" | |
| ): | |
| ok, frame = cap.read() | |
| if not ok: | |
| break | |
| rgb_frame = frame[:, :, ::-1] # BGR to RGB | |
| pil_image = Image.fromarray(rgb_frame) | |
| (annotated_image, annotations), _ = detect_objects( | |
| pil_image, checkpoint, confidence_threshold, use_url=False, url="" | |
| ) | |
| if annotated_image is None: | |
| continue | |
| annotated_frame = annotate_frame(annotated_image, annotations) | |
| writer.write(annotated_frame) | |
| frame_count += 1 | |
| writer.release() | |
| cap.release() | |
| if frame_count == 0: | |
| logger.warning("No valid frames processed in video") | |
| temp_file.close() | |
| os.unlink(temp_file.name) | |
| return None, gr.Markdown( | |
| "**Warning**: No valid frames processed. Try a different video or threshold.", | |
| visible=True, | |
| ) | |
| temp_file.close() | |
| # Copy to persistent directory for Gradio access | |
| output_filename = f"output_{os.path.basename(temp_file.name)}" | |
| output_path = VIDEO_OUTPUT_DIR / output_filename | |
| shutil.copy(temp_file.name, output_path) | |
| os.unlink(temp_file.name) # Remove temporary file | |
| logger.info(f"Video saved to {output_path}") | |
| return str(output_path), gr.Markdown(visible=False) | |
| except Exception as e: | |
| logger.error(f"Video processing failed: {str(e)}") | |
| if "temp_file" in locals(): | |
| temp_file.close() | |
| if os.path.exists(temp_file.name): | |
| os.unlink(temp_file.name) | |
| return None, gr.Markdown( | |
| f"**Error**: Video processing failed: {str(e)}", visible=True | |
| ) | |
| def create_image_inputs() -> List[gr.components.Component]: | |
| return [ | |
| gr.Image( | |
| label="Upload Image", | |
| type="pil", | |
| sources=["upload", "webcam"], | |
| interactive=True, | |
| elem_classes="input-component", | |
| ), | |
| gr.Checkbox(label="Use Image URL Instead", value=False), | |
| gr.Textbox( | |
| label="Image URL", | |
| placeholder="https://example.com/image.jpg", | |
| visible=False, | |
| elem_classes="input-component", | |
| ), | |
| gr.Dropdown( | |
| choices=CHECKPOINTS, | |
| label="Select Model Checkpoint", | |
| value=DEFAULT_CHECKPOINT, | |
| elem_classes="input-component", | |
| ), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=DEFAULT_CONFIDENCE_THRESHOLD, | |
| step=0.1, | |
| label="Confidence Threshold", | |
| elem_classes="input-component", | |
| ), | |
| ] | |
| def create_video_inputs() -> List[gr.components.Component]: | |
| return [ | |
| gr.Video( | |
| label="Upload Video", | |
| sources=["upload"], | |
| interactive=True, | |
| format="mp4", # Ensure MP4 format | |
| elem_classes="input-component", | |
| ), | |
| gr.Dropdown( | |
| choices=CHECKPOINTS, | |
| label="Select Model Checkpoint", | |
| value=DEFAULT_CHECKPOINT, | |
| elem_classes="input-component", | |
| ), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=DEFAULT_CONFIDENCE_THRESHOLD, | |
| step=0.1, | |
| label="Confidence Threshold", | |
| elem_classes="input-component", | |
| ), | |
| ] | |
| def create_button_row(is_image: bool) -> List[gr.Button]: | |
| prefix = "Image" if is_image else "Video" | |
| return [ | |
| gr.Button( | |
| f"{prefix} Detect Objects", variant="primary", elem_classes="action-button" | |
| ), | |
| gr.Button(f"{prefix} Clear", variant="secondary", elem_classes="action-button"), | |
| ] | |
| # Gradio interface | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # Real-Time Object Detection Demo | |
| Experience state-of-the-art object detection with USTC's Dfine models. Upload an image or video, | |
| provide a URL, or try an example below. Select a model and adjust the confidence threshold to see detections in real time! | |
| """, | |
| elem_classes="header-text", | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("Image"): | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=300): | |
| with gr.Group(): | |
| ( | |
| image_input, | |
| use_url, | |
| url_input, | |
| image_checkpoint, | |
| image_confidence_threshold, | |
| ) = create_image_inputs() | |
| image_detect_button, image_clear_button = create_button_row( | |
| is_image=True | |
| ) | |
| with gr.Column(scale=2): | |
| image_output = gr.AnnotatedImage( | |
| label="Detection Results", | |
| show_label=True, | |
| color_map=None, | |
| elem_classes="output-component", | |
| ) | |
| image_error_message = gr.Markdown( | |
| visible=False, elem_classes="error-text" | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| example["path"], | |
| example["use_url"], | |
| example["url"], | |
| DEFAULT_CHECKPOINT, | |
| DEFAULT_CONFIDENCE_THRESHOLD, | |
| ] | |
| for example in IMAGE_EXAMPLES | |
| ], | |
| inputs=[ | |
| image_input, | |
| use_url, | |
| url_input, | |
| image_checkpoint, | |
| image_confidence_threshold, | |
| ], | |
| outputs=[image_output, image_error_message], | |
| fn=detect_objects, | |
| cache_examples=False, | |
| label="Select an image example to populate inputs", | |
| ) | |
| with gr.Tab("Video"): | |
| gr.Markdown( | |
| f"The input video will be truncated to {MAX_NUM_FRAMES} frames." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=300): | |
| with gr.Group(): | |
| video_input, video_checkpoint, video_confidence_threshold = ( | |
| create_video_inputs() | |
| ) | |
| video_detect_button, video_clear_button = create_button_row( | |
| is_image=False | |
| ) | |
| with gr.Column(scale=2): | |
| video_output = gr.Video( | |
| label="Detection Results", | |
| format="mp4", # Explicit MP4 format | |
| elem_classes="output-component", | |
| ) | |
| video_error_message = gr.Markdown( | |
| visible=False, elem_classes="error-text" | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| [example["path"], DEFAULT_CHECKPOINT, DEFAULT_CONFIDENCE_THRESHOLD] | |
| for example in VIDEO_EXAMPLES | |
| ], | |
| inputs=[video_input, video_checkpoint, video_confidence_threshold], | |
| outputs=[video_output, video_error_message], | |
| fn=process_video, | |
| cache_examples=False, | |
| label="Select a video example to populate inputs", | |
| ) | |
| # Dynamic visibility for URL input | |
| use_url.change( | |
| fn=lambda x: gr.update(visible=x), | |
| inputs=use_url, | |
| outputs=url_input, | |
| ) | |
| # Image clear button | |
| image_clear_button.click( | |
| fn=lambda: ( | |
| None, | |
| False, | |
| "", | |
| DEFAULT_CHECKPOINT, | |
| DEFAULT_CONFIDENCE_THRESHOLD, | |
| None, | |
| gr.Markdown(visible=False), | |
| ), | |
| outputs=[ | |
| image_input, | |
| use_url, | |
| url_input, | |
| image_checkpoint, | |
| image_confidence_threshold, | |
| image_output, | |
| image_error_message, | |
| ], | |
| ) | |
| # Video clear button | |
| video_clear_button.click( | |
| fn=lambda: ( | |
| None, | |
| DEFAULT_CHECKPOINT, | |
| DEFAULT_CONFIDENCE_THRESHOLD, | |
| None, | |
| gr.Markdown(visible=False), | |
| ), | |
| outputs=[ | |
| video_input, | |
| video_checkpoint, | |
| video_confidence_threshold, | |
| video_output, | |
| video_error_message, | |
| ], | |
| ) | |
| # Image detect button | |
| image_detect_button.click( | |
| fn=detect_objects, | |
| inputs=[ | |
| image_input, | |
| image_checkpoint, | |
| image_confidence_threshold, | |
| use_url, | |
| url_input, | |
| ], | |
| outputs=[image_output, image_error_message], | |
| ) | |
| # Video detect button | |
| video_detect_button.click( | |
| fn=process_video, | |
| inputs=[video_input, video_checkpoint, video_confidence_threshold], | |
| outputs=[video_output, video_error_message], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20).launch() | |