Spaces:
Runtime error
Runtime error
| import os | |
| import argparse | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import torchvision.transforms as T | |
| from transformers import AutoTokenizer | |
| import gradio as gr | |
| from resnet50 import build_model | |
| # from utils import generate_similiarity_map, post_process, load_tokenizer, build_transform_R50 | |
| from utils import generate_similiarity_map, get_transform, post_process, load_tokenizer, build_transform_R50 | |
| from utils import IMAGENET_MEAN, IMAGENET_STD | |
| from internvl.train.dataset import dynamic_preprocess | |
| from internvl.model.internvl_chat import InternVLChatModel | |
| import spaces | |
| # 模型配置 | |
| CHECKPOINTS = { | |
| "TokenFD_4096_English_seg": "TongkunGuan/TokenFD_4096_English_seg", | |
| "TokenFD_2048_Bilingual_seg": "TongkunGuan/TokenFD_2048_Bilingual_seg", | |
| } | |
| # 全局变量 | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| def load_model(check_type): | |
| # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| device = torch.device("cuda") | |
| if check_type == 'R50': | |
| tokenizer = load_tokenizer('tokenizer_path') | |
| model = build_model(argparse.Namespace()).eval() | |
| model.load_state_dict(torch.load(CHECKPOINTS['R50'], map_location='cpu')['model']) | |
| transform = build_transform_R50(normalize_type='imagenet') | |
| elif check_type == 'R50_siglip': | |
| tokenizer = load_tokenizer('tokenizer_path') | |
| model = build_model(argparse.Namespace()).eval() | |
| model.load_state_dict(torch.load(CHECKPOINTS['R50_siglip'], map_location='cpu')['model']) | |
| transform = build_transform_R50(normalize_type='imagenet') | |
| elif 'TokenFD' in check_type: | |
| model_path = CHECKPOINTS[check_type] | |
| tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False, use_auth_token=HF_TOKEN) | |
| # model = InternVLChatModel.from_pretrained(model_path, torch_dtype=torch.bfloat16).eval() | |
| model = InternVLChatModel.from_pretrained(model_path, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 ,load_in_8bit=False, load_in_4bit=False).eval() | |
| transform = get_transform(is_train=False, image_size=model.config.force_image_size) | |
| return model.to(device), tokenizer, transform, device | |
| def process_image(model, tokenizer, transform, device, check_type, image, text): | |
| src_size = image.size | |
| if 'TokenFD' in check_type: | |
| images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12, | |
| image_size=model.config.force_image_size, | |
| use_thumbnail=model.config.use_thumbnail, | |
| return_ratio=True) | |
| pixel_values = torch.stack([transform(img) for img in images]).to(device) | |
| else: | |
| pixel_values = torch.stack([transform(image)]).to(device) | |
| target_ratio = (1, 1) | |
| # 文本处理 | |
| text_input = text | |
| if text_input[0] in '!"#$%&\'()*+,-./0123456789:;<=>?@^_{|}~0123456789': | |
| input_ids = tokenizer(text_input)['input_ids'][1:] | |
| else: | |
| input_ids = tokenizer(' '+text_input)['input_ids'][1:] | |
| input_ids = torch.tensor(input_ids, device=device) | |
| # 获取嵌入 | |
| with torch.no_grad(): | |
| if 'R50' in check_type: | |
| text_embeds = model.language_embedding(input_ids) | |
| else: | |
| text_embeds = model.tok_embeddings(input_ids).clone() | |
| vit_embeds, size1 = model.forward_tokenocr(pixel_values.to(torch.bfloat16).to(device)) | |
| print("vit_embeds",vit_embeds) | |
| print("vit_embeds,shape",vit_embeds.shape) | |
| print("target_ratio",target_ratio) | |
| print("check_type",check_type) | |
| vit_embeds, size2 = post_process(vit_embeds, target_ratio, check_type) | |
| # 计算相似度 | |
| text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) | |
| vit_embeds = vit_embeds / vit_embeds.norm(dim=-1, keepdim=True) | |
| similarity = text_embeds @ vit_embeds.T | |
| resized_size = size1 if size1 is not None else size2 | |
| # print(f"text_embeds shape: {text_embeds.shape}, numel: {text_embeds.numel()}") # text_embeds shape: torch.Size([4, 2048]), numel: 8192 | |
| # print(f"vit_embeds shape: {vit_embeds.shape}, numel: {vit_embeds.numel()}") # vit_embeds shape: torch.Size([9728, 2048]), numel: 19922944 | |
| # print(f"similarity shape: {similarity.shape}, numel: {similarity.numel()}")# similarity shape: torch.Size([4, 9728]), numel: 38912 | |
| # 生成可视化 | |
| attn_map = similarity.reshape(len(text_embeds), resized_size[0], resized_size[1]) | |
| # attn_map = similarity.reshape(len(text_embeds), *target_ratio) | |
| all_bpe_strings = [tokenizer.decode(input_id) for input_id in input_ids] | |
| current_vis = generate_similiarity_map(images, attn_map, | |
| [tokenizer.decode([i]) for i in input_ids], | |
| [], target_ratio, src_size) | |
| current_bpe = [tokenizer.decode([i]) for i in input_ids] | |
| # current_bpe[-1] = 'Input text' | |
| # current_bpe.append(text) | |
| return image, current_vis, current_bpe | |
| # 事件处理函数 | |
| # 上一项和下一项按钮 | |
| def update_index(direction, current_vis, current_bpe, current_index): | |
| # 计算新的索引 | |
| new_index = max(0, min(current_index + direction, len(current_vis) - 1)) | |
| # 更新可视化内容 | |
| return ( | |
| current_vis[new_index], | |
| format_bpe_display(current_bpe[new_index]), | |
| new_index # 更新索引 | |
| ) | |
| def format_bpe_display(bpe): | |
| # 使用HTML标签来设置字体大小、颜色,加粗,并居中 | |
| return f"<div style='text-align:center; font-size:20px;'><strong>Current BPE: <span style='color:red;'>{bpe}</span></strong></div>" | |
| # Gradio界面 | |
| with gr.Blocks(title="BPE Visualization Demo") as demo: | |
| gr.Markdown("## BPE Visualization Demo - TokenFD基座模型能力可视化") | |
| with gr.Row(): | |
| with gr.Column(scale=0.5): | |
| model_type = gr.Dropdown( | |
| choices=["TokenFD_4096_English_seg", "TokenFD_2048_Bilingual_seg", "R50", "R50_siglip"], | |
| label="Select model type", | |
| value="TokenOCR_4096_English_seg" # 设置默认值为第一个选项 | |
| ) | |
| image_input = gr.Image(label="Upload images", type="pil") | |
| text_input = gr.Textbox(label="Input text") | |
| run_btn = gr.Button("RUN") | |
| gr.Examples( | |
| examples=[ | |
| [os.path.join("examples", "examples0.jpg"), "Veterans and Benefits"], | |
| [os.path.join("examples", "examples1.jpg"), "Refreshers"], | |
| [os.path.join("examples", "examples2.png"), "Vision Transformer"] | |
| ], | |
| inputs=[image_input, text_input], | |
| label="Sample input" | |
| ) | |
| with gr.Column(scale=2): | |
| gr.Markdown("<p style='font-size:20px;'><span style='color:red;'>If the input text is not included in the image</span>, the attention map will show a lot of noise (the actual response value is very low), since we normalize the attention map according to the relative value.</p>") | |
| with gr.Row(): | |
| orig_img = gr.Image(label="Original picture", interactive=False) | |
| heatmap = gr.Image(label="BPE visualization", interactive=False) | |
| with gr.Row() as controls: | |
| prev_btn = gr.Button("⬅ Last", visible=False) | |
| next_btn = gr.Button("⮕ Next", visible=False) | |
| bpe_display = gr.Markdown("Current BPE: ") | |
| current_vis_state = gr.State([]) | |
| current_bpe_state = gr.State([]) | |
| current_index_state = gr.State(0) | |
| # 事件处理 | |
| def on_run_clicked(model_type, image, text): | |
| current_index = 0 # Reset index when new image is processed | |
| image, current_vis, current_bpe = process_image(*load_model(model_type), model_type, image, text) | |
| bpe_text = format_bpe_display(current_bpe) | |
| print("current_vis",len(current_vis)) | |
| print("current_bpe",len(current_bpe)) | |
| # return image, current_vis[0],f"<div style='text-align:center; font-size:20px;'><strong>Current BPE: <span style='color:red;'>{current_bpe[0]}</span></strong></div>", gr.update(visible=True), gr.update(visible=True) | |
| return ( | |
| image, | |
| current_vis[current_index], | |
| format_bpe_display(current_bpe[current_index]), | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| current_vis, # 存储整个列表 | |
| current_bpe, # 存储整个列表 | |
| current_index # 存储当前索引 | |
| ) | |
| run_btn.click( | |
| on_run_clicked, | |
| inputs=[model_type, image_input, text_input], | |
| outputs=[orig_img, heatmap, bpe_display, prev_btn, next_btn, current_vis_state, current_bpe_state, current_index_state] | |
| ) | |
| prev_btn.click( | |
| update_index, | |
| inputs=[gr.State(-1), current_vis_state, current_bpe_state, current_index_state], | |
| outputs=[heatmap, bpe_display, current_index_state] | |
| ) | |
| next_btn.click( | |
| update_index, | |
| inputs=[gr.State(1), current_vis_state, current_bpe_state, current_index_state], | |
| outputs=[heatmap, bpe_display, current_index_state] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |