Spaces:
Runtime error
Runtime error
| import base64 | |
| import os | |
| import shutil | |
| import tempfile | |
| from io import BytesIO | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import torchvision.transforms as transforms | |
| from decord import VideoReader | |
| from PIL import Image, ImageDraw, ImageFont | |
| from transformers import AutoModel, AutoTokenizer | |
| import spaces | |
| title_markdown = (""" | |
| <div style="display: flex; justify-content: flex-start; align-items: center; text-align: center;"> | |
| <div style="margin-right: 20px; display: flex; align-items: center;"> | |
| <a href="https://github.com/ShareGPT4Omni/ShareGPT4Video" style="text-decoration: none; display: flex; align-items: center;"> | |
| <img src="https://raw.githubusercontent.com/ShareGPT4V/ShareGPT4V-Resources/master/images/share4video_tight.png" alt="ShareGPT4Video🚀" style="max-width: 120px; height: auto;"> | |
| </a> | |
| </div> | |
| <div> | |
| <h1>ShareGPT4Video: Improving Video Understanding and Generation with Better Captions</h1> | |
| <h5 style="margin: 0;">If you like our project, please give us a star ✨ on Github for the latest update.</h5> | |
| <h5 style="margin: 0;"> <a href="https://sharegpt4video.github.io/">[Project Page]</a> <a href="https://github.com/ShareGPT4Omni/ShareGPT4Video">[Code]</a> <a href="https://arxiv.org/abs/2406.04325v1">[Paper]</a> | |
| </div> | |
| </div> | |
| """) | |
| block_css = """ | |
| #buttons button { | |
| min-width: min(120px,100%); | |
| } | |
| """ | |
| learn_more_markdown = (""" | |
| ### License | |
| The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation. | |
| """) | |
| new_path = 'Lin-Chen/ShareCaptioner-Video' | |
| tokenizer = AutoTokenizer.from_pretrained(new_path, trust_remote_code=True) | |
| model = AutoModel.from_pretrained( | |
| new_path, torch_dtype=torch.float16, trust_remote_code=True).cuda().eval() | |
| model.cuda() | |
| model.tokenizer = tokenizer | |
| def padding_336(b, pad=336): | |
| width, height = b.size | |
| tar = int(np.ceil(height / pad) * pad) | |
| top_padding = int((tar - height)/2) | |
| bottom_padding = tar - height - top_padding | |
| left_padding = 0 | |
| right_padding = 0 | |
| b = transforms.functional.pad( | |
| b, [left_padding, top_padding, right_padding, bottom_padding], fill=[255, 255, 255]) | |
| return b | |
| def HD_transform(img, hd_num=25): | |
| width, height = img.size | |
| trans = False | |
| if width < height: | |
| img = img.transpose(Image.TRANSPOSE) | |
| trans = True | |
| width, height = img.size | |
| ratio = (width / height) | |
| scale = 1 | |
| while scale*np.ceil(scale/ratio) <= hd_num: | |
| scale += 1 | |
| scale -= 1 | |
| new_w = int(scale * 336) | |
| new_h = int(new_w / ratio) | |
| img = transforms.functional.resize(img, [new_h, new_w],) | |
| img = padding_336(img, 336) | |
| width, height = img.size | |
| if trans: | |
| img = img.transpose(Image.TRANSPOSE) | |
| return img | |
| def get_seq_frames(total_num_frames, desired_num_frames, start=None, end=None): | |
| if start is None: | |
| assert end is None | |
| start, end = 0, total_num_frames | |
| print(f"{start=}, {end=}") | |
| desired_num_frames -= 2 | |
| end = min(total_num_frames, end) | |
| start = max(start, 0) | |
| seg_size = float((end - start)) / desired_num_frames | |
| seq = [start] | |
| for i in range(desired_num_frames): | |
| s = int(np.round(seg_size * i)) | |
| e = int(np.round(seg_size * (i + 1))) | |
| seq.append(min(int(start + (s + e) // 2), total_num_frames-1)) | |
| return seq + [end-1] | |
| def model_gen(model, text, images, need_bos=True, hd_num=25, max_new_token=2048, beam=3, do_sample=False): | |
| pt1 = 0 | |
| embeds = [] | |
| im_mask = [] | |
| if images is None: | |
| images = [] | |
| images_loc = [] | |
| else: | |
| images = [images] | |
| images_loc = [0] | |
| for i, pts in enumerate(images_loc + [len(text)]): | |
| subtext = text[pt1:pts] | |
| if need_bos or len(subtext) > 0: | |
| text_embeds = model.encode_text( | |
| subtext, add_special_tokens=need_bos) | |
| embeds.append(text_embeds) | |
| im_mask.append(torch.zeros(text_embeds.shape[:2]).cuda()) | |
| need_bos = False | |
| if i < len(images): | |
| try: | |
| image = Image.open(images[i]).convert('RGB') | |
| except: | |
| image = images[i].convert('RGB') | |
| image = HD_transform(image, hd_num=hd_num) | |
| image = model.vis_processor(image).unsqueeze(0).cuda() | |
| image_embeds = model.encode_img(image) | |
| print(image_embeds.shape) | |
| embeds.append(image_embeds) | |
| im_mask.append(torch.ones(image_embeds.shape[:2]).cuda()) | |
| pt1 = pts | |
| embeds = torch.cat(embeds, dim=1) | |
| im_mask = torch.cat(im_mask, dim=1) | |
| im_mask = im_mask.bool() | |
| outputs = model.generate(inputs_embeds=embeds, im_mask=im_mask, | |
| temperature=1.0, max_new_tokens=max_new_token, num_beams=beam, | |
| do_sample=False, repetition_penalty=1.00) | |
| output_token = outputs[0] | |
| if output_token[0] == 0 or output_token[0] == 1: | |
| output_token = output_token[1:] | |
| output_text = model.tokenizer.decode( | |
| output_token, add_special_tokens=False) | |
| output_text = output_text.split('[UNUSED_TOKEN_145]')[0].strip() | |
| output_text = output_text.split('<|im_end|>')[0].strip() | |
| return output_text | |
| def img_process(imgs): | |
| new_w = 0 | |
| new_h = 0 | |
| for im in imgs: | |
| w, h = im.size | |
| new_w = max(new_w, w) | |
| new_h += h + 20 | |
| pad = max(new_w // 4, 100) | |
| new_w += 20 | |
| new_h += 20 | |
| font = ImageFont.truetype("SimHei.ttf", pad // 5) | |
| new_img = Image.new('RGB', (new_w + pad, new_h), 'white') | |
| draw = ImageDraw.Draw(new_img) | |
| curr_h = 10 | |
| for idx, im in enumerate(imgs): | |
| w, h = im.size | |
| new_img.paste(im, (pad, curr_h)) | |
| draw.text((0, curr_h + h // 2), | |
| f'<IMAGE {idx}>', font=font, fill='black') | |
| if idx + 1 < len(imgs): | |
| draw.line([(0, curr_h + h + 10), (new_w+pad, | |
| curr_h + h + 10)], fill='black', width=2) | |
| curr_h += h + 20 | |
| return new_img | |
| def load_quota_video(vis_path, start=None, end=None): | |
| vr = VideoReader(vis_path) | |
| total_frame_num = len(vr) | |
| fps = vr.get_avg_fps() | |
| if start is not None: | |
| assert end is not None | |
| start_frame = int(start * fps) | |
| end_frame = min(int(end * fps), total_frame_num) | |
| else: | |
| start_frame = 0 | |
| end_frame = total_frame_num | |
| interval = int(2 * fps) | |
| frame_idx = list(range(start_frame, end_frame, interval)) | |
| img_array = vr.get_batch(frame_idx).asnumpy() | |
| num_frm, H, W, _ = img_array.shape | |
| img_array = img_array.reshape( | |
| (1, num_frm, img_array.shape[-3], img_array.shape[-2], img_array.shape[-1])) | |
| clip_imgs = [] | |
| for j in range(num_frm): | |
| clip_imgs.append(Image.fromarray(img_array[0, j])) | |
| return clip_imgs | |
| def resize_image(image_path, max_size=1024): | |
| with Image.open(image_path) as img: | |
| width, height = img.size | |
| if width > max_size or height > max_size: | |
| if width > height: | |
| new_width = max_size | |
| new_height = int(height * (max_size / width)) | |
| else: | |
| new_height = max_size | |
| new_width = int(width * (max_size / height)) | |
| else: | |
| new_width = width | |
| new_height = height | |
| resized_img = img.resize((new_width, new_height)) | |
| print(f"resized_img_size: {resized_img.size}") | |
| return resized_img | |
| def encode_resized_image(image_path, max_size=1024): | |
| resized_img = resize_image(image_path, max_size) | |
| try: | |
| with BytesIO() as buffer: | |
| resized_img.save(buffer, format="JPEG") | |
| return base64.b64encode(buffer.getvalue()).decode('utf-8') | |
| except: | |
| with BytesIO() as buffer: | |
| rgb_img = resized_img.convert('RGB') | |
| rgb_img.save(buffer, format="JPEG") | |
| return base64.b64encode(buffer.getvalue()).decode('utf-8') | |
| def generate_slidingcaptioning(video_path): | |
| imgs = load_quota_video(video_path) | |
| q = 'This is the first frame of a video, describe it in detail.' | |
| query = f'[UNUSED_TOKEN_146]user\n{q}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n' | |
| img = imgs[0] | |
| with torch.cuda.amp.autocast(): | |
| response = model_gen(model, query, img, hd_num=9) | |
| print(response) | |
| responses = [response] | |
| images = [img] | |
| for idx in range(len(imgs)-1): | |
| image1 = imgs[idx] | |
| image2 = imgs[idx+1] | |
| prompt = "Here are the Video frame {} at {}.00 Second(s) and Video frame {} at {}.00 Second(s) of a video, describe what happend between them. What happend before is: {}".format( | |
| idx, int(idx*2), idx+1, int((idx+1)*2), response) | |
| width, height = image1.size | |
| new_img = Image.new('RGB', (width, 2*height+50), 'white') | |
| new_img.paste(image1, (0, 0)) | |
| new_img.paste(image2, (0, height+50)) | |
| query = f'[UNUSED_TOKEN_146]user\n{prompt}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n' | |
| with torch.cuda.amp.autocast(): | |
| response = model_gen(model, query, new_img, hd_num=9) | |
| responses.append(response) | |
| images.append(new_img) | |
| prompt = 'Summarize the following per frame descriptions:\n' | |
| for idx, txt in enumerate(responses): | |
| prompt += 'Video frame {} at {}.00 Second(s) description: {}\n'.format( | |
| idx+1, idx*2, txt) | |
| query = f'[UNUSED_TOKEN_146]user\n{prompt}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n' | |
| print(query) | |
| with torch.cuda.amp.autocast(): | |
| summ = model_gen(model, query, None, hd_num=16) | |
| print(summ) | |
| return summ | |
| def generate_fastcaptioning(video_path): | |
| q = 'Here are a few key frames of a video, discribe this video in detail.' | |
| query = f'[UNUSED_TOKEN_146]user\n{q}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n' | |
| imgs = load_quota_video(video_path) | |
| img = img_process(imgs) | |
| with torch.cuda.amp.autocast(): | |
| response = model_gen(model, query, img, hd_num=16, | |
| do_sample=False, beam=3) | |
| return response | |
| def generate_promptrecaptioning(text): | |
| q = f'Translate this brief generation prompt into a detailed caption: {text}' | |
| query = f'[UNUSED_TOKEN_146]user\n{q}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n' | |
| with torch.cuda.amp.autocast(): | |
| response = model_gen(model, query, None) | |
| return response | |
| def save_video_to_local(video_path): | |
| filename = os.path.join('temp', next( | |
| tempfile._get_candidate_names()) + '.mp4') | |
| shutil.copyfile(video_path, filename) | |
| return filename | |
| with gr.Blocks(title='ShareCaptioner-Video', theme=gr.themes.Default(), css=block_css) as demo: | |
| gr.Markdown(title_markdown) | |
| state = gr.State() | |
| state_ = gr.State() | |
| first_run = gr.State() | |
| with gr.Row(): | |
| gr.Markdown("### The ShareCaptioner-Video is a Four-in-One exceptional video captioning model with the following capabilities:\n1. Fast captioning, 2. Sliding Captioning, 3. Clip Summarizing, 4. Prompt Re-Captioning") | |
| with gr.Row(): | |
| gr.Markdown("(THE DEMO OF \"Clip Summarizing\" IS COMING SOON...)") | |
| with gr.Row(): | |
| with gr.Column(scale=6): | |
| with gr.Row(): | |
| video = gr.Video(label="Input Video") | |
| with gr.Row(): | |
| textbox = gr.Textbox( | |
| show_label=False, placeholder="Input Text", container=False | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2, min_width=50): | |
| submit_btn_sc = gr.Button( | |
| value="Sliding Captioning", variant="primary", interactive=True | |
| ) | |
| with gr.Column(scale=2, min_width=50): | |
| submit_btn_fc = gr.Button( | |
| value="Fast Captioning", variant="primary", interactive=True | |
| ) | |
| with gr.Column(scale=2, min_width=50): | |
| submit_btn_pr = gr.Button( | |
| value="Prompt Re-captioning", variant="primary", interactive=True | |
| ) | |
| with gr.Column(scale=4, min_width=200): | |
| with gr.Row(): | |
| textbox_out = gr.Textbox( | |
| show_label=False, placeholder="Output", container=False | |
| ) | |
| gr.Markdown(learn_more_markdown) | |
| submit_btn_sc.click(generate_slidingcaptioning, [video], [textbox_out]) | |
| submit_btn_fc.click(generate_fastcaptioning, [video], [textbox_out]) | |
| submit_btn_pr.click(generate_promptrecaptioning, [textbox], [textbox_out]) | |
| demo.launch() | |