Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers.generation.streamers import TextIteratorStreamer | |
| from PIL import Image | |
| import requests | |
| from io import BytesIO | |
| from threading import Thread | |
| import os | |
| # 导入 LLaVA 相关模块 | |
| from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN | |
| from llava.conversation import conv_templates, SeparatorStyle | |
| from llava.model.builder import load_pretrained_model | |
| from llava.utils import disable_torch_init | |
| from llava.mm_utils import tokenizer_image_token | |
| # **确保 Hugging Face 缓存目录正确** | |
| os.environ["HUGGINGFACE_HUB_CACHE"] = os.getcwd() + "/weights" | |
| # **加载 LLaVA-1.5-13B** | |
| disable_torch_init() | |
| model_id = "Yanqing0327/LLaVA-project" # 替换为你的 Hugging Face 模型仓库 | |
| tokenizer, model, image_processor, context_len = load_pretrained_model( | |
| model_id, model_name="llava-v1.5-13b", model_base=None, load_8bit=False, load_4bit=False | |
| ) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device) | |
| def load_image(image_file): | |
| """确保 image 是 `PIL.Image`""" | |
| if isinstance(image_file, Image.Image): | |
| return image_file.convert("RGB") # 直接返回 `PIL.Image` | |
| elif isinstance(image_file, str) and (image_file.startswith('http') or image_file.startswith('https')): | |
| response = requests.get(image_file) | |
| return Image.open(BytesIO(response.content)).convert('RGB') | |
| else: # 这里如果 `image_file` 是路径 | |
| return Image.open(image_file).convert("RGB") | |
| def llava_infer(image, text, temperature, top_p, max_tokens): | |
| """LLaVA 模型推理""" | |
| if image is None or text.strip() == "": | |
| return "请提供图片和文本输入" | |
| # 预处理图像 | |
| image_data = load_image(image) | |
| image_tensor = image_processor.preprocess(image_data, return_tensors='pt')['pixel_values'].half().to(device) | |
| # **处理对话** | |
| conv_mode = "llava_v1" | |
| conv = conv_templates[conv_mode].copy() | |
| # 生成输入文本,添加特殊 token | |
| inp = DEFAULT_IMAGE_TOKEN + '\n' + text | |
| conv.append_message(conv.roles[0], inp) | |
| conv.append_message(conv.roles[1], None) | |
| prompt = conv.get_prompt() | |
| input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device) | |
| stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, timeout=20.0) | |
| # **执行推理** | |
| with torch.inference_mode(): | |
| thread = Thread(target=model.generate, kwargs=dict( | |
| inputs=input_ids, | |
| images=image_tensor, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_new_tokens=max_tokens, | |
| streamer=streamer, | |
| use_cache=True | |
| )) | |
| thread.start() | |
| response = "" | |
| prepend_space = False | |
| for new_text in streamer: | |
| if new_text == " ": | |
| prepend_space = True | |
| continue | |
| if new_text.endswith(stop_str): | |
| new_text = new_text[:-len(stop_str)].strip() | |
| prepend_space = False | |
| elif prepend_space: | |
| new_text = " " + new_text | |
| prepend_space = False | |
| response += new_text | |
| if prepend_space: | |
| response += " " | |
| thread.join() | |
| return response | |
| # **创建 Gradio Web 界面** | |
| with gr.Blocks(title="LLaVA 1.5-13B Web UI") as demo: | |
| gr.Markdown("# 🌋 LLaVA-1.5-13B Web Interface") | |
| gr.Markdown("上传图片并输入文本,LLaVA 将返回回答") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| image_input = gr.Image(type="pil", label="上传图片") | |
| text_input = gr.Textbox(placeholder="输入文本...", label="输入文本") | |
| temperature = gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="Temperature") | |
| top_p = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="Top P") | |
| max_tokens = gr.Slider(10, 1024, value=512, step=10, label="Max Tokens") | |
| submit_button = gr.Button("提交") | |
| with gr.Column(scale=7): | |
| chatbot_output = gr.Textbox(label="LLaVA 输出", interactive=False) | |
| submit_button.click(fn=llava_infer, inputs=[image_input, text_input, temperature, top_p, max_tokens], outputs=chatbot_output) | |
| # **启动 Gradio Web 界面** | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |