Spaces:
Runtime error
Runtime error
| import os | |
| os.system("cd multimodal && pip install .") | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| import string | |
| import cv2 | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download, login | |
| from open_flamingo.src.factory import create_model_and_transforms | |
| flamingo, image_processor, tokenizer, vis_embed_size = create_model_and_transforms( | |
| "ViT-L-14", | |
| "datacomp_xl_s13b_b90k", | |
| "EleutherAI/pythia-1.4b", | |
| "EleutherAI/pythia-1.4b", | |
| add_visual_grounding=True, | |
| location_token_num=1000, | |
| add_visual_token = True, | |
| use_format_v2 = True, | |
| ) | |
| checkpoint_path = hf_hub_download("chendl/compositional_test", "pythiaS.pt") | |
| checkpoint = torch.load(checkpoint_path, map_location="cpu") | |
| model_state_dict = {} | |
| for key in checkpoint.keys(): | |
| model_state_dict[key.replace("module.", "")] = checkpoint[key] | |
| if "vision_encoder.logit_scale"in model_state_dict: | |
| # previous checkpoint has some unnecessary weights | |
| del model_state_dict["vision_encoder.logit_scale"] | |
| del model_state_dict["vision_encoder.visual.proj"] | |
| del model_state_dict["vision_encoder.visual.ln_post.weight"] | |
| del model_state_dict["vision_encoder.visual.ln_post.bias"] | |
| flamingo.load_state_dict(model_state_dict, strict=True) | |
| def get_outputs( | |
| model, | |
| batch_images, | |
| attention_mask, | |
| max_generation_length, | |
| min_generation_length, | |
| num_beams, | |
| length_penalty, | |
| input_ids, | |
| image_start_index_list=None, | |
| image_nums=None, | |
| bad_words_ids=None, | |
| ): | |
| # and torch.cuda.amp.autocast(dtype=torch.float16) | |
| with torch.inference_mode(): | |
| outputs = model.generate( | |
| batch_images, | |
| input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=max_generation_length, | |
| min_length=min_generation_length, | |
| num_beams=num_beams, | |
| length_penalty=length_penalty, | |
| image_start_index_list=image_start_index_list, | |
| image_nums=image_nums, | |
| bad_words_ids=bad_words_ids, | |
| ) | |
| return outputs | |
| def evaluate_refcoco( | |
| model, | |
| tokenizer, | |
| image_processor, | |
| batch_size, | |
| tsvfile, | |
| max_generation_length=20, | |
| num_beams=3, | |
| length_penalty=-2.0, | |
| device=-1, | |
| vis_embed_size=None, | |
| rank=0, | |
| world_size=1, | |
| id=0, | |
| ): | |
| model.eval().cuda() | |
| loc_token_ids = [] | |
| for i in range(1000): | |
| loc_token_ids.append(int(tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1])) | |
| media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] | |
| endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] | |
| pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1] | |
| bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1] | |
| prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1] | |
| # all_ids = set(range(model.lang_encoder.lm_head.out_features)) | |
| # bad_words_ids = list(all_ids - set(loc_token_ids)) | |
| # bad_words_ids = [[b] for b in bad_words_ids] | |
| # min_loc_token_id = min(loc_token_ids) | |
| # max_loc_token_id = max(loc_token_ids) | |
| total = 0 | |
| correct = 0 | |
| ious = [] | |
| if "refcocog" in tsvfile: | |
| dataset_name = "refcocog" | |
| elif "refcocoplus" in tsvfile: | |
| dataset_name = "refcocoplus" | |
| else: | |
| dataset_name = "refcoco" | |
| with open(tsvfile, "r") as f: | |
| lines = f.readlines() | |
| pbar = tqdm(lines, disable=(rank != 0)) | |
| for ii, line in enumerate(pbar): | |
| if ii % world_size != rank: | |
| continue | |
| total += 1 | |
| line = line.rstrip() | |
| uniq_id, image_id, text, region_coord, image = line.split("\t") | |
| image = Image.open(BytesIO(base64.urlsafe_b64decode(image))).convert("RGB") | |
| # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/yolo.png").convert("RGB") | |
| # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/temp/cat.png").convert("RGB") | |
| # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/temp/262148000.png") | |
| gt_box = np.array(list(map(float, region_coord.split(",")))) | |
| width = image.width | |
| height = image.height | |
| image = image.resize((224, 224)) | |
| gt_box = gt_box / np.array([width, height, width, height]) * 224 | |
| batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0) | |
| prompt = [ | |
| f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|><|#object#|>{text.rstrip('.').strip()}<|#endofobject#|><|#visual#|>"] | |
| # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>the cat<|#visual#|>"] | |
| # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"] | |
| # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>a man<|#visual#|> is doing a trick on a skateboard<|#visual#|>"] | |
| encodings = tokenizer( | |
| prompt, | |
| padding="longest", | |
| truncation=True, | |
| return_tensors="pt", | |
| max_length=2000, | |
| ) | |
| input_ids = encodings["input_ids"] | |
| attention_mask = encodings["attention_mask"] | |
| # attention_mask[input_ids == prebox_token_id] = 0 | |
| image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() | |
| image_start_index_list = [[x] for x in image_start_index_list] | |
| image_nums = [1] * len(input_ids) | |
| vision_x = batch_images.cuda() | |
| lang_x = input_ids.cuda() | |
| attention_mask = attention_mask.cuda() | |
| model.debug_id = 0 | |
| with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16): | |
| outputs = model( | |
| vision_x=vision_x, | |
| lang_x=lang_x, | |
| attention_mask=attention_mask, | |
| labels=None, | |
| image_nums=image_nums, | |
| image_start_index_list=image_start_index_list, | |
| added_bbox_list=None, | |
| add_box=False, | |
| ) | |
| boxes = outputs["boxes"] | |
| scores = outputs["scores"] | |
| if len(scores) > 0: | |
| box = boxes[scores.argmax()] | |
| iou = get_iou(box, gt_box) | |
| else: | |
| iou = 0.0 | |
| # tqdm.write(f"output: {tokenizer.batch_decode(outputs)}") | |
| tqdm.write(f"no output for: {uniq_id}, {image_id}, {text}") | |
| if iou >= 0.5: | |
| correct += 1 | |
| pbar.set_description(f"iou: {iou:.2f} score: {correct / total:.4f}") | |
| # open_cv_image = np.array(image) | |
| # # Convert RGB to BGR | |
| # open_cv_image = open_cv_image[:, :, ::-1].copy() | |
| # for box, score in zip(boxes, scores): | |
| # open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2) | |
| # cv2.imwrite("output.jpg", open_cv_image) | |
| # print(boxes) | |
| # print(scores) | |
| # exit() | |
| def generate( | |
| idx, | |
| image, | |
| text, | |
| vis_embed_size=256, | |
| rank=0, | |
| world_size=1, | |
| ): | |
| if image is None: | |
| raise gr.Error("Please upload an image.") | |
| flamingo.eval() | |
| loc_token_ids = [] | |
| for i in range(1000): | |
| loc_token_ids.append(int(tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1])) | |
| media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] | |
| endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] | |
| pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1] | |
| bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1] | |
| prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1] | |
| image_ori = image | |
| image = image.convert("RGB") | |
| width = image.width | |
| height = image.height | |
| image = image.resize((224, 224)) | |
| batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0) | |
| if idx == 1: | |
| prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|><|#object#|>{text.rstrip('.').strip()}<|#endofobject#|><|#visual#|>"] | |
| bad_words_ids = None | |
| max_generation_length = 5 | |
| else: | |
| prompt = [f"<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|>{text.rstrip('.')}"] | |
| bad_words_ids = loc_word_ids | |
| max_generation_length = 30 | |
| encodings = tokenizer( | |
| prompt, | |
| padding="longest", | |
| truncation=True, | |
| return_tensors="pt", | |
| max_length=2000, | |
| ) | |
| input_ids = encodings["input_ids"] | |
| attention_mask = encodings["attention_mask"] | |
| image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() | |
| image_start_index_list = [[x] for x in image_start_index_list] | |
| image_nums = [1] * len(input_ids) | |
| outputs = get_outputs( | |
| model=flamingo, | |
| batch_images=batch_images, | |
| attention_mask=attention_mask, | |
| max_generation_length=max_generation_length, | |
| min_generation_length=4, | |
| num_beams=1, | |
| length_penalty=1.0, | |
| input_ids=input_ids, | |
| bad_words_ids=bad_words_ids, | |
| image_start_index_list=image_start_index_list, | |
| image_nums=image_nums, | |
| ) | |
| boxes = outputs["boxes"] | |
| scores = outputs["scores"] | |
| if len(scores) > 0: | |
| box = boxes[scores.argmax()] | |
| iou = get_iou(box, gt_box) | |
| else: | |
| iou = 0.0 | |
| # tqdm.write(f"output: {tokenizer.batch_decode(outputs)}") | |
| tqdm.write(f"no output for: {uniq_id}, {image_id}, {text}") | |
| if iou >= 0.5: | |
| correct += 1 | |
| gen_text = tokenizer.batch_decode(outputs) | |
| if idx == 1: | |
| return f"Output:{gen_text}", out_image | |
| elif idx == 2: | |
| return (f"Question: {text.strip()} Answer: {gen_text}") | |
| else: | |
| return (f"Output:{gen_text}") | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| π Object Centric Pretraining Demo | |
| In this demo we showcase the in-context learning and grounding capabilities of the Object-Centric Pretrained model, a large multimodal model. Note that we add two additional demonstrations to the ones presented to improve the demo experience. | |
| The model is trained on an interleaved mixture of text, images and bounding box and is able to generate text conditioned on sequences of images/text. | |
| """ | |
| ) | |
| with gr.Accordion("See terms and conditions"): | |
| gr.Markdown( | |
| """**Please read the following information carefully before proceeding.**This demo does NOT store any personal information on its users, and it does NOT store user queries.""") | |
| with gr.Tab("π· Image Captioning"): | |
| with gr.Row(): | |
| query_image = gr.Image(type="pil") | |
| with gr.Row(): | |
| chat_input = gr.Textbox(lines=1, label="Chat Input") | |
| text_output = gr.Textbox(value="Output:", label="Model output") | |
| run_btn = gr.Button("Run model") | |
| def on_click_fn(img,text): return generate(0, img, text) | |
| run_btn.click(on_click_fn, inputs=[query_image,chat_input], outputs=[text_output]) | |
| with gr.Tab("π¦ Grounding"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| query_image = gr.Image(type="pil") | |
| with gr.Column(scale=1): | |
| out_image = gr.Image(type="pil") | |
| with gr.Row(): | |
| chat_input = gr.Textbox(lines=1, label="Chat Input") | |
| text_output = gr.Textbox(value="Output:", label="Model output") | |
| run_btn = gr.Button("Run model") | |
| def on_click_fn(img, text): return generate(1, img, text) | |
| run_btn.click(on_click_fn, inputs=[query_image, chat_input], outputs=[text_output, out_image]) | |
| with gr.Tab("π’ Counting objects"): | |
| with gr.Row(): | |
| query_image = gr.Image(type="pil") | |
| with gr.Row(): | |
| chat_input = gr.Textbox(lines=1, label="Chat Input") | |
| text_output = gr.Textbox(value="Output:", label="Model output") | |
| run_btn = gr.Button("Run model") | |
| def on_click_fn(img,text): return generate(0, img, text) | |
| run_btn.click(on_click_fn, inputs=[query_image, chat_input], outputs=[text_output]) | |
| with gr.Tab("π΅οΈ Visual Question Answering"): | |
| with gr.Row(): | |
| query_image = gr.Image(type="pil") | |
| with gr.Row(): | |
| question = gr.Textbox(lines=1, label="Question") | |
| text_output = gr.Textbox(value="Output:", label="Model output") | |
| run_btn = gr.Button("Run model") | |
| def on_click_fn(img, txt): return generate(2, img, txt) | |
| run_btn.click( | |
| on_click_fn, inputs=[query_image, question], outputs=[text_output] | |
| ) | |
| with gr.Tab("π Custom"): | |
| gr.Markdown( | |
| """### Customize the demonstration by uploading your own images and text samples. | |
| ### **Note: Any text prompt you use will be prepended with an 'Output:', so you don't need to include it in your prompt.**""" | |
| ) | |
| with gr.Row(): | |
| query_image = gr.Image(type="pil") | |
| with gr.Row(): | |
| question = gr.Textbox(lines=1, label="Question") | |
| text_output = gr.Textbox(value="Output:", label="Model output") | |
| run_btn = gr.Button("Run model") | |
| def on_click_fn(img, txt): return generate(2, img, txt) | |
| run_btn.click( | |
| on_click_fn, inputs=[query_image, question], outputs=[text_output] | |
| ) | |
| demo.queue(concurrency_count=1) | |
| demo.launch() | |