Spaces:
Build error
Build error
| import os | |
| import cv2 | |
| import pdb | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from transformers import CLIPProcessor, CLIPVisionModelWithProjection | |
| from transformers import logging | |
| logging.set_verbosity_error() | |
| class FeatureExtractor(): | |
| def __init__(self, args): | |
| self.device = args.feature_extractor_device | |
| self.beta = args.beta | |
| self.processor = CLIPProcessor.from_pretrained(args.feature_extractor) | |
| self.model = CLIPVisionModelWithProjection.from_pretrained(args.feature_extractor).to(self.device) | |
| self.data_dir = args.data_dir | |
| self.tmp_dir = args.tmp_dir | |
| def __call__(self, video_path, video_id): | |
| cap = cv2.VideoCapture(video_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| video_length = frame_count / fps | |
| sample_rate = int(fps) * self.beta | |
| save_path = os.path.join(self.tmp_dir, video_id + '.npz') | |
| if os.path.exists(save_path): | |
| data = np.load(save_path) | |
| clip_features = data['features'] | |
| return clip_features, video_length | |
| clip_features = [] | |
| print("Extract the clip feature.") | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if cap.get(cv2.CAP_PROP_POS_FRAMES) % sample_rate == 0: | |
| image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| inputs = self.processor(images=image, return_tensors="pt").pixel_values | |
| inputs = inputs.to(self.device) | |
| with torch.no_grad(): | |
| feat = self.model(inputs)['image_embeds'] | |
| clip_features.append(feat.cpu().numpy()) | |
| print("Finished.") | |
| clip_features = np.concatenate(clip_features, axis=0) | |
| np.savez_compressed(save_path, features=clip_features) | |
| return clip_features, video_length | |