Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoImageProcessor, TimesformerForVideoClassification | |
| import cv2 | |
| from PIL import Image | |
| import numpy as np | |
| import time | |
| from collections import deque | |
| import base64 | |
| import io | |
| HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection" | |
| MODEL_INPUT_NUM_FRAMES = 8 | |
| TARGET_IMAGE_HEIGHT = 224 | |
| TARGET_IMAGE_WIDTH = 224 | |
| RAW_RECORDING_DURATION_SECONDS = 10.0 | |
| FRAMES_TO_SAMPLE_PER_CLIP = 20 | |
| DELAY_BETWEEN_PREDICTIONS_SECONDS = 120.0 | |
| print(f"Loading model and processor from {HF_MODEL_REPO_ID}...") | |
| try: | |
| processor = AutoImageProcessor.from_pretrained(HF_MODEL_REPO_ID) | |
| model = TimesformerForVideoClassification.from_pretrained(HF_MODEL_REPO_ID) | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| exit() | |
| model.eval() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| print(f"Model loaded on {device}.") | |
| raw_frames_buffer = deque() | |
| current_clip_start_time = time.time() | |
| last_prediction_completion_time = time.time() | |
| app_state = "recording" | |
| def sample_frames(frames_list, target_count): | |
| if not frames_list: | |
| return [] | |
| if len(frames_list) <= target_count: | |
| return frames_list | |
| indices = np.linspace(0, len(frames_list) - 1, target_count, dtype=int) | |
| # FIX: Corrected list indexing from () to [] | |
| sampled = [frames_list[int(i)] for i in indices] | |
| return sampled | |
| def live_predict_stream(image_np_array): | |
| global raw_frames_buffer, current_clip_start_time, last_prediction_completion_time, app_state | |
| current_time = time.time() | |
| pil_image = Image.fromarray(image_np_array) | |
| status_message = "" | |
| prediction_result = "" | |
| if app_state == "recording": | |
| raw_frames_buffer.append(pil_image) | |
| elapsed_recording_time = current_time - current_clip_start_time | |
| status_message = f"Recording: {elapsed_recording_time:.1f}/{RAW_RECORDING_DURATION_SECONDS}s. Raw frames: {len(raw_frames_buffer)}" | |
| prediction_result = "Buffering..." | |
| if elapsed_recording_time >= RAW_RECORDING_DURATION_SECONDS: | |
| app_state = "predicting" | |
| status_message = "Preparing to predict..." | |
| prediction_result = "Processing..." | |
| print("DEBUG: Transitioning to 'predicting' state.") | |
| elif app_state == "predicting": | |
| if raw_frames_buffer: | |
| print("DEBUG: Starting prediction.") | |
| try: | |
| sampled_raw_frames = sample_frames(list(raw_frames_buffer), FRAMES_TO_SAMPLE_PER_CLIP) | |
| frames_for_model = sample_frames(sampled_raw_frames, MODEL_INPUT_NUM_FRAMES) | |
| if len(frames_for_model) < MODEL_INPUT_NUM_FRAMES: | |
| prediction_result = "Error: Not enough frames for model." | |
| status_message = "Error during frame sampling." | |
| print(f"ERROR: Insufficient frames for model input: {len(frames_for_model)}/{MODEL_INPUT_NUM_FRAMES}") | |
| app_state = "recording" # Reset to recording state | |
| raw_frames_buffer.clear() | |
| current_clip_start_time = time.time() | |
| last_prediction_completion_time = time.time() | |
| return status_message, prediction_result | |
| processed_input = processor(images=frames_for_model, return_tensors="pt") | |
| pixel_values = processed_input.pixel_values.to(device) | |
| with torch.no_grad(): | |
| outputs = model(pixel_values) | |
| logits = outputs.logits | |
| predicted_class_id = logits.argmax(-1).item() | |
| predicted_label = model.config.id2label.get(predicted_class_id, "Unknown") | |
| confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item() | |
| prediction_result = f"Predicted: {predicted_label} (Confidence: {confidence:.2f})" | |
| status_message = "Prediction complete." | |
| print(f"DEBUG: Prediction Result: {prediction_result}") | |
| raw_frames_buffer.clear() | |
| last_prediction_completion_time = current_time | |
| app_state = "processing_delay" | |
| print("DEBUG: Transitioning to 'processing_delay' state.") | |
| except Exception as e: | |
| prediction_result = f"Error during prediction: {e}" | |
| status_message = "Prediction error." | |
| print(f"ERROR during prediction: {e}") | |
| app_state = "processing_delay" # Move to delay to avoid continuous errors | |
| else: | |
| status_message = "Waiting for frames..." | |
| prediction_result = "..." | |
| elif app_state == "processing_delay": | |
| elapsed_delay = current_time - last_prediction_completion_time | |
| status_message = f"Delaying next prediction: {int(elapsed_delay)}/{int(DELAY_BETWEEN_PREDICTIONS_SECONDS)}s" | |
| if elapsed_delay >= DELAY_BETWEEN_PREDICTIONS_SECONDS: | |
| app_state = "recording" | |
| current_clip_start_time = current_time | |
| status_message = "Starting new recording..." | |
| prediction_result = "Ready..." | |
| print("DEBUG: Transitioning back to 'recording' state.") | |
| return status_message, prediction_result | |
| def reset_app_state_manual(): | |
| global raw_frames_buffer, current_clip_start_time, last_prediction_completion_time, app_state | |
| raw_frames_buffer.clear() | |
| current_clip_start_time = time.time() | |
| last_prediction_completion_time = time.time() | |
| app_state = "recording" | |
| print("DEBUG: Manual reset triggered.") | |
| return "Ready to record...", "Ready for new prediction." | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| f""" | |
| # TimesFormer Crime Detection - Hugging Face Space Host | |
| This Space hosts the `owinymarvin/timesformer-crime-detection` model. | |
| Live webcam demo with recording and prediction phases. | |
| """ | |
| ) | |
| with gr.Tab("Live Webcam Demo"): | |
| gr.Markdown( | |
| f""" | |
| Continuously captures live webcam feed for **{RAW_RECORDING_DURATION_SECONDS} seconds**, | |
| then makes a prediction. There is a **{DELAY_BETWEEN_PREDICTIONS_SECONDS/60:.0f} minute delay** afterwards. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| webcam_input = gr.Image( | |
| sources=["webcam"], | |
| streaming=True, | |
| label="Live Webcam Feed" | |
| ) | |
| status_output = gr.Textbox(label="Current Status", value="Initializing...") | |
| reset_button = gr.Button("Reset / Start New Cycle") | |
| with gr.Column(): | |
| prediction_output = gr.Textbox(label="Prediction Result", value="Waiting...") | |
| webcam_input.stream( | |
| live_predict_stream, | |
| inputs=[webcam_input], | |
| outputs=[status_output, prediction_output] | |
| ) | |
| reset_button.click( | |
| reset_app_state_manual, | |
| inputs=[], | |
| outputs=[status_output, prediction_output] | |
| ) | |
| with gr.Tab("API Endpoint for External Clients"): | |
| gr.Markdown( | |
| """ | |
| Use this API endpoint to send base64-encoded frames for prediction. | |
| """ | |
| ) | |
| # Re-adding a slightly more representative API interface | |
| # Gradio's automatic API documentation will use this to show inputs/outputs | |
| gr.Interface( | |
| fn=lambda frames_list: f"Received {len(frames_list)} frames. This is a dummy response. Integrate predict_from_frames_api here.", | |
| inputs=gr.Json(label="List of Base64-encoded image strings"), | |
| outputs=gr.Textbox(label="API Response"), | |
| live=False, | |
| allow_flagging="never" # For API endpoints, flagging is usually not desired | |
| ) | |
| # Note: The actual `predict_from_frames_api` function is defined above, | |
| # but for a clean API tab, we can use a dummy interface here that Gradio will | |
| # use to generate the interactive API documentation. The actual API call | |
| # from your local script directly targets the /run/predict_from_frames_api endpoint. | |
| if __name__ == "__main__": | |
| demo.launch() |