|  | import os | 
					
						
						|  | import random | 
					
						
						|  | import io | 
					
						
						|  | import av | 
					
						
						|  | import cv2 | 
					
						
						|  | import decord | 
					
						
						|  | import imageio | 
					
						
						|  | from decord import VideoReader | 
					
						
						|  | import torch | 
					
						
						|  | import numpy as np | 
					
						
						|  | import math | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  | decord.bridge.set_bridge("torch") | 
					
						
						|  |  | 
					
						
						|  | from transformers import AutoConfig, AutoModel | 
					
						
						|  | config = AutoConfig.from_pretrained("/fs-computility/video/heyinan/iv2hf/", trust_remote_code=True) | 
					
						
						|  | model = AutoModel.from_pretrained("/fs-computility/video/heyinan/iv2hf/", trust_remote_code=True).to(config.device) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1, start=None, end=None): | 
					
						
						|  | start_frame, end_frame = 0, vlen | 
					
						
						|  | if start is not None: | 
					
						
						|  | start_frame = max(start_frame,int(start * input_fps)) | 
					
						
						|  | if end is not None: | 
					
						
						|  | end_frame = min(end_frame,int(end * input_fps)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if start_frame >= end_frame: | 
					
						
						|  | raise ValueError("Start frame index must be less than end frame index") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | clip_length = end_frame - start_frame | 
					
						
						|  |  | 
					
						
						|  | if sample in ["rand", "middle"]: | 
					
						
						|  | acc_samples = min(num_frames, clip_length) | 
					
						
						|  |  | 
					
						
						|  | intervals = np.linspace(start=start_frame, stop=end_frame, num=acc_samples + 1).astype(int) | 
					
						
						|  | ranges = [] | 
					
						
						|  | for idx, interv in enumerate(intervals[:-1]): | 
					
						
						|  | ranges.append((interv, intervals[idx + 1] - 1)) | 
					
						
						|  | if sample == 'rand': | 
					
						
						|  | try: | 
					
						
						|  | frame_indices = [random.choice(range(x[0], x[1] + 1)) for x in ranges] | 
					
						
						|  | except: | 
					
						
						|  | frame_indices = np.random.permutation(clip_length)[:acc_samples] + start_frame | 
					
						
						|  | frame_indices.sort() | 
					
						
						|  | frame_indices = list(frame_indices) | 
					
						
						|  | elif fix_start is not None: | 
					
						
						|  | frame_indices = [x[0] + fix_start for x in ranges] | 
					
						
						|  | elif sample == 'middle': | 
					
						
						|  | frame_indices = [(x[0] + x[1]) // 2 for x in ranges] | 
					
						
						|  | else: | 
					
						
						|  | raise NotImplementedError | 
					
						
						|  |  | 
					
						
						|  | if len(frame_indices) < num_frames: | 
					
						
						|  | padded_frame_indices = [frame_indices[-1]] * num_frames | 
					
						
						|  | padded_frame_indices[:len(frame_indices)] = frame_indices | 
					
						
						|  | frame_indices = padded_frame_indices | 
					
						
						|  | elif "fps" in sample: | 
					
						
						|  | output_fps = float(sample[3:]) | 
					
						
						|  | duration = float(clip_length) / input_fps | 
					
						
						|  | delta = 1 / output_fps | 
					
						
						|  | frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta) | 
					
						
						|  | frame_indices = np.around(frame_seconds * input_fps).astype(int) + start_frame | 
					
						
						|  | frame_indices = [e for e in frame_indices if e < end_frame] | 
					
						
						|  | if max_num_frames > 0 and len(frame_indices) > max_num_frames: | 
					
						
						|  | frame_indices = frame_indices[:max_num_frames] | 
					
						
						|  |  | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError | 
					
						
						|  | return frame_indices | 
					
						
						|  |  | 
					
						
						|  | def read_frames_decord( | 
					
						
						|  | video_path, num_frames, sample='middle', fix_start=None, | 
					
						
						|  | max_num_frames=-1, client=None, trimmed30=False, start=None, end=None | 
					
						
						|  | ): | 
					
						
						|  | num_threads = 1 if video_path.endswith('.webm') else 0 | 
					
						
						|  |  | 
					
						
						|  | video_reader = VideoReader(video_path, num_threads=num_threads) | 
					
						
						|  | vlen = len(video_reader) | 
					
						
						|  |  | 
					
						
						|  | fps = video_reader.get_avg_fps() | 
					
						
						|  | duration = vlen / float(fps) | 
					
						
						|  |  | 
					
						
						|  | frame_indices = get_frame_indices( | 
					
						
						|  | num_frames, vlen, sample=sample, fix_start=fix_start, | 
					
						
						|  | input_fps=fps, max_num_frames=max_num_frames, start=start, end=end | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | frames = video_reader.get_batch(frame_indices) | 
					
						
						|  | frames = frames.permute(0, 3, 1, 2) | 
					
						
						|  | return frames, frame_indices, duration | 
					
						
						|  |  | 
					
						
						|  | def get_text_feature(model, texts): | 
					
						
						|  | text_input = model.tokenizer(texts).to(model.device) | 
					
						
						|  | text_features = model.encode_text(text_input) | 
					
						
						|  | return text_features | 
					
						
						|  |  | 
					
						
						|  | def get_similarity(video_feature, text_feature): | 
					
						
						|  | video_feature = F.normalize(video_feature, dim=-1) | 
					
						
						|  | text_feature = F.normalize(text_feature, dim=-1) | 
					
						
						|  | sim_matrix = text_feature @ video_feature.T | 
					
						
						|  | return sim_matrix | 
					
						
						|  |  | 
					
						
						|  | def get_top_videos(model, text_features, video_features, video_paths, texts): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | video_features = F.normalize(video_features, dim=-1) | 
					
						
						|  | text_features = F.normalize(text_features, dim=-1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | sim_matrix = text_features @ video_features.T | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | top_k = 5 | 
					
						
						|  | sim_matrix_top_k = torch.topk(sim_matrix, top_k, dim=1)[1] | 
					
						
						|  | softmax_sim_matrix = F.softmax(sim_matrix, dim=1) | 
					
						
						|  |  | 
					
						
						|  | retrieval_infos = {} | 
					
						
						|  | for i in range(len(sim_matrix_top_k)): | 
					
						
						|  | print("\n",texts[i]) | 
					
						
						|  | retrieval_infos[texts[i]] = [] | 
					
						
						|  | for j in range(top_k): | 
					
						
						|  | print("top", j+1, ":", video_paths[sim_matrix_top_k[i][j]], "~prob:", sim_matrix[i][sim_matrix_top_k[i][j]].item()) | 
					
						
						|  | retrieval_infos[texts[i]].append({"video":  video_paths[sim_matrix_top_k[i][j]], "prob": sim_matrix[i][sim_matrix_top_k[i][j]].item(), "rank": j+1}) | 
					
						
						|  | return retrieval_infos | 
					
						
						|  |  | 
					
						
						|  | if __name__=="__main__": | 
					
						
						|  | video_features = [] | 
					
						
						|  | demo_videos = ["video1.mp4","video2.mp4"] | 
					
						
						|  | texts = ['a person talking', 'a logo', 'a building'] | 
					
						
						|  | for video_path in demo_videos: | 
					
						
						|  | frames, frame_indices, video_duration = read_frames_decord(video_path,8) | 
					
						
						|  | frames = model.transform(frames).unsqueeze(0).to(model.device) | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | video_feature = model.encode_vision(frames, test=True) | 
					
						
						|  | video_features.append(video_feature) | 
					
						
						|  |  | 
					
						
						|  | text_features = get_text_feature(model, texts) | 
					
						
						|  | video_features = torch.cat(video_features, dim=0).to(text_features.dtype).to(config.device) | 
					
						
						|  | results = get_top_videos(model, text_features, video_features, demo_videos, texts) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  |