Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| import json | |
| import urllib | |
| from torchvision.transforms import Compose, Lambda | |
| from torchvision.transforms._transforms_video import ( | |
| CenterCropVideo, | |
| NormalizeVideo, | |
| ) | |
| from pytorchvideo.data.encoded_video import EncodedVideo | |
| from pytorchvideo.transforms import ( | |
| ApplyTransformToKey, | |
| ShortSideScale, | |
| UniformTemporalSubsample, | |
| UniformCropVideo | |
| ) | |
| import numpy as np # Explicitly add numpy import | |
| # Choose the `slowfast_r50` model | |
| model = torch.hub.load('facebookresearch/pytorchvideo', 'slowfast_r50', pretrained=True) | |
| # Set to CPU since you don't have a GPU | |
| device = "cpu" | |
| model = model.eval() | |
| model = model.to(device) | |
| # --- Class Name Loading (from notebook) --- | |
| json_url = "https://dl.fbaipublicfiles.com/pyslowfast/dataset/class_names/kinetics_classnames.json" | |
| json_filename = "kinetics_classnames.json" | |
| try: | |
| urllib.URLopener().retrieve(json_url, json_filename) | |
| except: | |
| urllib.request.urlretrieve(json_url, json_filename) | |
| with open(json_filename, "r") as f: | |
| kinetics_classnames = json.load(f) | |
| kinetics_id_to_classname = {} | |
| for k, v in kinetics_classnames.items(): | |
| kinetics_id_to_classname[v] = str(k).replace('"', "") | |
| # --- Define Input Transform (from notebook) --- | |
| side_size = 256 | |
| mean = [0.45, 0.45, 0.45] | |
| std = [0.225, 0.225, 0.225] | |
| crop_size = 256 | |
| num_frames = 32 | |
| sampling_rate = 2 | |
| frames_per_second = 30 | |
| slowfast_alpha = 4 | |
| # num_clips = 10 # Not used in inference function | |
| # num_crops = 3 # Not used in inference function | |
| class PackPathway(torch.nn.Module): | |
| """ | |
| Transform for converting video frames as a list of tensors. | |
| """ | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, frames: torch.Tensor): | |
| fast_pathway = frames | |
| slow_pathway = torch.index_select( | |
| frames, | |
| 1, | |
| torch.linspace( | |
| 0, frames.shape[1] - 1, frames.shape[1] // slowfast_alpha | |
| ).long(), | |
| ) | |
| frame_list = [slow_pathway, fast_pathway] | |
| return frame_list | |
| transform = ApplyTransformToKey( | |
| key="video", | |
| transform=Compose( | |
| [ | |
| UniformTemporalSubsample(num_frames), | |
| Lambda(lambda x: x/255.0), | |
| NormalizeVideo(mean, std), | |
| ShortSideScale( | |
| size=side_size | |
| ), | |
| CenterCropVideo(crop_size), | |
| PackPathway() | |
| ] | |
| ), | |
| ) | |
| clip_duration = (num_frames * sampling_rate)/frames_per_second | |
| # Download example video (for local testing and for Gradio examples) | |
| url_link = "https://dl.fbaipublicfiles.com/pytorchvideo/projects/archery.mp4" | |
| video_path = 'archery.mp4' | |
| try: urllib.URLopener().retrieve(url_link, video_path) | |
| except: urllib.request.urlretrieve(url_link, video_path) | |
| def inference(in_vid): | |
| if in_vid is None: | |
| return "Please upload a video or use the webcam." | |
| try: | |
| # Initialize an EncodedVideo helper class and load the video | |
| video = EncodedVideo.from_path(in_vid) | |
| # Ensure we have enough frames for the clip duration | |
| if video.duration < clip_duration: | |
| return f"Video is too short. Minimum duration is {clip_duration:.2f} seconds." | |
| # Select the duration of the clip to load by specifying the start and end duration | |
| start_sec = 0 | |
| end_sec = start_sec + clip_duration | |
| # Load the desired clip | |
| video_data = video.get_clip(start_sec=start_sec, end_sec=end_sec) | |
| # Apply a transform to normalize the video input | |
| video_data = transform(video_data) | |
| # Move the inputs to the desired device | |
| inputs = video_data["video"] | |
| inputs = [i.to(device)[None, ...] for i in inputs] | |
| # Pass the input clip through the model | |
| with torch.no_grad(): # Ensure no gradient computation for inference | |
| preds = model(inputs) | |
| # Get the predicted classes | |
| post_act = torch.nn.Softmax(dim=1) | |
| preds = post_act(preds) | |
| pred_classes = preds.topk(k=5).indices[0] | |
| # Map the predicted classes to the label names | |
| pred_class_names = [kinetics_id_to_classname[int(i)] for i in pred_classes] | |
| return "Top 5 predicted labels: %s" % ", ".join(pred_class_names) | |
| except Exception as e: | |
| # Catch common errors like video decoding issues or insufficient frames | |
| return f"An error occurred during inference: {e}" | |
| # --- UPDATED GRADIO INTERFACE SYNTAX --- | |
| # Removed gr.inputs and gr.outputs | |
| inputs_gradio = gr.Video(label="Upload Video or Use Webcam", sources=["upload", "webcam"], format="mp4") | |
| outputs_gradio = gr.Textbox(label="Top 5 Predicted Labels") | |
| title = "PyTorchVideo SlowFast Action Recognition" | |
| description = """ | |
| Demo for PyTorchVideo's SlowFast model, pretrained on the Kinetics 400 dataset for action recognition. | |
| Upload your video or use your webcam to classify the action. | |
| """ | |
| article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1812.03982' target='_blank'>SlowFast Networks for Video Recognition</a> | <a href='https://github.com/facebookresearch/pytorchvideo' target='_blank'>PyTorchVideo GitHub Repo</a></p>" | |
| examples = [ | |
| [video_path] # Use the downloaded archery.mp4 as an example | |
| ] | |
| gr.Interface( | |
| fn=inference, | |
| inputs=inputs_gradio, | |
| outputs=outputs_gradio, | |
| title=title, | |
| description=description, | |
| article=article, | |
| examples=examples, | |
| analytics_enabled=False | |
| ).launch() |