Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import cv2 | |
| import torch | |
| from PIL import Image | |
| from pathlib import Path | |
| from threading import Thread | |
| from transformers import AutoModelForCausalLM, AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer | |
| import spaces | |
| import time | |
| TITLE = " google/gemma-3-270m-it " | |
| DESCRIPTION= """ | |
| It's so small | |
| """ | |
| IS_RTL = False | |
| TEXT_ALIGN = "left" | |
| # model config | |
| model_name = "google/gemma-3-270m-it" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype="auto", | |
| device_map="auto", | |
| attn_implementation="eager" | |
| ).eval() | |
| processor = AutoProcessor.from_pretrained(model_name) | |
| # I will add timestamp later | |
| def extract_video_frames(video_path, num_frames=8): | |
| cap = cv2.VideoCapture(video_path) | |
| frames = [] | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| step = max(total_frames // num_frames, 1) | |
| for i in range(num_frames): | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, i * step) | |
| ret, frame = cap.read() | |
| if ret: | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frames.append(Image.fromarray(frame)) | |
| cap.release() | |
| return frames | |
| def format_message(content, files): | |
| message_content = [] | |
| if content: | |
| parts = content.split('<image>') | |
| for i, part in enumerate(parts): | |
| if part.strip(): | |
| message_content.append({"type": "text", "text": part.strip()}) | |
| if i < len(parts) - 1 and files: | |
| img = Image.open(files.pop(0)) | |
| message_content.append({"type": "image", "image": img}) | |
| for file in files: | |
| file_path = file if isinstance(file, str) else file.name | |
| if Path(file_path).suffix.lower() in ['.jpg', '.jpeg', '.png']: | |
| img = Image.open(file_path) | |
| message_content.append({"type": "image", "image": img}) | |
| elif Path(file_path).suffix.lower() in ['.mp4', '.mov']: | |
| frames = extract_video_frames(file_path) | |
| for frame in frames: | |
| message_content.append({"type": "image", "image": frame}) | |
| return message_content | |
| def format_conversation_history(chat_history): | |
| messages = [] | |
| current_user_content = [] | |
| for item in chat_history: | |
| role = item["role"] | |
| content = item["content"] | |
| if role == "user": | |
| if isinstance(content, str): | |
| current_user_content.append({"type": "text", "text": content}) | |
| elif isinstance(content, list): | |
| current_user_content.extend(content) | |
| else: | |
| current_user_content.append({"type": "text", "text": str(content)}) | |
| elif role == "assistant": | |
| if current_user_content: | |
| messages.append({"role": "user", "content": current_user_content}) | |
| current_user_content = [] | |
| messages.append({"role": "assistant", "content": [{"type": "text", "text": str(content)}]}) | |
| if current_user_content: | |
| messages.append({"role": "user", "content": current_user_content}) | |
| return messages | |
| def generate_response(input_data, chat_history, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty): | |
| if isinstance(input_data, dict) and "text" in input_data: | |
| text = input_data["text"] | |
| files = input_data.get("files", []) | |
| else: | |
| text = str(input_data) | |
| files = [] | |
| new_message_content = format_message(text, files) | |
| new_message = {"role": "user", "content": new_message_content} | |
| system_message = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}] if system_prompt else [] | |
| processed_history = format_conversation_history(chat_history) | |
| messages = system_message + processed_history | |
| if messages and messages[-1]["role"] == "user": | |
| messages[-1]["content"].extend(new_message["content"]) | |
| else: | |
| messages.append(new_message) | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_tensors="pt", | |
| return_dict=True | |
| ).to(model.device) | |
| streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) | |
| generation_kwargs = dict( | |
| inputs, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| repetition_penalty=repetition_penalty | |
| ) | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| outputs = [] | |
| for text in streamer: | |
| outputs.append(text) | |
| yield "".join(outputs) | |
| chat_interface = gr.ChatInterface( | |
| fn=generate_response, | |
| chatbot=gr.Chatbot(rtl=IS_RTL, show_copy_button=True,type="messages"), | |
| additional_inputs=[ | |
| gr.Slider(label="Max new tokens", minimum=100, maximum=2000, step=1, value=512), | |
| gr.Textbox( | |
| label="System Prompt", | |
| value="You are a very helpful multimodal assistant", | |
| lines=4, | |
| placeholder="Change the settings", | |
| text_align = TEXT_ALIGN, rtl = IS_RTL | |
| ), | |
| gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.2), | |
| gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.4), | |
| gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=30), | |
| gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1), | |
| ], | |
| examples=[ | |
| [{"text": "Write a poem which describes potatoes"}], | |
| ], | |
| textbox=gr.MultimodalTextbox( | |
| rtl=IS_RTL, | |
| label="input", | |
| file_types=["image", "video"], | |
| file_count="multiple", | |
| placeholder="Input text, Any image or video will be ignored", | |
| ), | |
| cache_examples=False, | |
| type="messages", | |
| fill_height=True, | |
| stop_btn="Stop", | |
| css_paths=["style.css"], | |
| multimodal=True, | |
| title=TITLE, | |
| description=DESCRIPTION, | |
| theme=gr.themes.Soft(), | |
| ) | |
| if __name__ == "__main__": | |
| chat_interface.queue(max_size=20).launch() | |