Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| from pathlib import Path | |
| from transformers import GPT2Config, GPT2LMHeadModel, GPT2TokenizerFast | |
| import json | |
| import torch | |
| model = GPT2LMHeadModel.from_pretrained('gpt2') | |
| tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # temp_folder = 'temp' | |
| # os.makedirs(temp_folder, exist_ok=True) | |
| logit = {} | |
| json_file = 'index.json' | |
| with open(json_file, 'r') as file: | |
| data = json.load(file) | |
| for key, value in data.items(): | |
| text_description = value['text_description'] | |
| inputs = tokenizer(text_description, return_tensors="pt", padding="max_length", max_length=32, truncation=True) | |
| outputs = model(**inputs, labels=inputs["input_ids"]) | |
| logits = outputs.logits | |
| logit[key] = logits | |
| # torch.save(logits, os.path.join(temp_folder, f"{key}.pt")) | |
| def search_index(query): | |
| inputs = tokenizer(query, return_tensors="pt", padding="max_length", max_length=32, truncation=True) | |
| outputs = model(**inputs, labels=inputs["input_ids"]) | |
| max_similarity = float('-inf') | |
| max_similarity_uuid = None | |
| # for file in os.listdir(temp_folder): | |
| # uuid = file.split('.')[0] | |
| # logits = torch.load(os.path.join(temp_folder, file)) | |
| for uuid, logits in logit.items(): | |
| similarity = (outputs.logits * logits).sum() | |
| if similarity > max_similarity: | |
| max_similarity = similarity | |
| max_similarity_uuid = uuid | |
| gr.Info(f"Found the most similar video with UUID: {max_similarity_uuid}. \n Downloading the video...") | |
| return max_similarity_uuid | |
| def download_video(uuid): | |
| dataset_name = "quchenyuan/360x_dataset_LR" | |
| dataset_path = "360_dataset/binocular/" | |
| video_filename = f"{uuid}.mp4" | |
| storage_dir = Path("videos") | |
| storage_dir.mkdir(exist_ok=True) | |
| # storage_limit = 40 * 1024 * 1024 * 1024 | |
| # current_storage = sum(f.stat().st_size for f in storage_dir.glob('*') if f.is_file()) | |
| # if current_storage + os.path.getsize(video_filename) > storage_limit: | |
| # oldest_file = min(storage_dir.glob('*'), key=os.path.getmtime) | |
| # oldest_file.unlink() | |
| downloaded_file_path = hf_hub_download(dataset_name, dataset_path + video_filename) | |
| return str(storage_dir / video_filename) | |
| def search_and_show_video(query): | |
| uuid = search_index(query) | |
| video_path = download_video(uuid) | |
| return video_path | |
| if __name__ == "__main__": | |
| with gr.Blocks() as demo: | |
| with gr.Column(): | |
| with gr.Row(): | |
| gr.HTML("<h1><i>360+x</i> : A Panoptic Multi-modal Scene Understanding Dataset</h1>") | |
| with gr.Row(): | |
| gr.HTML("<p><a href='https://x360dataset.github.io/'>Official Website</a> <a href='https://arxiv.org/abs/2404.00989'>Paper</a></p>") | |
| with gr.Row(): | |
| gr.HTML("<h2>Search for a video by entering a query below:</h2>") | |
| with gr.Row(): | |
| search_input = gr.Textbox(label="Query", placeholder="Enter a query to search for a video.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_output_1 = gr.Video() | |
| with gr.Column(): | |
| video_output_2 = gr.Video() | |
| with gr.Column(): | |
| video_output_3 = gr.Video() | |
| with gr.Row(): | |
| submit_button = gr.Button(value="Search") | |
| submit_button.click(search_and_show_video, search_input, | |
| outputs=[video_output_1, video_output_2, video_output_3]) | |
| demo.launch() | |