Spaces:
Running
on
Zero
Running
on
Zero
| from threading import Thread | |
| from typing import Dict | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoModelForVision2Seq, AutoProcessor, AutoTokenizer, TextIteratorStreamer | |
| TITLE = "<h1><center>Chat with PaliGemma-3B-Chat-v0.2</center></h1>" | |
| DESCRIPTION = "<h3><center>Visit <a href='https://huggingface.co/BUAADreamer/PaliGemma-3B-Chat-v0.2' target='_blank'>our model page</a> for details.</center></h3>" | |
| CSS = """ | |
| .duplicate-button { | |
| margin: auto !important; | |
| color: white !important; | |
| background: black !important; | |
| border-radius: 100vh !important; | |
| } | |
| """ | |
| model_id = "BUAADreamer/PaliGemma-3B-Chat-v0.2" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype="auto", device_map="auto") | |
| def stream_chat(message: Dict[str, str], history: list): | |
| # Turn 1: | |
| # {'text': 'what is this', 'files': ['image-xxx.jpg']} | |
| # [] | |
| # Turn 2: | |
| # {'text': 'continue?', 'files': []} | |
| # [[('image-xxx.jpg',), None], ['what is this', 'a image.']] | |
| image_path = None | |
| if len(message["files"]) != 0: | |
| image_path = message["files"][0] | |
| if len(history) != 0 and isinstance(history[0][0], tuple): | |
| image_path = history[0][0][0] | |
| history = history[1:] | |
| if image_path is not None: | |
| image = Image.open(image_path).convert("RGB") | |
| else: | |
| image = Image.new("RGB", (100, 100), (255, 255, 255)) | |
| pixel_values = processor(images=[image], return_tensors="pt").to(model.device)["pixel_values"] | |
| conversation = [] | |
| for prompt, answer in history: | |
| conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}]) | |
| conversation.append({"role": "user", "content": message["text"]}) | |
| input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt") | |
| image_token_id = tokenizer.convert_tokens_to_ids("<image>") | |
| image_prefix = torch.empty((1, getattr(processor, "image_seq_length")), dtype=input_ids.dtype).fill_(image_token_id) | |
| input_ids = torch.cat((image_prefix, input_ids), dim=-1).to(model.device) | |
| streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
| generate_kwargs = dict( | |
| input_ids=input_ids, | |
| pixel_values=pixel_values, | |
| streamer=streamer, | |
| max_new_tokens=256, | |
| do_sample=True, | |
| ) | |
| t = Thread(target=model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| output = "" | |
| for new_token in streamer: | |
| output += new_token | |
| yield output | |
| chatbot = gr.Chatbot(height=450) | |
| with gr.Blocks(css=CSS) as demo: | |
| gr.HTML(TITLE) | |
| gr.HTML(DESCRIPTION) | |
| gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button") | |
| gr.ChatInterface( | |
| fn=stream_chat, | |
| multimodal=True, | |
| chatbot=chatbot, | |
| fill_height=True, | |
| cache_examples=False, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |