Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| from pathlib import Path | |
| os.system("python -m pip install --upgrade pip") | |
| os.system("cd multimodal && pip install .") | |
| os.system("cd multimodal/YOLOX && pip install .") | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| import tempfile | |
| import string | |
| import cv2 | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download, login | |
| from open_flamingo.src.factory import create_model_and_transforms | |
| from open_flamingo.chat.conversation import ChatBOT, CONV_VISION | |
| sys.path.append(str(Path(__file__).parent.parent.parent)) | |
| TEMP_FILE_DIR = Path(__file__).parent / 'temp' | |
| TEMP_FILE_DIR.mkdir(parents=True, exist_ok=True) | |
| SHARED_UI_WARNING = f'''### [NOTE] It is possible that you are waiting in a lengthy queue. | |
| You can duplicate and use it with a paid private GPU. | |
| <a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/Vision-CAIR/minigpt4?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-xl-dark.svg" alt="Duplicate Space"></a> | |
| Alternatively, you can also use the demo on our [project page](https://compositionalvlm.github.io/). | |
| ''' | |
| flamingo, image_processor, tokenizer, vis_embed_size = create_model_and_transforms( | |
| "ViT-L-14", | |
| "datacomp_xl_s13b_b90k", | |
| "EleutherAI/pythia-1.4b", | |
| "EleutherAI/pythia-1.4b", | |
| location_token_num=1000, | |
| lora=False, | |
| lora_r=16, | |
| use_sam=None, | |
| add_visual_token=True, | |
| use_format_v2=True, | |
| add_box=True, | |
| add_pe=False, | |
| add_relation=False, | |
| enhance_data=False, | |
| ) | |
| model_name = "pythiaS" | |
| checkpoint_path = hf_hub_download("chendl/compositional_test", "pythiaS.pt") | |
| checkpoint = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"] | |
| model_state_dict = {} | |
| for key in checkpoint.keys(): | |
| model_state_dict[key.replace("module.", "")] = checkpoint[key] | |
| if "vision_encoder.logit_scale" in model_state_dict: | |
| # previous checkpoint has some unnecessary weights | |
| del model_state_dict["vision_encoder.logit_scale"] | |
| del model_state_dict["vision_encoder.visual.proj"] | |
| del model_state_dict["vision_encoder.visual.ln_post.weight"] | |
| del model_state_dict["vision_encoder.visual.ln_post.bias"] | |
| flamingo.load_state_dict(model_state_dict, strict=True) | |
| chat = ChatBOT(flamingo, image_processor, tokenizer, vis_embed_size,model_name) | |
| def get_outputs( | |
| model, | |
| batch_images, | |
| attention_mask, | |
| max_generation_length, | |
| min_generation_length, | |
| num_beams, | |
| length_penalty, | |
| input_ids, | |
| image_start_index_list=None, | |
| image_nums=None, | |
| bad_words_ids=None, | |
| ): | |
| # and torch.cuda.amp.autocast(dtype=torch.float16) | |
| with torch.inference_mode(): | |
| outputs = model( | |
| vision_x=batch_images, | |
| lang_x=input_ids, | |
| attention_mask=attention_mask, | |
| labels=None, | |
| image_nums=image_nums, | |
| image_start_index_list=image_start_index_list, | |
| added_bbox_list=None, | |
| add_box=False, | |
| ) | |
| # outputs = model.generate( | |
| # batch_images, | |
| # input_ids, | |
| # attention_mask=attention_mask, | |
| # max_new_tokens=max_generation_length, | |
| # min_length=min_generation_length, | |
| # num_beams=num_beams, | |
| # length_penalty=length_penalty, | |
| # image_start_index_list=image_start_index_list, | |
| # image_nums=image_nums, | |
| # bad_words_ids=bad_words_ids, | |
| # ) | |
| return outputs | |
| def generate( | |
| idx, | |
| image, | |
| text, | |
| vis_embed_size=256, | |
| rank=0, | |
| world_size=1, | |
| ): | |
| if image is None: | |
| raise gr.Error("Please upload an image.") | |
| flamingo.eval() | |
| loc_token_ids = [] | |
| for i in range(1000): | |
| loc_token_ids.append(int(tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1])) | |
| media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] | |
| endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] | |
| pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1] | |
| bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1] | |
| prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1] | |
| image_ori = image | |
| image = image.convert("RGB") | |
| width = image.width | |
| height = image.height | |
| image = image.resize((224, 224)) | |
| batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0) | |
| if idx == 1: | |
| prompt = [ | |
| f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|><|#object#|> {text.rstrip('.').strip()}<|#endofobject#|><|#visual#|>"] | |
| bad_words_ids = None | |
| max_generation_length = 5 | |
| else: | |
| prompt = [f"<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|>{text.rstrip('.')}"] | |
| bad_words_ids = loc_word_ids | |
| max_generation_length = 30 | |
| encodings = tokenizer( | |
| prompt, | |
| padding="longest", | |
| truncation=True, | |
| return_tensors="pt", | |
| max_length=2000, | |
| ) | |
| input_ids = encodings["input_ids"] | |
| attention_mask = encodings["attention_mask"] | |
| image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() | |
| image_start_index_list = [[x] for x in image_start_index_list] | |
| image_nums = [1] * len(input_ids) | |
| outputs = get_outputs( | |
| model=flamingo, | |
| batch_images=batch_images, | |
| attention_mask=attention_mask, | |
| max_generation_length=max_generation_length, | |
| min_generation_length=4, | |
| num_beams=1, | |
| length_penalty=1.0, | |
| input_ids=input_ids, | |
| bad_words_ids=bad_words_ids, | |
| image_start_index_list=image_start_index_list, | |
| image_nums=image_nums, | |
| ) | |
| boxes = outputs["boxes"] | |
| scores = outputs["scores"] | |
| if len(scores) > 0: | |
| box = boxes[scores.argmax()] / 224 | |
| print(f"{box}") | |
| if idx == 1: | |
| open_cv_image = np.array(image_ori) | |
| # Convert RGB to BGR | |
| open_cv_image = open_cv_image[:, :, ::-1].copy() | |
| box = box * [width, height, width, height] | |
| # for box in boxes: | |
| open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2) | |
| out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB)) | |
| return f"Output:{box}", out_image | |
| elif idx == 2: | |
| gen_text = tokenizer.batch_decode(outputs) | |
| return (f"Question: {text.strip()} Answer: {gen_text}") | |
| else: | |
| gen_text = tokenizer.batch_decode(outputs) | |
| return (f"Output:{gen_text}") | |
| title = """<h1 align="center">Demo of Compositional-VLM</h1>""" | |
| description = """<h3>This is the demo of Compositional-VLM. Upload your images and start chatting!</h3>""" | |
| article = """<div style='display:flex; gap: 0.25rem; '><a href='https://vis-www.cs.umass.edu/CoVLM/'><img src='https://img.shields.io/badge/Project-Page-Green'></a><a href='https://github.com/UMass-Foundation-Model/CoVLM'><img src='https://img.shields.io/badge/Github-Code-blue'></a><a href='https://arxiv.org/abs/2311.03354'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></div> | |
| """ | |
| # TODO show examples below | |
| # ======================================== | |
| # Gradio Setting | |
| # ======================================== | |
| def gradio_reset(chat_state, img_list): | |
| if chat_state is not None: | |
| chat_state = [] | |
| if img_list is not None: | |
| img_list = [] | |
| return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', | |
| interactive=False), gr.update( | |
| value="Upload & Start Chat", interactive=True), chat_state, img_list | |
| def build_image(image): | |
| if image is None: | |
| return None | |
| # res = draw_bounding_boxes(image=image, boxes=boxes_to_draw, colors=color_to_draw, width=8) | |
| from torchvision.transforms import ToPILImage | |
| # res = ToPILImage()(res) | |
| _, path = tempfile.mkstemp(suffix='.jpg', dir=TEMP_FILE_DIR) | |
| image.save(path) | |
| return path | |
| def upload_img(gr_img, text_input, chat_state, chatbot): | |
| if gr_img is None: | |
| return None, None, gr.update(interactive=True), chat_state, None | |
| chat_state = [] | |
| img_list = [] | |
| path = build_image(gr_img) | |
| chatbot = chatbot + [[(path,), None]] | |
| llm_message = chat.upload_img(gr_img, chat_state, img_list) | |
| return gr.update(interactive=False), gr.Textbox(placeholder='Type and press Enter', interactive=True), gr.update( | |
| value="Start Chatting", interactive=False), chat_state, img_list, chatbot | |
| def gradio_ask(user_message, chatbot, chat_state, radio): | |
| # if len(user_message) == 0: | |
| # return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state | |
| chat.ask(user_message, chat_state, radio) | |
| chatbot = chatbot + [[user_message, None]] | |
| return chatbot, chat_state | |
| def generate_ans(user_message, chatbot, chat_state, img_list, radio, text, num_beams, temperature): | |
| # if len(user_message) == 0: | |
| # return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state | |
| chat.ask(user_message, chat_state, radio) | |
| chatbot = chatbot + [[user_message, None]] | |
| # return chatbot, chat_state | |
| image = None | |
| llm_message, image = \ | |
| chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature, | |
| max_length=2000, radio=radio, text_input=text) | |
| chatbot[-1][1] = llm_message | |
| if chat_state[-1]["from"] == "gpt": | |
| chat_state[-1]["value"] = llm_message | |
| if image == None: | |
| return "", chatbot, chat_state, img_list | |
| else: | |
| path = build_image(image) | |
| chatbot = chatbot + [[None, (path,)]] | |
| return "", chatbot, chat_state, img_list | |
| def gradio_answer(chatbot, chat_state, img_list, radio, text, num_beams, temperature): | |
| image = None | |
| llm_message, image = \ | |
| chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature, | |
| max_length=2000, radio=radio, text_input=text) | |
| chatbot[-1][1] = llm_message | |
| if chat_state[-1]["from"] == "gpt": | |
| chat_state[-1]["value"] = llm_message | |
| if image == None: | |
| return "", chatbot, chat_state, img_list | |
| else: | |
| path = build_image(image) | |
| chatbot = chatbot + [[None, (path,)]] | |
| return "", chatbot, chat_state, img_list | |
| task_template = { | |
| "Cap": "Summarize the content of the photo <image>.", | |
| "VQA": "For this image <image>, I want a simple and direct answer to my question: <question>", | |
| "REC": "Can you point out <expr> in the image <image> and provide the coordinates of its location?", | |
| "GC": "Can you give me a description of the region <boxes> in image <image>?", | |
| "Advanced": "<question>", | |
| } | |
| with gr.Blocks() as demo: | |
| gr.Markdown(title) | |
| gr.Markdown(SHARED_UI_WARNING) | |
| gr.Markdown(description) | |
| gr.Markdown(article) | |
| with gr.Row(): | |
| with gr.Column(scale=0.5): | |
| image = gr.Image(type="pil") | |
| upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary") | |
| clear = gr.Button("Restart") | |
| radio = gr.Radio( | |
| ["Cap", "VQA", "REC", "Advanced"], label="Task Template", value='Cap', | |
| ) | |
| num_beams = gr.Slider( | |
| minimum=1, | |
| maximum=5, | |
| value=1, | |
| step=1, | |
| interactive=True, | |
| label="beam search numbers)", | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=1.0, | |
| step=0.1, | |
| interactive=True, | |
| label="Temperature", | |
| ) | |
| with gr.Column(): | |
| chat_state = gr.State() | |
| img_list = gr.State() | |
| chatbot = gr.Chatbot(label='Compositional-VLM') | |
| # template = gr.Textbox(label='Template', show_label=True, lines=1, interactive=False, | |
| # value='Provide a comprehensive description of the image <image> and specify the positions of any mentioned objects in square brackets.') | |
| # text_input = gr.Textbox(label='<question>', show_label=True, placeholder="Please upload your image firstοΌ then input...", lines=3, | |
| # value=None, visible=False, interactive=False) | |
| # with gr.Row(): | |
| text_input = gr.Textbox(label='User', placeholder='Please upload your image first, then input...', | |
| interactive=False) | |
| # submit_button = gr.Button(value="Submit", interactive=True, variant="primary") | |
| upload_button.click(upload_img, [image, text_input, chat_state, chatbot], | |
| [image, text_input, upload_button, chat_state, img_list, chatbot]) | |
| # submit_button.click(gradio_ask, [text_input, chatbot, chat_state,radio], [chatbot, chat_state]).then( | |
| # gradio_answer, [chatbot, chat_state, img_list, radio, text_input,num_beams, temperature], [text_input,chatbot, chat_state, img_list] | |
| # ) | |
| text_input.submit(generate_ans, | |
| [text_input, chatbot, chat_state, img_list, radio, text_input, num_beams, temperature], | |
| [text_input, chatbot, chat_state, img_list]) | |
| # text_input.submit(gradio_ask, [text_input, chatbot, chat_state, radio], [chatbot, chat_state]).then( | |
| # gradio_answer, [chatbot, chat_state, img_list, radio, text_input, num_beams, temperature], | |
| # [text_input, chatbot, chat_state, img_list] | |
| # ) | |
| clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], | |
| queue=False) | |
| demo.launch(share=True) | |
| # | |
| # with gr.Blocks() as demo: | |
| # gr.Markdown( | |
| # """ | |
| # π Object Centric Pretraining Demo | |
| # In this demo we showcase the in-context learning and grounding capabilities of the Object-Centric Pretrained model, a large multimodal model. Note that we add two additional demonstrations to the ones presented to improve the demo experience. | |
| # The model is trained on an interleaved mixture of text, images and bounding box and is able to generate text conditioned on sequences of images/text. | |
| # """ | |
| # ) | |
| # | |
| # with gr.Accordion("See terms and conditions"): | |
| # gr.Markdown( | |
| # """**Please read the following information carefully before proceeding.**This demo does NOT store any personal information on its users, and it does NOT store user queries.""") | |
| # | |
| # with gr.Tab("π· Image Captioning"): | |
| # with gr.Row(): | |
| # | |
| # | |
| # query_image = gr.Image(type="pil") | |
| # with gr.Row(): | |
| # chat_input = gr.Textbox(lines=1, label="Chat Input") | |
| # text_output = gr.Textbox(value="Output:", label="Model output") | |
| # | |
| # run_btn = gr.Button("Run model") | |
| # | |
| # | |
| # | |
| # def on_click_fn(img,text): return generate(0, img, text) | |
| # | |
| # run_btn.click(on_click_fn, inputs=[query_image,chat_input], outputs=[text_output]) | |
| # | |
| # with gr.Tab("π¦ Grounding"): | |
| # with gr.Row(): | |
| # with gr.Column(scale=1): | |
| # query_image = gr.Image(type="pil") | |
| # with gr.Column(scale=1): | |
| # out_image = gr.Image(type="pil") | |
| # with gr.Row(): | |
| # chat_input = gr.Textbox(lines=1, label="Chat Input") | |
| # text_output = gr.Textbox(value="Output:", label="Model output") | |
| # | |
| # run_btn = gr.Button("Run model") | |
| # | |
| # | |
| # def on_click_fn(img, text): return generate(1, img, text) | |
| # | |
| # | |
| # run_btn.click(on_click_fn, inputs=[query_image, chat_input], outputs=[text_output, out_image]) | |
| # | |
| # with gr.Tab("π’ Counting objects"): | |
| # with gr.Row(): | |
| # query_image = gr.Image(type="pil") | |
| # with gr.Row(): | |
| # chat_input = gr.Textbox(lines=1, label="Chat Input") | |
| # text_output = gr.Textbox(value="Output:", label="Model output") | |
| # | |
| # run_btn = gr.Button("Run model") | |
| # | |
| # | |
| # def on_click_fn(img,text): return generate(0, img, text) | |
| # | |
| # | |
| # run_btn.click(on_click_fn, inputs=[query_image, chat_input], outputs=[text_output]) | |
| # | |
| # with gr.Tab("π΅οΈ Visual Question Answering"): | |
| # with gr.Row(): | |
| # query_image = gr.Image(type="pil") | |
| # with gr.Row(): | |
| # question = gr.Textbox(lines=1, label="Question") | |
| # text_output = gr.Textbox(value="Output:", label="Model output") | |
| # | |
| # run_btn = gr.Button("Run model") | |
| # | |
| # | |
| # def on_click_fn(img, txt): return generate(2, img, txt) | |
| # | |
| # | |
| # run_btn.click( | |
| # on_click_fn, inputs=[query_image, question], outputs=[text_output] | |
| # ) | |
| # | |
| # with gr.Tab("π Custom"): | |
| # gr.Markdown( | |
| # """### Customize the demonstration by uploading your own images and text samples. | |
| # ### **Note: Any text prompt you use will be prepended with an 'Output:', so you don't need to include it in your prompt.**""" | |
| # ) | |
| # with gr.Row(): | |
| # query_image = gr.Image(type="pil") | |
| # with gr.Row(): | |
| # question = gr.Textbox(lines=1, label="Question") | |
| # text_output = gr.Textbox(value="Output:", label="Model output") | |
| # | |
| # run_btn = gr.Button("Run model") | |
| # | |
| # | |
| # def on_click_fn(img, txt): return generate(2, img, txt) | |
| # | |
| # | |
| # run_btn.click( | |
| # on_click_fn, inputs=[query_image, question], outputs=[text_output] | |
| # ) | |
| # | |
| # demo.queue(concurrency_count=1) | |
| # demo.launch() | |