Spaces:
Runtime error
Runtime error
| import os | |
| import spaces | |
| try: | |
| token =os.environ['HF_TOKEN'] | |
| except: | |
| print("paste your hf token here!") | |
| token = "hf_xxxxxxxxxxxxxxxxxxx" | |
| os.environ['HF_TOKEN'] = token | |
| import torch | |
| import gradio as gr | |
| from gradio.themes.utils import colors, fonts, sizes | |
| from transformers import AutoTokenizer, AutoModel | |
| # ======================================== | |
| # Model Initialization | |
| # ======================================== | |
| tokenizer = AutoTokenizer.from_pretrained('OpenGVLab/InternVideo2_chat_8B_HD', | |
| trust_remote_code=True, | |
| use_fast=False, | |
| token=token) | |
| if torch.cuda.is_available(): | |
| model = AutoModel.from_pretrained( | |
| 'OpenGVLab/InternVideo2_chat_8B_HD', | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True).cuda() | |
| else: | |
| model = AutoModel.from_pretrained( | |
| 'OpenGVLab/InternVideo2_chat_8B_HD', | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True) | |
| from decord import VideoReader, cpu | |
| from PIL import Image | |
| import numpy as np | |
| import numpy as np | |
| import decord | |
| from decord import VideoReader, cpu | |
| import torch.nn.functional as F | |
| import torchvision.transforms as T | |
| from torchvision.transforms import PILToTensor | |
| from torchvision import transforms | |
| from torchvision.transforms.functional import InterpolationMode | |
| decord.bridge.set_bridge("torch") | |
| # ======================================== | |
| # Define Utils | |
| # ======================================== | |
| def get_index(num_frames, num_segments): | |
| seg_size = float(num_frames - 1) / num_segments | |
| start = int(seg_size / 2) | |
| offsets = np.array([ | |
| start + int(np.round(seg_size * idx)) for idx in range(num_segments) | |
| ]) | |
| return offsets | |
| def load_video(video_path, num_segments=8, return_msg=False, resolution=224, hd_num=4, padding=False): | |
| decord.bridge.set_bridge("torch") | |
| vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) | |
| num_frames = len(vr) | |
| frame_indices = get_index(num_frames, num_segments) | |
| mean = (0.485, 0.456, 0.406) | |
| std = (0.229, 0.224, 0.225) | |
| transform = transforms.Compose([ | |
| transforms.Lambda(lambda x: x.float().div(255.0)), | |
| transforms.Normalize(mean, std) | |
| ]) | |
| frames = vr.get_batch(frame_indices) | |
| # frames = torch.from_numpy(frames) | |
| frames = frames.permute(0, 3, 1, 2) | |
| if padding: | |
| frames = HD_transform_padding(frames.float(), image_size=resolution, hd_num=hd_num) | |
| else: | |
| frames = HD_transform_no_padding(frames.float(), image_size=resolution, hd_num=hd_num) | |
| frames = transform(frames) | |
| # print(frames.shape) | |
| T_, C, H, W = frames.shape | |
| sub_img = frames.reshape( | |
| 1, T_, 3, H//resolution, resolution, W//resolution, resolution | |
| ).permute(0, 3, 5, 1, 2, 4, 6).reshape(-1, T_, 3, resolution, resolution).contiguous() | |
| glb_img = F.interpolate( | |
| frames.float(), size=(resolution, resolution), mode='bicubic', align_corners=False | |
| ).to(sub_img.dtype).unsqueeze(0) | |
| frames = torch.cat([sub_img, glb_img]).unsqueeze(0) | |
| if return_msg: | |
| fps = float(vr.get_avg_fps()) | |
| sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices]) | |
| # " " should be added in the start and end | |
| msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds." | |
| return frames, msg | |
| else: | |
| return frames | |
| def HD_transform_padding(frames, image_size=224, hd_num=6): | |
| def _padding_224(frames): | |
| _, _, H, W = frames.shape | |
| tar = int(np.ceil(H / 224) * 224) | |
| top_padding = (tar - H) // 2 | |
| bottom_padding = tar - H - top_padding | |
| left_padding = 0 | |
| right_padding = 0 | |
| padded_frames = F.pad( | |
| frames, | |
| pad=[left_padding, right_padding, top_padding, bottom_padding], | |
| mode='constant', value=255 | |
| ) | |
| return padded_frames | |
| _, _, H, W = frames.shape | |
| trans = False | |
| if W < H: | |
| frames = frames.flip(-2, -1) | |
| trans = True | |
| width, height = H, W | |
| else: | |
| width, height = W, H | |
| ratio = width / height | |
| scale = 1 | |
| while scale * np.ceil(scale / ratio) <= hd_num: | |
| scale += 1 | |
| scale -= 1 | |
| new_w = int(scale * image_size) | |
| new_h = int(new_w / ratio) | |
| resized_frames = F.interpolate( | |
| frames, size=(new_h, new_w), | |
| mode='bicubic', | |
| align_corners=False | |
| ) | |
| padded_frames = _padding_224(resized_frames) | |
| if trans: | |
| padded_frames = padded_frames.flip(-2, -1) | |
| return padded_frames | |
| def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): | |
| best_ratio_diff = float('inf') | |
| best_ratio = (1, 1) | |
| area = width * height | |
| for ratio in target_ratios: | |
| target_aspect_ratio = ratio[0] / ratio[1] | |
| ratio_diff = abs(aspect_ratio - target_aspect_ratio) | |
| if ratio_diff < best_ratio_diff: | |
| best_ratio_diff = ratio_diff | |
| best_ratio = ratio | |
| elif ratio_diff == best_ratio_diff: | |
| if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: | |
| best_ratio = ratio | |
| return best_ratio | |
| def HD_transform_no_padding(frames, image_size=224, hd_num=6, fix_ratio=(2,1)): | |
| min_num = 1 | |
| max_num = hd_num | |
| _, _, orig_height, orig_width = frames.shape | |
| aspect_ratio = orig_width / orig_height | |
| # calculate the existing video aspect ratio | |
| target_ratios = set( | |
| (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if | |
| i * j <= max_num and i * j >= min_num) | |
| target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) | |
| # find the closest aspect ratio to the target | |
| if fix_ratio: | |
| target_aspect_ratio = fix_ratio | |
| else: | |
| target_aspect_ratio = find_closest_aspect_ratio( | |
| aspect_ratio, target_ratios, orig_width, orig_height, image_size) | |
| # calculate the target width and height | |
| target_width = image_size * target_aspect_ratio[0] | |
| target_height = image_size * target_aspect_ratio[1] | |
| blocks = target_aspect_ratio[0] * target_aspect_ratio[1] | |
| # resize the frames | |
| resized_frame = F.interpolate( | |
| frames, size=(target_height, target_width), | |
| mode='bicubic', align_corners=False | |
| ) | |
| return resized_frame | |
| # ======================================== | |
| # Gradio Setting | |
| # ======================================== | |
| def gradio_reset(chat_state, img_list): | |
| if chat_state is not None: | |
| chat_state = [] | |
| if img_list is not None: | |
| img_list = None | |
| return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your video first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list | |
| def upload_img( gr_video, num_segments, hd_num, padding): | |
| img_list = [] | |
| if gr_video is None: | |
| return None, None, gr.update(interactive=True),gr.update(interactive=True, placeholder='Please upload video/image first!'), None | |
| if gr_video: | |
| video_tensor, msg = load_video(gr_video, num_segments=num_segments, return_msg=True, resolution=224, hd_num=hd_num, padding=padding) | |
| video_tensor = video_tensor.to(model.device) | |
| return gr.update(interactive=True), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), video_tensor | |
| # if gr_img: | |
| # llm_message, img_list,chat_state = chat.upload_img(gr_img, chat_state, img_list) | |
| # return gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False) | |
| def clear_(): | |
| return [], [] | |
| def gradio_ask(user_message, chatbot): | |
| if len(user_message) == 0: | |
| return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state | |
| chatbot = chatbot + [[user_message, None]] | |
| return '', chatbot | |
| def gradio_answer(chatbot, sys_prompt, user_prompt, video_tensor, chat_state, num_beams, temperature, do_sample=False): | |
| video_tensor = video_tensor.to(model.device) | |
| response, chat_state = model.chat(tokenizer, | |
| sys_prompt, | |
| user_prompt, | |
| media_type='video', | |
| media_tensor=video_tensor, | |
| chat_history= chat_state, | |
| return_history=True, | |
| generation_config={ | |
| "num_beams": num_beams, | |
| "temperature": temperature, | |
| "do_sample": do_sample}) | |
| print(response) | |
| chatbot[-1][1] = response | |
| return chatbot, chat_state | |
| class OpenGVLab(gr.themes.base.Base): | |
| def __init__( | |
| self, | |
| *, | |
| primary_hue=colors.blue, | |
| secondary_hue=colors.sky, | |
| neutral_hue=colors.gray, | |
| spacing_size=sizes.spacing_md, | |
| radius_size=sizes.radius_sm, | |
| text_size=sizes.text_md, | |
| font=( | |
| fonts.GoogleFont("Noto Sans"), | |
| "ui-sans-serif", | |
| "sans-serif", | |
| ), | |
| font_mono=( | |
| fonts.GoogleFont("IBM Plex Mono"), | |
| "ui-monospace", | |
| "monospace", | |
| ), | |
| ): | |
| super().__init__( | |
| primary_hue=primary_hue, | |
| secondary_hue=secondary_hue, | |
| neutral_hue=neutral_hue, | |
| spacing_size=spacing_size, | |
| radius_size=radius_size, | |
| text_size=text_size, | |
| font=font, | |
| font_mono=font_mono, | |
| ) | |
| super().set( | |
| body_background_fill="*neutral_50", | |
| ) | |
| gvlabtheme = OpenGVLab(primary_hue=colors.blue, | |
| secondary_hue=colors.sky, | |
| neutral_hue=colors.gray, | |
| spacing_size=sizes.spacing_md, | |
| radius_size=sizes.radius_sm, | |
| text_size=sizes.text_md, | |
| ) | |
| title = """<h1 align="center"><a href="https://github.com/OpenGVLab/Ask-Anything"><img src="https://s1.ax1x.com/2023/05/07/p9dBMOU.png" alt="Ask-Anything" border="0" style="margin: 0 auto; height: 100px;" /></a> </h1>""" | |
| description =""" | |
| VideoChat2 powered by InternVideo!<br><p><a href='https://github.com/OpenGVLab/Ask-Anything'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p> | |
| """ | |
| SYS_PROMPT ="" | |
| with gr.Blocks(title="InternVideo-VideoChat!",theme=gvlabtheme,css="#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}") as demo: | |
| gr.Markdown(title) | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(scale=0.5, visible=True) as video_upload: | |
| with gr.Column(elem_id="image", scale=0.5) as img_part: | |
| # with gr.Tab("Video", elem_id='video_tab'): | |
| up_video = gr.Video(interactive=True, include_audio=True, elem_id="video_upload") | |
| # with gr.Tab("Image", elem_id='image_tab'): | |
| # up_image = gr.Image(type="pil", interactive=True, elem_id="image_upload") | |
| upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary") | |
| restart = gr.Button("Restart") | |
| sys_prompt = gr.State(f"{SYS_PROMPT}") | |
| num_beams = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=1, | |
| step=1, | |
| interactive=True, | |
| label="beam search numbers)", | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=1.0, | |
| step=0.1, | |
| interactive=True,label="Temperature", | |
| ) | |
| num_segments = gr.Slider( | |
| minimum=8, | |
| maximum=64, | |
| value=8, | |
| step=1, | |
| interactive=True, | |
| label="Input Frames", | |
| ) | |
| resolution = gr.Slider( | |
| minimum=224, | |
| maximum=224, | |
| value=224, | |
| step=1, | |
| interactive=True, | |
| label="Vision encoder resolution", | |
| ) | |
| hd_num = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=4, | |
| step=1, | |
| interactive=True, | |
| label="HD num", | |
| ) | |
| padding = gr.Checkbox( | |
| label="padding", | |
| info="" | |
| ) | |
| with gr.Column(visible=True) as input_raws: | |
| chat_state = gr.State([]) | |
| img_list = gr.State() | |
| chatbot = gr.Chatbot(elem_id="chatbot",label='VideoChat') | |
| with gr.Row(): | |
| with gr.Column(scale=0.7): | |
| text_input = gr.Textbox(show_label=False, placeholder='Please upload your video first', interactive=False) | |
| with gr.Column(scale=0.15, min_width=0): | |
| run = gr.Button("💭Send") | |
| with gr.Column(scale=0.15, min_width=0): | |
| clear = gr.Button("🔄Clear️") | |
| upload_button.click(upload_img, [ up_video, num_segments, hd_num, padding], [ up_video, text_input, upload_button, img_list]) | |
| text_input.submit(gradio_ask, [text_input, chatbot], [text_input, chatbot]).then( | |
| gradio_answer, [chatbot, sys_prompt, text_input, img_list, chat_state, num_beams, temperature], [chatbot, chat_state] | |
| ) | |
| run.click(gradio_ask, [text_input, chatbot], [text_input, chatbot]).then( | |
| gradio_answer, [chatbot, sys_prompt, text_input, img_list, chat_state, num_beams, temperature], [chatbot, chat_state] | |
| ) | |
| run.click(lambda: "", None, text_input) | |
| clear.click(clear_, None, [chatbot, chat_state]) | |
| restart.click(gradio_reset, [chat_state, img_list], [chatbot, up_video, text_input, upload_button, chat_state, img_list], queue=False) | |
| demo.launch() |