Spaces:
Running
Running
| import os | |
| import json | |
| import uuid | |
| import time | |
| import copy | |
| import base64 | |
| import logging | |
| import argparse | |
| import math | |
| import multiprocessing as mp | |
| from io import BytesIO | |
| from typing import Generator, Any, Dict, Optional | |
| import spaces | |
| import torch | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| from decord import VideoReader, cpu | |
| from scipy.spatial import cKDTree | |
| # import modelscope_studio as mgr | |
| # 导入模型相关模块 | |
| try: | |
| from models import ModelMiniCPMV4_5 | |
| except ImportError: | |
| print("Warning: models module not found. Please ensure models.py is available.") | |
| class ModelMiniCPMV4_5: | |
| def __init__(self, model_path): | |
| self.model_path = model_path | |
| self.model = None | |
| def __call__(self, query): | |
| return "Model not loaded", 0 | |
| # 全局配置 | |
| ERROR_MSG = "Error, please retry" | |
| model_name = 'MiniCPM-V 4.5' | |
| disable_text_only = False # 允许纯文本消息,便于测试 | |
| DOUBLE_FRAME_DURATION = 30 | |
| MAX_NUM_FRAMES = 180 | |
| MAX_NUM_PACKING = 3 | |
| TIME_SCALE = 0.1 | |
| IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'} | |
| VIDEO_EXTENSIONS = {'.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'} | |
| ENABLE_PARALLEL_ENCODING = True | |
| PARALLEL_PROCESSES = None | |
| # 全局模型实例 | |
| global_model = None | |
| # 日志配置 | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # 全局模型配置 | |
| model_config = { | |
| 'model_path': None, | |
| 'model_type': None, | |
| 'instance_id': 0 | |
| } | |
| # 全局模型缓存(在GPU进程中) | |
| _gpu_model_cache = None | |
| def _initialize_gpu_model(): | |
| """在GPU进程中获取模型并移到GPU""" | |
| global _gpu_model_cache | |
| if _gpu_model_cache is None: | |
| logger.info(f"在GPU进程中初始化模型: {model_config['model_type']}") | |
| match model_config['model_type'].lower(): | |
| case 'minicpmv4_5': | |
| _gpu_model_cache = ModelMiniCPMV4_5(model_config['model_path']) | |
| case _: | |
| raise ValueError(f"Unsupported model type: {model_config['model_type']}") | |
| logger.info(f"模型在CPU上初始化完成") | |
| # 每次推理时将模型移到GPU | |
| if hasattr(_gpu_model_cache, 'model') and hasattr(_gpu_model_cache.model, 'to'): | |
| logger.info("将模型移到GPU...") | |
| _gpu_model_cache.model.to('cuda') | |
| elif hasattr(_gpu_model_cache, 'model') and hasattr(_gpu_model_cache.model, 'model') and hasattr(_gpu_model_cache.model.model, 'to'): | |
| logger.info("将模型移到GPU(嵌套模型)...") | |
| _gpu_model_cache.model.model.to('cuda') | |
| return _gpu_model_cache | |
| def gpu_handler(query): | |
| """GPU推理处理器""" | |
| model = _initialize_gpu_model() | |
| res, output_tokens = model({ | |
| "image": query["image"], | |
| "question": query["question"], | |
| "params": query.get("params", "{}"), | |
| "temporal_ids": query.get("temporal_ids", None) | |
| }) | |
| return { | |
| "result": res, | |
| "usage": {"output_tokens": output_tokens} | |
| } | |
| def gpu_stream_handler(query): | |
| """GPU流式推理处理器""" | |
| model = _initialize_gpu_model() | |
| params = json.loads(query.get("params", "{}")) | |
| params["stream"] = True | |
| query["params"] = json.dumps(params) | |
| try: | |
| generator = model({ | |
| "image": query["image"], | |
| "question": query["question"], | |
| "params": query["params"], | |
| "temporal_ids": query.get("temporal_ids", None) | |
| }) | |
| # 收集生成器的所有输出,避免序列化问题 | |
| full_response = "" | |
| for chunk in generator: | |
| full_response += chunk | |
| return full_response | |
| except Exception as e: | |
| logger.error(f"GPU stream handler error: {e}") | |
| return f"Stream error: {str(e)}" | |
| class Model: | |
| """模型封装类,不持有实际模型对象""" | |
| def __init__(self, model_path: str, model_type: str, instance_id: int = 0): | |
| self.instance_id = instance_id | |
| self.model_path = model_path | |
| self.model_type = model_type | |
| # 设置全局配置 | |
| model_config['model_path'] = model_path | |
| model_config['model_type'] = model_type | |
| model_config['instance_id'] = instance_id | |
| logger.info(f"实例 {instance_id}: 配置模型类型 {model_type}") | |
| logger.info(f"实例 {instance_id}: 模型路径 {model_path}") | |
| def handler(self, query): | |
| """非流式推理处理器""" | |
| return gpu_handler(query) | |
| def stream_handler(self, query): | |
| """流式推理处理器""" | |
| return gpu_stream_handler(query) | |
| def initialize_model(): | |
| """初始化全局模型""" | |
| global global_model, _gpu_model_cache | |
| # 默认配置 | |
| model_path = os.getenv('MODEL_PATH', 'openbmb/MiniCPM-V-4_5') | |
| model_type = os.getenv('MODEL_TYPE', 'minicpmv4_5') | |
| logger.info(f"="*50) | |
| logger.info(f"启动MiniCPM-V服务") | |
| logger.info(f"模型路径: {model_path}") | |
| logger.info(f"模型类型: {model_type}") | |
| logger.info(f"="*50) | |
| # 创建模型封装类 | |
| global_model = Model(model_path, model_type, 0) | |
| # 在主进程中预加载模型到CPU(可选,为了更快的首次推理) | |
| try: | |
| logger.info("在主进程中预加载模型到CPU...") | |
| match model_type.lower(): | |
| case 'minicpmv4_5': | |
| _gpu_model_cache = ModelMiniCPMV4_5(model_path) | |
| case _: | |
| raise ValueError(f"Unsupported model type: {model_type}") | |
| logger.info("模型在主进程CPU上预加载完成") | |
| except Exception as e: | |
| logger.warning(f"主进程预加载模型失败,将在GPU进程中加载: {e}") | |
| _gpu_model_cache = None | |
| return global_model | |
| # 工具函数 | |
| def get_file_extension(filename): | |
| return os.path.splitext(filename)[1].lower() | |
| def is_image(filename): | |
| return get_file_extension(filename) in IMAGE_EXTENSIONS | |
| def is_video(filename): | |
| return get_file_extension(filename) in VIDEO_EXTENSIONS | |
| def map_to_nearest_scale(values, scale): | |
| tree = cKDTree(np.asarray(scale)[:, None]) | |
| _, indices = tree.query(np.asarray(values)[:, None]) | |
| return np.asarray(scale)[indices] | |
| def group_array(arr, size): | |
| return [arr[i:i+size] for i in range(0, len(arr), size)] | |
| def encode_image(image): | |
| """编码单张图片""" | |
| if not isinstance(image, Image.Image): | |
| if hasattr(image, 'path'): | |
| image = Image.open(image.path) | |
| elif hasattr(image, 'file') and hasattr(image.file, 'path'): | |
| image = Image.open(image.file.path) | |
| elif hasattr(image, 'name'): | |
| image = Image.open(image.name) | |
| else: | |
| image_path = getattr(image, 'url', getattr(image, 'orig_name', str(image))) | |
| image = Image.open(image_path) | |
| # 调整图片大小 | |
| max_size = 448*16 | |
| if max(image.size) > max_size: | |
| w, h = image.size | |
| if w > h: | |
| new_w = max_size | |
| new_h = int(h * max_size / w) | |
| else: | |
| new_h = max_size | |
| new_w = int(w * max_size / h) | |
| image = image.resize((new_w, new_h), resample=Image.BICUBIC) | |
| # 转换为base64 | |
| buffered = BytesIO() | |
| image.save(buffered, format="png") | |
| im_b64 = base64.b64encode(buffered.getvalue()).decode() | |
| return [{"type": "image", "pairs": im_b64}] | |
| def encode_image_parallel(image_data): | |
| """并行图片编码包装函数""" | |
| try: | |
| return encode_image(image_data) | |
| except Exception as e: | |
| print(f"[Parallel encoding error] Image encoding failed: {e}") | |
| return None | |
| def encode_images_parallel(frames, num_processes=None): | |
| """多进程并行图片编码""" | |
| if not ENABLE_PARALLEL_ENCODING: | |
| print(f"[Parallel encoding] Parallel encoding disabled, using serial processing") | |
| encoded_frames = [] | |
| for frame in frames: | |
| encoded = encode_image(frame) | |
| if encoded: | |
| encoded_frames.extend(encoded) | |
| return encoded_frames | |
| if num_processes is None: | |
| cpu_cores = mp.cpu_count() | |
| if PARALLEL_PROCESSES: | |
| num_processes = PARALLEL_PROCESSES | |
| else: | |
| if len(frames) >= 50: | |
| num_processes = min(cpu_cores, len(frames), 32) | |
| elif len(frames) >= 20: | |
| num_processes = min(cpu_cores, len(frames), 16) | |
| else: | |
| num_processes = min(cpu_cores, len(frames), 8) | |
| print(f"[Parallel encoding] Starting parallel encoding of {len(frames)} frame images, using {num_processes} processes") | |
| if len(frames) <= 2: | |
| print(f"[Parallel encoding] Few images ({len(frames)} frames), using serial processing") | |
| encoded_frames = [] | |
| for frame in frames: | |
| encoded = encode_image(frame) | |
| if encoded: | |
| encoded_frames.extend(encoded) | |
| return encoded_frames | |
| start_time = time.time() | |
| try: | |
| with mp.Pool(processes=num_processes) as pool: | |
| results = pool.map(encode_image_parallel, frames) | |
| encoded_frames = [] | |
| for result in results: | |
| if result: | |
| encoded_frames.extend(result) | |
| total_time = time.time() - start_time | |
| print(f"[Parallel encoding] Parallel encoding completed, total time: {total_time:.3f}s, encoded {len(encoded_frames)} images") | |
| return encoded_frames | |
| except Exception as e: | |
| print(f"[Parallel encoding] Parallel processing failed, falling back to serial processing: {e}") | |
| encoded_frames = [] | |
| for frame in frames: | |
| encoded = encode_image(frame) | |
| if encoded: | |
| encoded_frames.extend(encoded) | |
| return encoded_frames | |
| def encode_video(video, choose_fps=None): | |
| """编码视频文件""" | |
| def uniform_sample(l, n): | |
| gap = len(l) / n | |
| idxs = [int(i * gap + gap / 2) for i in range(n)] | |
| return [l[i] for i in idxs] | |
| if hasattr(video, 'path'): | |
| video_path = video.path | |
| elif hasattr(video, 'file') and hasattr(video.file, 'path'): | |
| video_path = video.file.path | |
| elif hasattr(video, 'name'): | |
| video_path = video.name | |
| else: | |
| video_path = getattr(video, 'url', getattr(video, 'orig_name', str(video))) | |
| vr = VideoReader(video_path, ctx=cpu(0)) | |
| fps = vr.get_avg_fps() | |
| video_duration = len(vr) / fps | |
| frame_idx = [i for i in range(0, len(vr))] | |
| effective_fps = choose_fps if choose_fps else 1 | |
| if video_duration < DOUBLE_FRAME_DURATION and effective_fps <= 5: | |
| effective_fps = effective_fps * 2 | |
| packing_nums = 2 | |
| choose_frames = round(min(effective_fps, round(fps)) * min(MAX_NUM_FRAMES, video_duration)) | |
| elif effective_fps * int(video_duration) <= MAX_NUM_FRAMES: | |
| packing_nums = 1 | |
| choose_frames = round(min(effective_fps, round(fps)) * min(MAX_NUM_FRAMES, video_duration)) | |
| else: | |
| packing_size = math.ceil(video_duration * effective_fps / MAX_NUM_FRAMES) | |
| if packing_size <= MAX_NUM_PACKING: | |
| choose_frames = round(video_duration * effective_fps) | |
| packing_nums = packing_size | |
| else: | |
| choose_frames = round(MAX_NUM_FRAMES * MAX_NUM_PACKING) | |
| packing_nums = MAX_NUM_PACKING | |
| choose_idx = choose_frames | |
| frame_idx = np.array(uniform_sample(frame_idx, choose_idx)) | |
| frames = vr.get_batch(frame_idx).asnumpy() | |
| frame_idx_ts = frame_idx / fps | |
| scale = np.arange(0, video_duration, TIME_SCALE) | |
| frame_ts_id = map_to_nearest_scale(frame_idx_ts, scale) / TIME_SCALE | |
| frame_ts_id = frame_ts_id.astype(np.int32) | |
| assert len(frames) == len(frame_ts_id) | |
| frames = [Image.fromarray(v.astype('uint8')).convert('RGB') for v in frames] | |
| frame_ts_id_group = group_array(frame_ts_id.tolist(), packing_nums) | |
| print(f"[Performance] Starting image encoding, total {len(frames)} frames") | |
| if ENABLE_PARALLEL_ENCODING: | |
| print(f"[Image encoding] Using multi-process parallel encoding, CPU cores: {mp.cpu_count()}") | |
| encoded_frames = encode_images_parallel(frames, PARALLEL_PROCESSES) | |
| else: | |
| print("[Warning] Parallel encoding disabled, using serial processing") | |
| encoded_frames = [] | |
| for frame in frames: | |
| encoded = encode_image(frame) | |
| if encoded: | |
| encoded_frames.extend(encoded) | |
| return encoded_frames, frame_ts_id_group | |
| # 响应处理函数 | |
| def parse_thinking_response(response_text): | |
| """解析包含<think>标签的响应文本,支持流式解析""" | |
| import re | |
| # 完整的thinking标签匹配 | |
| complete_think_pattern = r'<think>(.*?)</think>' | |
| thinking_matches = re.findall(complete_think_pattern, response_text, re.DOTALL) | |
| if thinking_matches: | |
| # 有完整的thinking标签 | |
| thinking_content = "\n\n".join(thinking_matches).strip() | |
| print("thinking_content---:", thinking_content) | |
| formal_answer = re.sub(complete_think_pattern, '', response_text, flags=re.DOTALL).strip() | |
| return thinking_content, formal_answer | |
| else: | |
| # 检查是否有未完成的thinking标签 | |
| partial_think_match = re.search(r'<think>(.*?)$', response_text, re.DOTALL) | |
| if partial_think_match: | |
| # 有开始标签但没有结束标签,说明thinking内容正在输出中 | |
| # 返回特殊标识,表示正在thinking过程中 | |
| return "STREAMING", "" | |
| else: | |
| # 没有thinking标签,直接返回原文作为正式回答 | |
| return "", response_text.strip() | |
| def parse_thinking_response_for_final(response_text): | |
| """最终解析thinking响应,用于完成时的格式化""" | |
| import re | |
| # 首先尝试匹配完整的thinking标签 | |
| think_pattern = r'<think>(.*?)</think>' | |
| thinking_matches = re.findall(think_pattern, response_text, re.DOTALL) | |
| if thinking_matches: | |
| thinking_content = "\n\n".join(thinking_matches).strip() | |
| formal_answer = re.sub(think_pattern, '', response_text, flags=re.DOTALL).strip() | |
| print(f"[parse_final] 找到完整thinking标签, thinking长度: {len(thinking_content)}, answer长度: {len(formal_answer)}") | |
| else: | |
| # 如果没有完整标签,检查是否有未闭合的<think>标签 | |
| if '<think>' in response_text: | |
| think_start = response_text.find('<think>') | |
| if think_start != -1: | |
| # 提取thinking内容(从<think>之后到字符串结束) | |
| thinking_content = response_text[think_start + 7:].strip() # 跳过<think> | |
| # formal_answer是<think>之前的内容 | |
| formal_answer = response_text[:think_start].strip() | |
| # 如果formal_answer为空,说明整个响应都是thinking内容 | |
| if not formal_answer: | |
| formal_answer = "" # 没有正式回答 | |
| print(f"[parse_final] 找到未闭合thinking标签") | |
| print(f"[parse_final] thinking内容: '{thinking_content[:100]}...'") | |
| print(f"[parse_final] formal_answer: '{formal_answer[:100]}...'") | |
| else: | |
| thinking_content = "" | |
| formal_answer = response_text.strip() | |
| print(f"[parse_final] 无thinking标签, answer长度: {len(formal_answer)}") | |
| else: | |
| thinking_content = "" | |
| formal_answer = response_text.strip() | |
| print(f"[parse_final] 无thinking标签, answer长度: {len(formal_answer)}") | |
| return thinking_content, formal_answer | |
| def normalize_text_for_html(text): | |
| """轻量级文本规范化""" | |
| import re | |
| if not text: | |
| return "" | |
| text = re.sub(r"[\u200B\u200C\u200D\uFEFF]", "", text) | |
| lines = [line.strip() for line in text.split("\n")] | |
| text = "\n".join(lines) | |
| text = text.strip() | |
| return text | |
| def format_response_with_thinking(thinking_content, formal_answer): | |
| """格式化包含思考过程的响应""" | |
| print(f"[format_thinking] thinking_content长度: {len(thinking_content) if thinking_content else 0}") | |
| print(f"[format_thinking] formal_answer长度: {len(formal_answer) if formal_answer else 0}") | |
| print(f"[format_thinking] thinking_content前100字符: '{thinking_content[:100] if thinking_content else 'None'}...'") | |
| print(f"[format_thinking] formal_answer前100字符: '{formal_answer[:100] if formal_answer else 'None'}...'") | |
| # 检查内容是否为空 | |
| if not thinking_content and not formal_answer: | |
| print("[format_thinking] 警告:thinking_content和formal_answer都为空!") | |
| elif not formal_answer: | |
| print("[format_thinking] 警告:formal_answer为空!") | |
| elif not thinking_content: | |
| print("[format_thinking] 注意:thinking_content为空,将使用简化格式") | |
| # 添加一个唯一的ID来强制前端重新渲染 | |
| import uuid | |
| unique_id = uuid.uuid4().hex[:8] | |
| # 如果有thinking内容,显示完整的thinking格式 | |
| if thinking_content and thinking_content.strip(): | |
| formatted_response = f""" | |
| <div class="response-container" id="response-{unique_id}"> | |
| <div class="thinking-section"> | |
| <div class="thinking-header">🤔 think</div> | |
| <div class="thinking-content">{thinking_content}</div> | |
| </div> | |
| <div class="formal-section"> | |
| <div class="formal-header">💡 answer</div> | |
| <div class="formal-content">{formal_answer if formal_answer else '(无正式回答)'}</div> | |
| </div> | |
| </div> | |
| """ | |
| else: | |
| # 如果没有thinking内容,直接显示回答 | |
| content_to_show = formal_answer if formal_answer and formal_answer.strip() else "(空回答)" | |
| formatted_response = f""" | |
| <div class="response-container" id="response-{unique_id}"> | |
| <div class="formal-section"> | |
| <div class="formal-content">{content_to_show}</div> | |
| </div> | |
| </div> | |
| """ | |
| return "\n" + formatted_response.strip() + "\n" | |
| def check_mm_type(mm_file): | |
| """检查多媒体文件类型""" | |
| if hasattr(mm_file, 'path'): | |
| path = mm_file.path | |
| elif hasattr(mm_file, 'file') and hasattr(mm_file.file, 'path'): | |
| path = mm_file.file.path | |
| elif hasattr(mm_file, 'name'): | |
| path = mm_file.name | |
| else: | |
| path = getattr(mm_file, 'url', getattr(mm_file, 'orig_name', str(mm_file))) | |
| if is_image(path): | |
| return "image" | |
| if is_video(path): | |
| return "video" | |
| return None | |
| def encode_mm_file(mm_file, choose_fps=None): | |
| """编码多媒体文件""" | |
| if check_mm_type(mm_file) == 'image': | |
| return encode_image(mm_file), None | |
| if check_mm_type(mm_file) == 'video': | |
| encoded_frames, frame_ts_id_group = encode_video(mm_file, choose_fps) | |
| return encoded_frames, frame_ts_id_group | |
| return None, None | |
| def encode_message(_question, choose_fps=None): | |
| """编码消息""" | |
| import re | |
| files = _question.files if _question.files else [] | |
| question = _question.text if _question.text else "" | |
| message = [] | |
| temporal_ids = [] | |
| # 检查是否使用旧的占位符格式 | |
| pattern = r"\[mm_media\]\d+\[/mm_media\]" | |
| if re.search(pattern, question): | |
| # 旧格式:使用占位符 | |
| matches = re.split(pattern, question) | |
| if len(matches) != len(files) + 1: | |
| gr.Warning("Number of Images not match the placeholder in text, please refresh the page to restart!") | |
| # 不使用 assert,而是处理不匹配的情况 | |
| if len(matches) > len(files) + 1: | |
| matches = matches[:len(files) + 1] | |
| else: | |
| while len(matches) < len(files) + 1: | |
| matches.append("") | |
| text = matches[0].strip() | |
| if text: | |
| message.append({"type": "text", "pairs": text}) | |
| for i in range(len(files)): | |
| encoded_content, frame_ts_id_group = encode_mm_file(files[i], choose_fps) | |
| if encoded_content: | |
| message += encoded_content | |
| if frame_ts_id_group: | |
| temporal_ids.extend(frame_ts_id_group) | |
| if i + 1 < len(matches): | |
| text = matches[i + 1].strip() | |
| if text: | |
| message.append({"type": "text", "pairs": text}) | |
| else: | |
| # 新格式:简单的文本 + 文件列表 | |
| if question.strip(): | |
| message.append({"type": "text", "pairs": question.strip()}) | |
| for file in files: | |
| encoded_content, frame_ts_id_group = encode_mm_file(file, choose_fps) | |
| if encoded_content: | |
| message += encoded_content | |
| if frame_ts_id_group: | |
| temporal_ids.extend(frame_ts_id_group) | |
| return message, temporal_ids if temporal_ids else None | |
| def check_has_videos(_question): | |
| """检查是否包含视频""" | |
| images_cnt = 0 | |
| videos_cnt = 0 | |
| files = _question.files if _question.files else [] | |
| for file in files: | |
| if check_mm_type(file) == "image": | |
| images_cnt += 1 | |
| else: | |
| videos_cnt += 1 | |
| return images_cnt, videos_cnt | |
| def save_media_to_persistent_cache(_question, session_id): | |
| """将图片和视频保存到持久化缓存中,返回保存的路径信息""" | |
| import os | |
| import shutil | |
| import uuid | |
| from pathlib import Path | |
| files = _question.files if _question.files else [] | |
| saved_media = [] | |
| # 创建会话专用的媒体缓存目录 | |
| cache_dir = Path("./media_cache") / session_id | |
| cache_dir.mkdir(parents=True, exist_ok=True) | |
| for file in files: | |
| file_type = check_mm_type(file) | |
| if file_type in ["image", "video"]: | |
| try: | |
| # 获取原始文件路径 | |
| original_path = None | |
| if hasattr(file, 'name'): | |
| original_path = file.name | |
| elif hasattr(file, 'path'): | |
| original_path = file.path | |
| elif hasattr(file, 'file') and hasattr(file.file, 'path'): | |
| original_path = file.file.path | |
| else: | |
| continue | |
| if original_path and os.path.exists(original_path): | |
| # 生成唯一的文件名 | |
| file_ext = os.path.splitext(original_path)[1] | |
| prefix = "img" if file_type == "image" else "vid" | |
| unique_filename = f"{prefix}_{uuid.uuid4().hex[:8]}{file_ext}" | |
| cached_path = cache_dir / unique_filename | |
| # 复制文件到缓存目录 | |
| shutil.copy2(original_path, cached_path) | |
| saved_media.append({ | |
| 'type': file_type, | |
| 'original_path': original_path, | |
| 'cached_path': str(cached_path), | |
| 'filename': unique_filename | |
| }) | |
| print(f"[save_media_to_persistent_cache] {file_type}已保存: {cached_path}") | |
| except Exception as e: | |
| print(f"[save_media_to_persistent_cache] 保存{file_type}失败: {e}") | |
| continue | |
| return saved_media | |
| def format_user_message_with_files(_question, session_id=None): | |
| """格式化包含文件的用户消息,支持图片和视频显示""" | |
| user_text = _question.text if _question.text else "" | |
| files = _question.files if _question.files else [] | |
| if not files: | |
| return user_text, [] | |
| # 保存媒体文件到持久化缓存 | |
| saved_media = [] | |
| if session_id: | |
| saved_media = save_media_to_persistent_cache(_question, session_id) | |
| if len(files) == 1: | |
| file = files[0] | |
| file_type = check_mm_type(file) | |
| # 如果是图片或视频且已保存到缓存 | |
| if file_type in ["image", "video"] and saved_media: | |
| media_info = saved_media[0] | |
| if file_type == "image": | |
| if user_text: | |
| return f"🖼️ {user_text}", saved_media | |
| else: | |
| return "🖼️ 图片", saved_media | |
| elif file_type == "video": | |
| if user_text: | |
| return f"🎬 {user_text}", saved_media | |
| else: | |
| return "🎬 视频", saved_media | |
| else: | |
| # 其他文件类型,使用文本描述 | |
| return f"[1 file uploaded] {user_text}", saved_media | |
| else: | |
| # 多个文件,统计不同类型 | |
| image_count = len([m for m in saved_media if m['type'] == 'image']) | |
| video_count = len([m for m in saved_media if m['type'] == 'video']) | |
| other_count = len(files) - image_count - video_count | |
| # 构建描述文本 | |
| parts = [] | |
| if image_count > 0: | |
| parts.append(f"{image_count} image{'s' if image_count > 1 else ''}") | |
| if video_count > 0: | |
| parts.append(f"{video_count} video{'s' if video_count > 1 else ''}") | |
| if other_count > 0: | |
| parts.append(f"{other_count} other file{'s' if other_count > 1 else ''}") | |
| if parts: | |
| files_desc = ", ".join(parts) | |
| return f"[{files_desc} uploaded] {user_text}", saved_media | |
| else: | |
| return f"[{len(files)} files uploaded] {user_text}", saved_media | |
| def update_media_gallery(app_session): | |
| """更新媒体画廊显示(图片和视频)""" | |
| import os | |
| media_cache = app_session.get('media_cache', []) | |
| if not media_cache: | |
| return gr.update(value=[], visible=False) | |
| # 获取所有缓存媒体文件的路径(图片和视频都支持) | |
| media_paths = [media_info['cached_path'] for media_info in media_cache if os.path.exists(media_info['cached_path'])] | |
| if media_paths: | |
| return gr.update(value=media_paths, visible=True) | |
| else: | |
| return gr.update(value=[], visible=False) | |
| def format_fewshot_user_message(image_path, user_text): | |
| """格式化FewShot用户消息,支持图片显示""" | |
| if image_path and user_text: | |
| return (user_text, image_path) | |
| elif image_path: | |
| return ("", image_path) | |
| else: | |
| return user_text | |
| # 主要的聊天函数 | |
| def chat_direct(img_b64, msgs, ctx, params=None, vision_hidden_states=None, temporal_ids=None, session_id=None): | |
| """直接调用模型进行聊天(非流式)""" | |
| default_params = {"num_beams": 3, "repetition_penalty": 1.2, "max_new_tokens": 16284} | |
| if params is None: | |
| params = default_params | |
| use_streaming = params.get('stream', False) | |
| if use_streaming: | |
| return chat_stream_direct(img_b64, msgs, ctx, params, vision_hidden_states, temporal_ids, session_id) | |
| else: | |
| # 构建请求数据 | |
| query = { | |
| "image": img_b64, | |
| "question": json.dumps(msgs, ensure_ascii=True), | |
| "params": json.dumps(params, ensure_ascii=True), | |
| } | |
| if temporal_ids: | |
| query["temporal_ids"] = json.dumps(temporal_ids, ensure_ascii=True) | |
| if session_id: | |
| query["session_id"] = session_id | |
| try: | |
| # 直接调用模型 | |
| result = global_model.handler(query) | |
| raw_result = result['result'] | |
| # 清理结果 | |
| import re | |
| cleaned_result = re.sub(r'(<box>.*</box>)', '', raw_result) | |
| cleaned_result = cleaned_result.replace('<ref>', '') | |
| cleaned_result = cleaned_result.replace('</ref>', '') | |
| cleaned_result = cleaned_result.replace('<box>', '') | |
| cleaned_result = cleaned_result.replace('</box>', '') | |
| # 解析思考过程 | |
| thinking_content_raw, formal_answer_raw = parse_thinking_response_for_final(cleaned_result) | |
| thinking_content_fmt = normalize_text_for_html(thinking_content_raw) | |
| formal_answer_fmt = normalize_text_for_html(formal_answer_raw) | |
| formatted_result = format_response_with_thinking(thinking_content_fmt, formal_answer_fmt) | |
| context_result = formal_answer_raw if formal_answer_raw else cleaned_result | |
| return 0, formatted_result, context_result, None | |
| except Exception as e: | |
| print(f"Chat error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return -1, ERROR_MSG, None, None | |
| def chat_stream_direct(img_b64, msgs, ctx, params=None, vision_hidden_states=None, temporal_ids=None, session_id=None): | |
| """直接调用模型进行流式聊天""" | |
| try: | |
| # 构建请求数据 | |
| query = { | |
| "image": img_b64, | |
| "question": json.dumps(msgs, ensure_ascii=True), | |
| "params": json.dumps(params, ensure_ascii=True), | |
| } | |
| if temporal_ids: | |
| query["temporal_ids"] = json.dumps(temporal_ids, ensure_ascii=True) | |
| if session_id: | |
| query["session_id"] = session_id | |
| # 直接调用流式处理器 | |
| generator = global_model.stream_handler(query) | |
| full_response = "" | |
| for chunk in generator: | |
| full_response += chunk | |
| if not full_response: | |
| return -1, ERROR_MSG, None, None | |
| # 清理结果 | |
| import re | |
| cleaned_result = re.sub(r'(<box>.*</box>)', '', full_response) | |
| cleaned_result = cleaned_result.replace('<ref>', '') | |
| cleaned_result = cleaned_result.replace('</ref>', '') | |
| cleaned_result = cleaned_result.replace('<box>', '') | |
| cleaned_result = cleaned_result.replace('</box>', '') | |
| # 解析思考过程 | |
| thinking_content_raw, formal_answer_raw = parse_thinking_response_for_final(cleaned_result) | |
| thinking_content_fmt = normalize_text_for_html(thinking_content_raw) | |
| formal_answer_fmt = normalize_text_for_html(formal_answer_raw) | |
| formatted_result = format_response_with_thinking(thinking_content_fmt, formal_answer_fmt) | |
| context_result = formal_answer_raw if formal_answer_raw else cleaned_result | |
| return 0, formatted_result, context_result, None | |
| except Exception as e: | |
| print(f"Stream chat error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return -1, ERROR_MSG, None, None | |
| def chat_stream_character_generator(img_b64, msgs, ctx, params=None, vision_hidden_states=None, temporal_ids=None, stop_control=None, session_id=None): | |
| """字符级流式生成器""" | |
| print(f"[chat_stream_character_generator] Starting character-level streaming") | |
| print(f"[chat_stream_character_generator] stop_control: {stop_control}") | |
| try: | |
| # 构建请求数据 | |
| query = { | |
| "image": img_b64, | |
| "question": json.dumps(msgs, ensure_ascii=True), | |
| "params": json.dumps(params, ensure_ascii=True), | |
| } | |
| if temporal_ids: | |
| query["temporal_ids"] = json.dumps(temporal_ids, ensure_ascii=True) | |
| if session_id: | |
| query["session_id"] = session_id | |
| # 调用流式处理器 - 现在返回完整响应而不是生成器 | |
| full_response = global_model.stream_handler(query) | |
| # 清理响应 | |
| import re | |
| clean_response = re.sub(r'(<box>.*</box>)', '', full_response) | |
| clean_response = clean_response.replace('<ref>', '') | |
| clean_response = clean_response.replace('</ref>', '') | |
| clean_response = clean_response.replace('<box>', '') | |
| clean_response = clean_response.replace('</box>', '') | |
| # 逐字符yield以模拟流式输出 | |
| char_count = 0 | |
| for char in clean_response: | |
| # 检查停止标志 | |
| if stop_control and stop_control.get('stop_streaming', False): | |
| print(f"[chat_stream_character_generator] *** 在第{char_count}个字符处收到停止信号 ***") | |
| break | |
| char_count += 1 | |
| if char_count % 10 == 0: | |
| print(f"[chat_stream_character_generator] 已输出{char_count}个字符,stop_flag: {stop_control.get('stop_streaming', False) if stop_control else 'None'}") | |
| yield char | |
| # 添加小延迟以模拟流式效果 | |
| import time | |
| time.sleep(0.01) | |
| print(f"[chat_stream_character_generator] 流式输出完成,总共输出{char_count}个字符") | |
| except Exception as e: | |
| print(f"[chat_stream_character_generator] 异常: {e}") | |
| error_msg = f"Stream error: {str(e)}" | |
| for char in error_msg: | |
| yield char | |
| # UI组件创建函数 | |
| def create_component(params, comp='Slider'): | |
| if comp == 'Slider': | |
| return gr.Slider( | |
| minimum=params['minimum'], | |
| maximum=params['maximum'], | |
| value=params['value'], | |
| step=params['step'], | |
| interactive=params['interactive'], | |
| label=params['label'] | |
| ) | |
| elif comp == 'Radio': | |
| return gr.Radio( | |
| choices=params['choices'], | |
| value=params['value'], | |
| interactive=params['interactive'], | |
| label=params['label'] | |
| ) | |
| elif comp == 'Button': | |
| return gr.Button( | |
| value=params['value'], | |
| interactive=True | |
| ) | |
| elif comp == 'Checkbox': | |
| return gr.Checkbox( | |
| value=params['value'], | |
| interactive=params['interactive'], | |
| label=params['label'], | |
| info=params.get('info', None) | |
| ) | |
| def create_multimodal_input(upload_image_disabled=False, upload_video_disabled=False): | |
| # 使用标准的 Gradio 组件替代 MultimodalInput,添加预览功能 | |
| return gr.File( | |
| file_count="multiple", | |
| file_types=["image", "video"], | |
| label="Upload Images/Videos", | |
| interactive=not (upload_image_disabled and upload_video_disabled), | |
| show_label=True, | |
| height=200 # 设置高度以显示预览 | |
| ) | |
| # UI控制函数 | |
| def update_streaming_mode_state(params_form): | |
| """根据解码类型更新流式模式状态""" | |
| if params_form == 'Beam Search': | |
| return gr.update(value=False, interactive=False, info="Beam Search mode does not support streaming output") | |
| else: | |
| return gr.update(value=True, interactive=True, info="Enable real-time streaming response") | |
| def stop_streaming(_app_cfg): | |
| """停止流式输出""" | |
| _app_cfg['stop_streaming'] = True | |
| print(f"[stop_streaming] Set stop flag to True") | |
| return _app_cfg | |
| def reset_stop_flag(_app_cfg): | |
| """重置停止标志""" | |
| _app_cfg['stop_streaming'] = False | |
| print(f"[reset_stop_flag] Reset stop flag to False") | |
| return _app_cfg | |
| def check_and_handle_stop(_app_cfg, context="unknown"): | |
| """检查停止标志""" | |
| should_stop = _app_cfg.get('stop_streaming', False) | |
| is_streaming = _app_cfg.get('is_streaming', False) | |
| if should_stop: | |
| print(f"[check_and_handle_stop] *** Stop signal detected at {context} ***") | |
| print(f"[check_and_handle_stop] stop_streaming: {should_stop}, is_streaming: {is_streaming}") | |
| return True | |
| return False | |
| def stop_button_clicked(_app_cfg): | |
| """处理停止按钮点击""" | |
| print("[stop_button_clicked] *** Stop button clicked ***") | |
| print(f"[stop_button_clicked] Current state - is_streaming: {_app_cfg.get('is_streaming', False)}") | |
| print(f"[stop_button_clicked] Current state - stop_streaming: {_app_cfg.get('stop_streaming', False)}") | |
| _app_cfg['stop_streaming'] = True | |
| _app_cfg['is_streaming'] = False | |
| print(f"[stop_button_clicked] Set stop_streaming = True, is_streaming = False") | |
| return _app_cfg, gr.update(visible=False) | |
| # 主要的响应函数 | |
| def respond_stream(_question, _chat_bot, _app_cfg, params_form, thinking_mode, streaming_mode, fps_setting): | |
| """流式响应生成器""" | |
| print(f"[respond_stream] Called with streaming_mode: {streaming_mode}, fps_setting: {fps_setting}") | |
| _app_cfg['is_streaming'] = True | |
| _app_cfg['stop_streaming'] = False | |
| if params_form == 'Beam Search': | |
| streaming_mode = False | |
| print(f"[respond_stream] Beam Search模式,强制禁用流式模式") | |
| _app_cfg['is_streaming'] = False | |
| _context = _app_cfg['ctx'].copy() | |
| encoded_message, temporal_ids = encode_message(_question, fps_setting) | |
| _context.append({'role': 'user', 'contents': encoded_message}) | |
| images_cnt = _app_cfg['images_cnt'] | |
| videos_cnt = _app_cfg['videos_cnt'] | |
| files_cnts = check_has_videos(_question) | |
| if files_cnts[1] + videos_cnt > 1 or (files_cnts[1] + videos_cnt == 1 and files_cnts[0] + images_cnt > 0): | |
| gr.Warning("Only supports single video file input right now!") | |
| yield create_multimodal_input(True, True), _chat_bot, _app_cfg, gr.update(visible=False) | |
| return | |
| if disable_text_only and files_cnts[1] + videos_cnt + files_cnts[0] + images_cnt <= 0: | |
| gr.Warning("Please chat with at least one image or video.") | |
| yield create_multimodal_input(False, False), _chat_bot, _app_cfg, gr.update(visible=False) | |
| return | |
| if params_form == 'Beam Search': | |
| params = { | |
| 'sampling': False, | |
| 'num_beams': 3, | |
| 'repetition_penalty': 1.2, | |
| "max_new_tokens": 16284, | |
| "enable_thinking": thinking_mode, | |
| "stream": False | |
| } | |
| else: | |
| params = { | |
| 'sampling': True, | |
| 'top_p': 0.8, | |
| 'top_k': 100, | |
| 'temperature': 0.7, | |
| 'repetition_penalty': 1.03, | |
| "max_new_tokens": 16284, | |
| "enable_thinking": thinking_mode, | |
| "stream": streaming_mode | |
| } | |
| if files_cnts[1] + videos_cnt > 0: | |
| params["max_inp_length"] = 2048 * 10 | |
| params["use_image_id"] = False | |
| params["max_slice_nums"] = 1 | |
| images_cnt += files_cnts[0] | |
| videos_cnt += files_cnts[1] | |
| # 构建用户消息显示(流式模式) | |
| user_message, saved_images = format_user_message_with_files(_question, _app_cfg.get('session_id')) | |
| # 将媒体信息保存到会话状态中 | |
| if saved_images: | |
| if 'media_cache' not in _app_cfg: | |
| _app_cfg['media_cache'] = [] | |
| _app_cfg['media_cache'].extend(saved_images) | |
| _chat_bot.append((user_message, "")) | |
| _context.append({"role": "assistant", "contents": [{"type": "text", "pairs": ""}]}) | |
| gen = chat_stream_character_generator("", _context[:-1], None, params, None, temporal_ids, _app_cfg, _app_cfg['session_id']) | |
| upload_image_disabled = videos_cnt > 0 | |
| upload_video_disabled = videos_cnt > 0 or images_cnt > 0 | |
| yield create_multimodal_input(upload_image_disabled, upload_video_disabled), _chat_bot, _app_cfg, gr.update(visible=True) | |
| print(f"[respond_stream] 开始字符级流式输出循环") | |
| char_count = 0 | |
| accumulated_content = "" | |
| for _char in gen: | |
| char_count += 1 | |
| if check_and_handle_stop(_app_cfg, f"字符{char_count}"): | |
| break | |
| accumulated_content += _char | |
| _context[-1]["contents"][0]["pairs"] += _char | |
| # 实时显示内容(thinking模式也实时显示) | |
| if thinking_mode: | |
| # 尝试解析当前累积的内容 | |
| thinking_content_raw, formal_answer_raw = parse_thinking_response(accumulated_content) | |
| # 如果解析出了完整的thinking内容,使用格式化显示 | |
| if thinking_content_raw and thinking_content_raw != "STREAMING" and formal_answer_raw: | |
| thinking_content_fmt = normalize_text_for_html(thinking_content_raw) | |
| formal_answer_fmt = normalize_text_for_html(formal_answer_raw) | |
| formatted_display = format_response_with_thinking(thinking_content_fmt, formal_answer_fmt) | |
| _chat_bot[-1] = (user_message, formatted_display) | |
| else: | |
| # 正在thinking过程中或者还没有完整标签,直接显示原始内容(实时流式) | |
| _chat_bot[-1] = (user_message, accumulated_content) | |
| else: | |
| # 非thinking模式,直接显示累积内容 | |
| _chat_bot[-1] = (user_message, accumulated_content) | |
| if char_count % 5 == 0: # 更频繁的更新以提供更好的流式体验 | |
| print(f"[respond_stream] 已处理{char_count}个字符,stop_flag: {_app_cfg.get('stop_streaming', False)}") | |
| yield create_multimodal_input(upload_image_disabled, upload_video_disabled), _chat_bot, _app_cfg, gr.update(visible=True) | |
| time.sleep(0.02) # 稍微增加延迟以避免过于频繁的更新 | |
| else: | |
| yield create_multimodal_input(upload_image_disabled, upload_video_disabled), _chat_bot, _app_cfg, gr.update(visible=True) | |
| if _app_cfg.get('stop_streaming', False): | |
| print("[respond_stream] 流式输出已停止") | |
| # 最终处理thinking格式化 | |
| final_content = accumulated_content | |
| if thinking_mode: | |
| thinking_content_raw, formal_answer_raw = parse_thinking_response_for_final(final_content) | |
| thinking_content_fmt = normalize_text_for_html(thinking_content_raw) | |
| formal_answer_fmt = normalize_text_for_html(formal_answer_raw) | |
| formatted_result = format_response_with_thinking(thinking_content_fmt, formal_answer_fmt) | |
| _chat_bot[-1] = (user_message, formatted_result) | |
| _context[-1]["contents"][0]["pairs"] = formal_answer_raw if formal_answer_raw else final_content | |
| else: | |
| _chat_bot[-1] = (user_message, final_content) | |
| _context[-1]["contents"][0]["pairs"] = final_content | |
| _app_cfg['ctx'] = _context | |
| _app_cfg['images_cnt'] = images_cnt | |
| _app_cfg['videos_cnt'] = videos_cnt | |
| _app_cfg['is_streaming'] = False | |
| upload_image_disabled = videos_cnt > 0 | |
| upload_video_disabled = videos_cnt > 0 or images_cnt > 0 | |
| yield create_multimodal_input(upload_image_disabled, upload_video_disabled), _chat_bot, _app_cfg, gr.update(visible=False) | |
| def respond(_question, _chat_bot, _app_cfg, params_form, thinking_mode, streaming_mode, fps_setting): | |
| """主响应函数""" | |
| if 'session_id' not in _app_cfg: | |
| _app_cfg['session_id'] = uuid.uuid4().hex[:16] | |
| print(f"[会话] 为现有会话生成session_id: {_app_cfg['session_id']}") | |
| # 记录thinking模式状态变化 | |
| prev_thinking_mode = _app_cfg.get('last_thinking_mode', False) | |
| _app_cfg['last_thinking_mode'] = thinking_mode | |
| if prev_thinking_mode != thinking_mode: | |
| print(f"[respond] Thinking模式切换: {prev_thinking_mode} -> {thinking_mode}") | |
| # 强制清理可能的缓存状态 | |
| if hasattr(_app_cfg, 'thinking_cache'): | |
| del _app_cfg['thinking_cache'] | |
| # 添加额外的状态重置 | |
| if thinking_mode and not prev_thinking_mode: | |
| print("[respond] 启用thinking模式,重置相关状态") | |
| _app_cfg['thinking_enabled'] = True | |
| elif not thinking_mode and prev_thinking_mode: | |
| print("[respond] 禁用thinking模式") | |
| _app_cfg['thinking_enabled'] = False | |
| if params_form == 'Beam Search': | |
| streaming_mode = False | |
| print(f"[respond] Beam Search模式,强制禁用流式模式") | |
| if streaming_mode: | |
| print("[respond] 选择流式模式") | |
| yield from respond_stream(_question, _chat_bot, _app_cfg, params_form, thinking_mode, streaming_mode, fps_setting) | |
| return | |
| # 非流式模式 | |
| _context = _app_cfg['ctx'].copy() | |
| encoded_message, temporal_ids = encode_message(_question, fps_setting) | |
| _context.append({'role': 'user', 'contents': encoded_message}) | |
| images_cnt = _app_cfg['images_cnt'] | |
| videos_cnt = _app_cfg['videos_cnt'] | |
| files_cnts = check_has_videos(_question) | |
| if files_cnts[1] + videos_cnt > 1 or (files_cnts[1] + videos_cnt == 1 and files_cnts[0] + images_cnt > 0): | |
| gr.Warning("Only supports single video file input right now!") | |
| upload_image_disabled = videos_cnt > 0 | |
| upload_video_disabled = videos_cnt > 0 or images_cnt > 0 | |
| yield create_multimodal_input(upload_image_disabled, upload_video_disabled), _chat_bot, _app_cfg, gr.update(visible=False) | |
| return | |
| if disable_text_only and files_cnts[1] + videos_cnt + files_cnts[0] + images_cnt <= 0: | |
| gr.Warning("Please chat with at least one image or video.") | |
| upload_image_disabled = videos_cnt > 0 | |
| upload_video_disabled = videos_cnt > 0 or images_cnt > 0 | |
| yield create_multimodal_input(upload_image_disabled, upload_video_disabled), _chat_bot, _app_cfg, gr.update(visible=False) | |
| return | |
| if params_form == 'Beam Search': | |
| params = { | |
| 'sampling': False, | |
| 'num_beams': 3, | |
| 'repetition_penalty': 1.2, | |
| "max_new_tokens": 16284, | |
| "enable_thinking": thinking_mode, | |
| "stream": False | |
| } | |
| else: | |
| params = { | |
| 'sampling': True, | |
| 'top_p': 0.8, | |
| 'top_k': 100, | |
| 'temperature': 0.7, | |
| 'repetition_penalty': 1.03, | |
| "max_new_tokens": 16284, | |
| "enable_thinking": thinking_mode, | |
| "stream": False | |
| } | |
| if files_cnts[1] + videos_cnt > 0: | |
| params["max_inp_length"] = 2048 * 10 | |
| params["use_image_id"] = False | |
| params["max_slice_nums"] = 1 | |
| # 调用聊天函数 | |
| code, _answer, _context_answer, sts = chat_direct("", _context, None, params, None, temporal_ids, _app_cfg['session_id']) | |
| images_cnt += files_cnts[0] | |
| videos_cnt += files_cnts[1] | |
| if code == 0: | |
| context_content = _context_answer if _context_answer else _answer | |
| _context.append({"role": "assistant", "contents": [{"type": "text", "pairs": context_content}]}) | |
| # 根据thinking_mode决定是否应用thinking格式化 | |
| if thinking_mode: | |
| thinking_content_raw, formal_answer_raw = parse_thinking_response_for_final(_answer) | |
| thinking_content_fmt = normalize_text_for_html(thinking_content_raw) | |
| formal_answer_fmt = normalize_text_for_html(formal_answer_raw) | |
| print(f"[respond] 非流式模式 - thinking_mode: {thinking_mode}, thinking_content: '{thinking_content_raw[:50]}...'") | |
| formatted_result = format_response_with_thinking(thinking_content_fmt, formal_answer_fmt) | |
| else: | |
| print(f"[respond] 非流式模式 - thinking_mode: {thinking_mode}, 使用原始回答") | |
| formatted_result = _answer | |
| # 构建用户消息显示 | |
| user_message, saved_images = format_user_message_with_files(_question, _app_cfg.get('session_id')) | |
| # 将媒体信息保存到会话状态中 | |
| if saved_images: | |
| if 'media_cache' not in _app_cfg: | |
| _app_cfg['media_cache'] = [] | |
| _app_cfg['media_cache'].extend(saved_images) | |
| _chat_bot.append((user_message, formatted_result)) | |
| _app_cfg['ctx'] = _context | |
| _app_cfg['sts'] = sts | |
| else: | |
| _context.append({"role": "assistant", "contents": [{"type": "text", "pairs": "Error occurred during processing"}]}) | |
| # 构建用户消息显示(错误情况) | |
| user_message, saved_images = format_user_message_with_files(_question, _app_cfg.get('session_id')) | |
| # 将媒体信息保存到会话状态中 | |
| if saved_images: | |
| if 'media_cache' not in _app_cfg: | |
| _app_cfg['media_cache'] = [] | |
| _app_cfg['media_cache'].extend(saved_images) | |
| _chat_bot.append((user_message, "Error occurred during processing")) | |
| _app_cfg['images_cnt'] = images_cnt | |
| _app_cfg['videos_cnt'] = videos_cnt | |
| _app_cfg['is_streaming'] = False | |
| upload_image_disabled = videos_cnt > 0 | |
| upload_video_disabled = videos_cnt > 0 or images_cnt > 0 | |
| # 统一使用yield返回结果,确保与流式模式兼容 | |
| yield create_multimodal_input(upload_image_disabled, upload_video_disabled), _chat_bot, _app_cfg, gr.update(visible=False) | |
| # FewShot相关函数 | |
| def fewshot_add_demonstration(_image, _user_message, _assistant_message, _chat_bot, _app_cfg): | |
| if 'session_id' not in _app_cfg: | |
| _app_cfg['session_id'] = uuid.uuid4().hex[:16] | |
| print(f"[会话] 为FewShot示例生成session_id: {_app_cfg['session_id']}") | |
| ctx = _app_cfg["ctx"] | |
| # 构建用户消息 | |
| user_msg = "" | |
| if _image is not None: | |
| image = Image.open(_image).convert("RGB") | |
| ctx.append({"role": "user", "contents": [ | |
| *encode_image(image), | |
| {"type": "text", "pairs": _user_message} | |
| ]}) | |
| user_msg = f"[Image uploaded] {_user_message}" | |
| else: | |
| if _user_message: | |
| ctx.append({"role": "user", "contents": [{"type": "text", "pairs": _user_message}]}) | |
| user_msg = _user_message | |
| # 构建助手消息 | |
| if _assistant_message: | |
| ctx.append({"role": "assistant", "contents": [{"type": "text", "pairs": _assistant_message}]}) | |
| # 只有当用户消息和助手消息都存在时才添加到聊天记录 | |
| if user_msg and _assistant_message: | |
| formatted_user_msg = format_fewshot_user_message(_image, _user_message) if _image else user_msg | |
| _chat_bot.append([formatted_user_msg, _assistant_message]) | |
| return None, "", "", _chat_bot, _app_cfg | |
| def fewshot_respond(_image, _user_message, _chat_bot, _app_cfg, params_form, thinking_mode, streaming_mode, fps_setting): | |
| """FewShot响应函数""" | |
| print(f"[fewshot_respond] Called with streaming_mode: {streaming_mode}") | |
| if 'session_id' not in _app_cfg: | |
| _app_cfg['session_id'] = uuid.uuid4().hex[:16] | |
| print(f"[会话] 为FewShot会话生成session_id: {_app_cfg['session_id']}") | |
| if params_form == 'Beam Search': | |
| streaming_mode = False | |
| print(f"[fewshot_respond] Beam Search模式,强制禁用流式模式") | |
| user_message_contents = [] | |
| _context = _app_cfg["ctx"].copy() | |
| images_cnt = _app_cfg["images_cnt"] | |
| temporal_ids = None | |
| if _image: | |
| image = Image.open(_image).convert("RGB") | |
| user_message_contents += encode_image(image) | |
| images_cnt += 1 | |
| if _user_message: | |
| user_message_contents += [{"type": "text", "pairs": _user_message}] | |
| if user_message_contents: | |
| _context.append({"role": "user", "contents": user_message_contents}) | |
| if params_form == 'Beam Search': | |
| params = { | |
| 'sampling': False, | |
| 'num_beams': 3, | |
| 'repetition_penalty': 1.2, | |
| "max_new_tokens": 16284, | |
| "enable_thinking": thinking_mode, | |
| "stream": False | |
| } | |
| else: | |
| params = { | |
| 'sampling': True, | |
| 'top_p': 0.8, | |
| 'top_k': 100, | |
| 'temperature': 0.7, | |
| 'repetition_penalty': 1.03, | |
| "max_new_tokens": 16284, | |
| "enable_thinking": thinking_mode, | |
| "stream": streaming_mode | |
| } | |
| if disable_text_only and images_cnt == 0: | |
| gr.Warning("Please chat with at least one image or video.") | |
| yield _image, _user_message, '', _chat_bot, _app_cfg | |
| return | |
| if streaming_mode: | |
| print(f"[fewshot_respond] Using streaming mode") | |
| _app_cfg['is_streaming'] = True | |
| _app_cfg['stop_streaming'] = False | |
| if _image: | |
| user_msg = format_fewshot_user_message(_image, _user_message) | |
| _chat_bot.append([user_msg, ""]) | |
| else: | |
| _chat_bot.append([_user_message, ""]) | |
| _context.append({"role": "assistant", "contents": [{"type": "text", "pairs": ""}]}) | |
| _app_cfg['stop_streaming'] = False | |
| gen = chat_stream_character_generator("", _context[:-1], None, params, None, temporal_ids, _app_cfg, _app_cfg['session_id']) | |
| yield _image, _user_message, '', _chat_bot, _app_cfg | |
| accumulated_content = "" | |
| for _char in gen: | |
| if _app_cfg.get('stop_streaming', False): | |
| print("[fewshot_respond] 收到停止信号,中断流式响应") | |
| break | |
| accumulated_content += _char | |
| _context[-1]["contents"][0]["pairs"] += _char | |
| # 实时解析和格式化thinking内容 | |
| if thinking_mode: | |
| # 尝试解析当前累积的内容 | |
| thinking_content_raw, formal_answer_raw = parse_thinking_response(accumulated_content) | |
| # 如果解析出了完整的thinking内容,使用格式化显示 | |
| if thinking_content_raw and thinking_content_raw != "STREAMING" and formal_answer_raw: | |
| thinking_content_fmt = normalize_text_for_html(thinking_content_raw) | |
| formal_answer_fmt = normalize_text_for_html(formal_answer_raw) | |
| formatted_display = format_response_with_thinking(thinking_content_fmt, formal_answer_fmt) | |
| _chat_bot[-1] = (_chat_bot[-1][0], formatted_display) | |
| else: | |
| # 正在thinking过程中或者还没有完整标签,直接显示原始内容(实时流式) | |
| _chat_bot[-1] = (_chat_bot[-1][0], accumulated_content) | |
| else: | |
| # 非thinking模式,直接显示累积内容 | |
| _chat_bot[-1] = (_chat_bot[-1][0], accumulated_content) | |
| yield _image, _user_message, '', _chat_bot, _app_cfg | |
| final_content = _context[-1]["contents"][0]["pairs"] | |
| _app_cfg['ctx'] = _context | |
| _app_cfg['images_cnt'] = images_cnt | |
| _app_cfg['is_streaming'] = False | |
| yield _image, '', '', _chat_bot, _app_cfg | |
| else: | |
| # 非流式模式 | |
| code, _answer, _context_answer, sts = chat_direct("", _context, None, params, None, temporal_ids, _app_cfg['session_id']) | |
| context_content = _context_answer if _context_answer else _answer | |
| _context.append({"role": "assistant", "contents": [{"type": "text", "pairs": context_content}]}) | |
| if _image: | |
| user_msg = format_fewshot_user_message(_image, _user_message) | |
| _chat_bot.append([user_msg, _answer]) | |
| else: | |
| _chat_bot.append([_user_message, _answer]) | |
| if code == 0: | |
| _app_cfg['ctx'] = _context | |
| _app_cfg['sts'] = sts | |
| _app_cfg['images_cnt'] = images_cnt | |
| _app_cfg['is_streaming'] = False | |
| yield None, '', '', _chat_bot, _app_cfg | |
| # 其他UI函数 | |
| def regenerate_button_clicked(_question, _image, _user_message, _assistant_message, _chat_bot, _app_cfg, params_form, thinking_mode, streaming_mode, fps_setting): | |
| print(f"[regenerate] streaming_mode: {streaming_mode}") | |
| print(f"[regenerate] thinking_mode: {thinking_mode}") | |
| print(f"[regenerate] chat_type: {_app_cfg.get('chat_type', 'unknown')}") | |
| if params_form == 'Beam Search': | |
| streaming_mode = False | |
| print(f"[regenerate] Beam Search模式,强制禁用流式模式") | |
| if len(_chat_bot) <= 1 or not _chat_bot[-1][1]: | |
| gr.Warning('No question for regeneration.') | |
| yield _question, _image, _user_message, _assistant_message, _chat_bot, _app_cfg | |
| return | |
| if _app_cfg["chat_type"] == "Chat": | |
| images_cnt = _app_cfg['images_cnt'] | |
| videos_cnt = _app_cfg['videos_cnt'] | |
| _question = _chat_bot[-1][0] | |
| _chat_bot = _chat_bot[:-1] | |
| _app_cfg['ctx'] = _app_cfg['ctx'][:-2] | |
| files_cnts = check_has_videos(_question) | |
| images_cnt -= files_cnts[0] | |
| videos_cnt -= files_cnts[1] | |
| _app_cfg['images_cnt'] = images_cnt | |
| _app_cfg['videos_cnt'] = videos_cnt | |
| print(f"[regenerate] About to call respond with streaming_mode: {streaming_mode}") | |
| for result in respond(_question, _chat_bot, _app_cfg, params_form, thinking_mode, streaming_mode, fps_setting): | |
| new_input, _chat_bot, _app_cfg, _stop_button = result | |
| _question = new_input | |
| yield _question, _image, _user_message, _assistant_message, _chat_bot, _app_cfg | |
| else: | |
| # 在 tuples 格式下,_chat_bot[-1][0] 是字符串 | |
| last_user_message = _chat_bot[-1][0] | |
| last_image = None | |
| # 检查消息是否包含图片标识 | |
| if "[Image uploaded]" in last_user_message: | |
| # 从消息中提取实际的用户消息 | |
| last_user_message = last_user_message.replace("[Image uploaded] ", "") | |
| # 注意:在简化的 tuples 格式下,我们无法直接获取图片文件 | |
| # 这里需要根据实际需要进行处理 | |
| _chat_bot = _chat_bot[:-1] | |
| _app_cfg['ctx'] = _app_cfg['ctx'][:-2] | |
| print(f"[regenerate] About to call fewshot_respond with streaming_mode: {streaming_mode}") | |
| for result in fewshot_respond(last_image, last_user_message, _chat_bot, _app_cfg, params_form, thinking_mode, streaming_mode, fps_setting): | |
| _image, _user_message, _assistant_message, _chat_bot, _app_cfg = result | |
| yield _question, _image, _user_message, _assistant_message, _chat_bot, _app_cfg | |
| def flushed(): | |
| return gr.update(interactive=True) | |
| def clear_media_cache(session_id): | |
| """清理指定会话的媒体缓存""" | |
| import shutil | |
| from pathlib import Path | |
| try: | |
| cache_dir = Path("./media_cache") / session_id | |
| if cache_dir.exists(): | |
| shutil.rmtree(cache_dir) | |
| print(f"[clear_media_cache] 已清理会话 {session_id} 的媒体缓存") | |
| except Exception as e: | |
| print(f"[clear_media_cache] 清理缓存失败: {e}") | |
| def clear(txt_input, file_upload, chat_bot, app_session): | |
| # 清理旧会话的媒体缓存 | |
| if 'session_id' in app_session: | |
| clear_media_cache(app_session['session_id']) | |
| chat_bot = copy.deepcopy(init_conversation) | |
| app_session['sts'] = None | |
| app_session['ctx'] = [] | |
| app_session['images_cnt'] = 0 | |
| app_session['videos_cnt'] = 0 | |
| app_session['stop_streaming'] = False | |
| app_session['is_streaming'] = False | |
| app_session['media_cache'] = [] # 清空媒体缓存信息 | |
| app_session['last_thinking_mode'] = False # 重置thinking模式状态 | |
| app_session['session_id'] = uuid.uuid4().hex[:16] | |
| print(f"[会话] 生成新会话ID: {app_session['session_id']}") | |
| return "", None, gr.update(value=[], visible=False), gr.update(value=[], visible=False), chat_bot, app_session, None, '', '' | |
| def select_chat_type(_tab, _app_cfg): | |
| _app_cfg["chat_type"] = _tab | |
| return _app_cfg | |
| # UI配置 | |
| form_radio = { | |
| 'choices': ['Beam Search', 'Sampling'], | |
| 'value': 'Sampling', | |
| 'interactive': True, | |
| 'label': 'Decode Type' | |
| } | |
| thinking_checkbox = { | |
| 'value': False, | |
| 'interactive': True, | |
| 'label': 'Enable Thinking Mode', | |
| } | |
| streaming_checkbox = { | |
| 'value': True, | |
| 'interactive': True, | |
| 'label': 'Enable Streaming Mode', | |
| } | |
| fps_slider = { | |
| 'minimum': 1, | |
| 'maximum': 20, | |
| 'value': 3, | |
| 'step': 1, | |
| 'interactive': True, | |
| 'label': 'Custom FPS for Video Processing' | |
| } | |
| init_conversation = [ | |
| ["", "You can talk to me now"] | |
| ] | |
| css = """ | |
| video { height: auto !important; } | |
| .example label { font-size: 16px;} | |
| /* Current Media Gallery 滚动条样式 - 使用class选择器更安全 */ | |
| .current-media-gallery { | |
| overflow-y: auto !important; | |
| max-height: 600px !important; | |
| position: relative !important; | |
| } | |
| /* 确保只影响特定的Gallery容器内部 */ | |
| .current-media-gallery > div, | |
| .current-media-gallery .gallery-container { | |
| overflow-y: auto !important; | |
| max-height: 580px !important; | |
| } | |
| .current-media-gallery .gallery-item { | |
| margin-bottom: 10px !important; | |
| } | |
| /* 只为Current Media Gallery自定义滚动条样式 */ | |
| .current-media-gallery::-webkit-scrollbar, | |
| .current-media-gallery > div::-webkit-scrollbar, | |
| .current-media-gallery .gallery-container::-webkit-scrollbar { | |
| width: 8px !important; | |
| } | |
| .current-media-gallery::-webkit-scrollbar-track, | |
| .current-media-gallery > div::-webkit-scrollbar-track, | |
| .current-media-gallery .gallery-container::-webkit-scrollbar-track { | |
| background: #f1f1f1 !important; | |
| border-radius: 4px !important; | |
| } | |
| .current-media-gallery::-webkit-scrollbar-thumb, | |
| .current-media-gallery > div::-webkit-scrollbar-thumb, | |
| .current-media-gallery .gallery-container::-webkit-scrollbar-thumb { | |
| background: #c1c1c1 !important; | |
| border-radius: 4px !important; | |
| } | |
| .current-media-gallery::-webkit-scrollbar-thumb:hover, | |
| .current-media-gallery > div::-webkit-scrollbar-thumb:hover, | |
| .current-media-gallery .gallery-container::-webkit-scrollbar-thumb:hover { | |
| background: #a8a8a8 !important; | |
| } | |
| /* 隐藏Current Media的不必要元素 */ | |
| .current-media-gallery .upload-container, | |
| .current-media-gallery .drop-zone, | |
| .current-media-gallery .file-upload, | |
| .current-media-gallery .upload-text, | |
| .current-media-gallery .drop-text { | |
| display: none !important; | |
| } | |
| .current-media-gallery .clear-button, | |
| .current-media-gallery .delete-button, | |
| .current-media-gallery .remove-button { | |
| display: none !important; | |
| } | |
| /* 当Gallery为空时隐藏标签和占位文本 */ | |
| .current-media-gallery:not([style*="display: none"]) .gallery-container:empty::after { | |
| content: ""; | |
| display: none; | |
| } | |
| .current-media-gallery .empty-gallery-text, | |
| .current-media-gallery .placeholder-text { | |
| display: none !important; | |
| } | |
| /* 确保滚动条不会影响到其他组件 */ | |
| .current-media-gallery { | |
| isolation: isolate !important; | |
| } | |
| /* 重置其他Gallery组件的滚动条样式,防止被污染 */ | |
| .gradio-gallery:not(.current-media-gallery)::-webkit-scrollbar { | |
| width: initial !important; | |
| } | |
| .gradio-gallery:not(.current-media-gallery)::-webkit-scrollbar-track { | |
| background: initial !important; | |
| border-radius: initial !important; | |
| } | |
| .gradio-gallery:not(.current-media-gallery)::-webkit-scrollbar-thumb { | |
| background: initial !important; | |
| border-radius: initial !important; | |
| } | |
| /* 确保chatbot不受影响 */ | |
| .thinking-chatbot::-webkit-scrollbar { | |
| width: initial !important; | |
| } | |
| .thinking-chatbot::-webkit-scrollbar-track { | |
| background: initial !important; | |
| } | |
| .thinking-chatbot::-webkit-scrollbar-thumb { | |
| background: initial !important; | |
| } | |
| /* 思考过程和正式回答的样式 */ | |
| .response-container { | |
| margin: 10px 0; | |
| } | |
| .thinking-section { | |
| background: linear-gradient(135deg, #f8f9ff 0%, #f0f4ff 100%); | |
| border: 1px solid #d1d9ff; | |
| border-radius: 12px; | |
| padding: 16px; | |
| margin-bottom: 0px; | |
| box-shadow: 0 2px 8px rgba(67, 90, 235, 0.1); | |
| } | |
| .thinking-header { | |
| font-weight: 600; | |
| color: #4c5aa3; | |
| font-size: 14px; | |
| margin-bottom: 12px; | |
| display: flex; | |
| align-items: center; | |
| gap: 8px; | |
| } | |
| .thinking-content { | |
| color: #5a6ba8; | |
| font-size: 13px; | |
| line-height: 1; | |
| font-style: italic; | |
| background: rgba(255, 255, 255, 0.6); | |
| padding: 12px; | |
| border-radius: 8px; | |
| border-left: 3px solid #4c5aa3; | |
| white-space: pre-wrap; | |
| } | |
| .formal-section { | |
| background: linear-gradient(135deg, #ffffff 0%, #f8f9fa 100%); | |
| border: 1px solid #e9ecef; | |
| border-radius: 12px; | |
| padding: 16px; | |
| box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05); | |
| } | |
| .formal-header { | |
| font-weight: 600; | |
| color: #28a745; | |
| font-size: 14px; | |
| margin-bottom: 12px; | |
| display: flex; | |
| align-items: center; | |
| gap: 8px; | |
| } | |
| .formal-content { | |
| color: #333; | |
| font-size: 14px; | |
| line-height: 1; | |
| white-space: pre-wrap; | |
| } | |
| /* 聊天机器人容器样式 */ | |
| .thinking-chatbot .message { | |
| border-radius: 12px; | |
| overflow: visible; | |
| margin-top: 0 !important; | |
| margin-bottom: 0 !important; | |
| } | |
| .thinking-chatbot .message-wrap { | |
| margin-top: 0 !important; | |
| margin-bottom: 0 !important; | |
| } | |
| .thinking-chatbot .message.bot { | |
| background: transparent !important; | |
| border: none !important; | |
| padding: 8px !important; | |
| } | |
| .thinking-chatbot .message.bot .content { | |
| background: transparent !important; | |
| } | |
| """ | |
| introduction = """ | |
| ## Features: | |
| 1. Chat with single image | |
| 2. Chat with multiple images | |
| 3. Chat with video | |
| 4. Streaming Mode: Real-time response streaming | |
| 5. Thinking Mode: Show model reasoning process | |
| Click `How to use` tab to see examples. | |
| """ | |
| # 主应用 | |
| def create_app(): | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Tab(model_name): | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=300): | |
| gr.Markdown(value=introduction) | |
| params_form = create_component(form_radio, comp='Radio') | |
| thinking_mode = create_component(thinking_checkbox, comp='Checkbox') | |
| streaming_mode = create_component(streaming_checkbox, comp='Checkbox') | |
| fps_setting = create_component(fps_slider, comp='Slider') | |
| regenerate = create_component({'value': 'Regenerate'}, comp='Button') | |
| clear_button = create_component({'value': 'Clear History'}, comp='Button') | |
| stop_button = gr.Button("Stop", visible=False) | |
| with gr.Column(scale=3, min_width=500): | |
| initial_session_id = uuid.uuid4().hex[:16] | |
| print(f"[会话] 初始化会话,生成session_id: {initial_session_id}") | |
| app_session = gr.State({ | |
| 'sts': None, 'ctx': [], 'images_cnt': 0, 'videos_cnt': 0, | |
| 'chat_type': 'Chat', 'stop_streaming': False, 'is_streaming': False, | |
| 'session_id': initial_session_id, 'media_cache': [], 'last_thinking_mode': False | |
| }) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| chat_bot = gr.Chatbot( | |
| label=f"Chat with {model_name}", | |
| value=copy.deepcopy(init_conversation), | |
| height=600, | |
| elem_classes="thinking-chatbot" | |
| ) | |
| with gr.Column(scale=1, min_width=200): | |
| current_images = gr.Gallery( | |
| label="Current Media", | |
| show_label=True, | |
| elem_id="current_media", | |
| elem_classes="current-media-gallery", | |
| columns=1, | |
| rows=1, # 设为1行,让内容可以垂直滚动 | |
| height=600, | |
| visible=False, | |
| container=True, # 启用容器模式 | |
| allow_preview=True, # 允许预览 | |
| show_download_button=False, # 隐藏下载按钮 | |
| interactive=False, # 禁用交互,防止用户上传/删除 | |
| show_share_button=False # 隐藏分享按钮 | |
| ) | |
| with gr.Tab("Chat") as chat_tab: | |
| chat_tab_label = gr.Textbox(value="Chat", interactive=False, visible=False) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| txt_input = gr.Textbox( | |
| placeholder="Type your message here...", | |
| label="Message", | |
| lines=2 | |
| ) | |
| with gr.Column(scale=1): | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| file_upload = create_multimodal_input() | |
| # 添加图片预览组件 | |
| file_preview = gr.Gallery( | |
| label="Uploaded Files Preview", | |
| show_label=True, | |
| elem_id="file_preview", | |
| columns=3, | |
| rows=2, | |
| height="auto", | |
| visible=False | |
| ) | |
| # 添加文件上传时的预览更新 | |
| def update_file_preview(files): | |
| if files: | |
| # 过滤出图片文件进行预览 | |
| image_files = [] | |
| for file in files: | |
| if hasattr(file, 'name'): | |
| file_path = file.name | |
| else: | |
| file_path = str(file) | |
| # 检查是否是图片文件 | |
| if any(file_path.lower().endswith(ext) for ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp']): | |
| image_files.append(file_path) | |
| if image_files: | |
| return gr.update(value=image_files, visible=True) | |
| return gr.update(value=[], visible=False) | |
| file_upload.change( | |
| update_file_preview, | |
| inputs=[file_upload], | |
| outputs=[file_preview] | |
| ) | |
| # 创建一个包装函数来处理新的输入格式 | |
| def handle_submit(message, files, chat_bot, current_images_gallery, app_session, params_form, thinking_mode, streaming_mode, fps_setting): | |
| print(f"[handle_submit] 收到输入: message='{message}', files={files}, chat_bot长度={len(chat_bot)}") | |
| # 如果消息为空且没有文件,直接返回 | |
| if not message and not files: | |
| print("[handle_submit] 消息和文件都为空,直接返回") | |
| return message, files, chat_bot, current_images_gallery, app_session, gr.update(visible=False) | |
| # 模拟原来的 MultimodalInput 格式 | |
| class MockInput: | |
| def __init__(self, text, files): | |
| self.text = text | |
| self.files = files if files else [] | |
| mock_question = MockInput(message, files) | |
| print(f"[handle_submit] 创建MockInput: text='{mock_question.text}', files={len(mock_question.files)}") | |
| # respond 函数返回生成器,我们需要逐步yield结果 | |
| result_generator = respond(mock_question, chat_bot, app_session, params_form, thinking_mode, streaming_mode, fps_setting) | |
| # 如果是生成器,逐步yield | |
| if hasattr(result_generator, '__iter__') and not isinstance(result_generator, (str, bytes, tuple)): | |
| print("[handle_submit] 使用生成器模式") | |
| for result in result_generator: | |
| new_file_input, updated_chat_bot, updated_app_session, stop_btn_update = result | |
| print(f"[handle_submit] yield结果: chat_bot长度={len(updated_chat_bot)}") | |
| # 更新媒体显示 | |
| media_gallery_update = update_media_gallery(updated_app_session) | |
| # 返回正确的输出格式 | |
| yield "", None, updated_chat_bot, media_gallery_update, updated_app_session, stop_btn_update | |
| else: | |
| print("[handle_submit] 使用非生成器模式") | |
| # 如果不是生成器,直接返回 | |
| new_file_input, updated_chat_bot, updated_app_session, stop_btn_update = result_generator | |
| print(f"[handle_submit] 直接返回结果: chat_bot长度={len(updated_chat_bot)}") | |
| # 更新图片显示 | |
| image_gallery_update = update_image_gallery(updated_app_session) | |
| yield "", None, updated_chat_bot, image_gallery_update, updated_app_session, stop_btn_update | |
| submit_btn.click( | |
| handle_submit, | |
| [txt_input, file_upload, chat_bot, current_images, app_session, params_form, thinking_mode, streaming_mode, fps_setting], | |
| [txt_input, file_upload, chat_bot, current_images, app_session, stop_button] | |
| ) | |
| with gr.Tab("Few Shot", visible=False) as fewshot_tab: | |
| fewshot_tab_label = gr.Textbox(value="Few Shot", interactive=False, visible=False) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(type="filepath", sources=["upload"]) | |
| with gr.Column(scale=3): | |
| user_message = gr.Textbox(label="User") | |
| assistant_message = gr.Textbox(label="Assistant") | |
| with gr.Row(): | |
| add_demonstration_button = gr.Button("Add Example") | |
| generate_button = gr.Button(value="Generate", variant="primary") | |
| add_demonstration_button.click( | |
| fewshot_add_demonstration, | |
| [image_input, user_message, assistant_message, chat_bot, app_session], | |
| [image_input, user_message, assistant_message, chat_bot, app_session] | |
| ) | |
| generate_button.click( | |
| fewshot_respond, | |
| [image_input, user_message, chat_bot, app_session, params_form, thinking_mode, streaming_mode, fps_setting], | |
| [image_input, user_message, assistant_message, chat_bot, app_session] | |
| ) | |
| chat_tab.select( | |
| select_chat_type, | |
| [chat_tab_label, app_session], | |
| [app_session] | |
| ) | |
| chat_tab.select( | |
| clear, | |
| [txt_input, file_upload, chat_bot, app_session], | |
| [txt_input, file_upload, file_preview, chat_bot, app_session, image_input, user_message, assistant_message] | |
| ) | |
| fewshot_tab.select( | |
| select_chat_type, | |
| [fewshot_tab_label, app_session], | |
| [app_session] | |
| ) | |
| fewshot_tab.select( | |
| clear, | |
| [txt_input, file_upload, chat_bot, app_session], | |
| [txt_input, file_upload, file_preview, chat_bot, app_session, image_input, user_message, assistant_message] | |
| ) | |
| # chat_bot.flushed(flushed, outputs=[txt_input]) # 标准 Chatbot 可能不支持 flushed | |
| params_form.change( | |
| update_streaming_mode_state, | |
| inputs=[params_form], | |
| outputs=[streaming_mode] | |
| ) | |
| regenerate.click( | |
| regenerate_button_clicked, | |
| [txt_input, image_input, user_message, assistant_message, chat_bot, app_session, params_form, thinking_mode, streaming_mode, fps_setting], | |
| [txt_input, image_input, user_message, assistant_message, chat_bot, app_session] | |
| ) | |
| clear_button.click( | |
| clear, | |
| [txt_input, file_upload, chat_bot, app_session], | |
| [txt_input, file_upload, file_preview, current_images, chat_bot, app_session, image_input, user_message, assistant_message] | |
| ) | |
| stop_button.click( | |
| stop_button_clicked, | |
| [app_session], | |
| [app_session, stop_button] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| # 解析命令行参数 | |
| parser = argparse.ArgumentParser(description='Web Demo for MiniCPM-V 4.5') | |
| parser.add_argument('--port', type=int, default=7860, help='Port to run the web demo on') | |
| parser.add_argument('--no-parallel-encoding', action='store_true', help='Disable parallel image encoding') | |
| parser.add_argument('--parallel-processes', type=int, default=None, help='Number of parallel processes for image encoding') | |
| args = parser.parse_args() | |
| # 配置并行编码 | |
| if args.no_parallel_encoding: | |
| ENABLE_PARALLEL_ENCODING = False | |
| print("[性能优化] 并行图像编码已禁用") | |
| else: | |
| ENABLE_PARALLEL_ENCODING = True | |
| print("[性能优化] 并行图像编码已启用") | |
| if args.parallel_processes: | |
| PARALLEL_PROCESSES = args.parallel_processes | |
| print(f"[性能优化] 设置并行进程数为: {PARALLEL_PROCESSES}") | |
| else: | |
| print(f"[性能优化] 自动检测并行进程数,CPU核心数: {mp.cpu_count()}") | |
| # 初始化模型 | |
| initialize_model() | |
| # 创建并启动应用 | |
| demo = create_app() | |
| demo.launch( | |
| share=False, | |
| debug=True, | |
| show_api=False, | |
| server_port=args.port, | |
| server_name="0.0.0.0" | |
| ) |