Spaces:
Sleeping
Sleeping
| import argparse | |
| import os | |
| import re | |
| from tqdm import tqdm | |
| import pandas as pd | |
| from vllm import LLM, SamplingParams | |
| from utils.logger import logger | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Recaption the video frame.") | |
| parser.add_argument( | |
| "--video_metadata_path", type=str, required=True, help="The path to the video dataset metadata (csv/jsonl)." | |
| ) | |
| parser.add_argument( | |
| "--video_path_column", | |
| type=str, | |
| default="video_path", | |
| help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).", | |
| ) | |
| parser.add_argument( | |
| "--caption_column", | |
| type=str, | |
| default="sampled_frame_caption", | |
| help="The column contains the sampled_frame_caption.", | |
| ) | |
| parser.add_argument( | |
| "--remove_quotes", | |
| action="store_true", | |
| help="Whether to remove quotes from caption.", | |
| ) | |
| parser.add_argument( | |
| "--batch_size", | |
| type=int, | |
| default=10, | |
| required=False, | |
| help="The batch size for the video caption.", | |
| ) | |
| parser.add_argument( | |
| "--summary_model_name", | |
| type=str, | |
| default="mistralai/Mistral-7B-Instruct-v0.2", | |
| ) | |
| parser.add_argument( | |
| "--summary_prompt", | |
| type=str, | |
| default=( | |
| "You are a helpful video description generator. I'll give you a description of the middle frame of the video clip, " | |
| "which you need to summarize it into a description of the video clip." | |
| "Please provide your video description following these requirements: " | |
| "1. Describe the basic and necessary information of the video in the third person, be as concise as possible. " | |
| "2. Output the video description directly. Begin with 'In this video'. " | |
| "3. Limit the video description within 100 words. " | |
| "Here is the mid-frame description: " | |
| ), | |
| ) | |
| parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).") | |
| parser.add_argument("--saved_freq", type=int, default=1000, help="The frequency to save the output results.") | |
| args = parser.parse_args() | |
| return args | |
| def main(): | |
| args = parse_args() | |
| if args.video_metadata_path.endswith(".csv"): | |
| video_metadata_df = pd.read_csv(args.video_metadata_path) | |
| elif args.video_metadata_path.endswith(".jsonl"): | |
| video_metadata_df = pd.read_json(args.video_metadata_path, lines=True) | |
| else: | |
| raise ValueError("The video_metadata_path must end with .csv or .jsonl.") | |
| video_path_list = video_metadata_df[args.video_path_column].tolist() | |
| sampled_frame_caption_list = video_metadata_df[args.caption_column].tolist() | |
| if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")): | |
| raise ValueError("The saved_path must end with .csv or .jsonl.") | |
| if os.path.exists(args.saved_path): | |
| if args.saved_path.endswith(".csv"): | |
| saved_metadata_df = pd.read_csv(args.saved_path) | |
| elif args.saved_path.endswith(".jsonl"): | |
| saved_metadata_df = pd.read_json(args.saved_path, lines=True) | |
| saved_video_path_list = saved_metadata_df[args.video_path_column].tolist() | |
| video_path_list = list(set(video_path_list) - set(saved_video_path_list)) | |
| video_metadata_df.set_index(args.video_path_column, inplace=True) | |
| video_metadata_df = video_metadata_df.loc[video_path_list] | |
| sampled_frame_caption_list = video_metadata_df[args.caption_column].tolist() | |
| logger.info(f"Resume from {args.saved_path}: {len(saved_video_path_list)} processed and {len(video_path_list)} to be processed.") | |
| sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=256) | |
| summary_model = LLM(model=args.summary_model_name, trust_remote_code=True) | |
| result_dict = {"video_path": [], "summary_model": [], "summary_caption": []} | |
| for i in tqdm(range(0, len(sampled_frame_caption_list), args.batch_size)): | |
| batch_video_path = video_path_list[i: i + args.batch_size] | |
| batch_caption = sampled_frame_caption_list[i : i + args.batch_size] | |
| batch_prompt = [] | |
| for caption in batch_caption: | |
| if args.remove_quotes: | |
| caption = re.sub(r'(["\']).*?\1', "", caption) | |
| batch_prompt.append("user:" + args.summary_prompt + str(caption) + "\n assistant:") | |
| batch_output = summary_model.generate(batch_prompt, sampling_params) | |
| result_dict["video_path"].extend(batch_video_path) | |
| result_dict["summary_model"].extend([args.summary_model_name] * len(batch_caption)) | |
| result_dict["summary_caption"].extend([output.outputs[0].text.rstrip() for output in batch_output]) | |
| # Save the metadata every args.saved_freq. | |
| if i != 0 and ((i // args.batch_size) % args.saved_freq) == 0: | |
| result_df = pd.DataFrame(result_dict) | |
| if args.saved_path.endswith(".csv"): | |
| header = True if not os.path.exists(args.saved_path) else False | |
| result_df.to_csv(args.saved_path, header=header, index=False, mode="a") | |
| elif args.saved_path.endswith(".jsonl"): | |
| result_df.to_json(args.saved_path, orient="records", lines=True, mode="a") | |
| logger.info(f"Save result to {args.saved_path}.") | |
| result_dict = {"video_path": [], "summary_model": [], "summary_caption": []} | |
| result_df = pd.DataFrame(result_dict) | |
| if args.saved_path.endswith(".csv"): | |
| header = True if not os.path.exists(args.saved_path) else False | |
| result_df.to_csv(args.saved_path, header=header, index=False, mode="a") | |
| elif args.saved_path.endswith(".jsonl"): | |
| result_df.to_json(args.saved_path, orient="records", lines=True, mode="a") | |
| logger.info(f"Save the final result to {args.saved_path}.") | |
| if __name__ == "__main__": | |
| main() |