File size: 4,857 Bytes
a9fb7e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb6eafc
a9fb7e9
 
cb6eafc
 
 
 
 
 
 
 
 
 
 
 
a9fb7e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb6eafc
 
a9fb7e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb6eafc
 
 
 
 
a9fb7e9
cb6eafc
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import httpx
import json
import logging
import html
from config import ANTCHAT_BASE_URL, ANTCHAT_API_KEY, CHAT_MODEL_SPECS

logging.basicConfig(
    level=logging.DEBUG,
    format="%(asctime)s [%(levelname)s] %(message)s"
)
logger = logging.getLogger(__name__)

def get_model_response(model_id, history, system_prompt, temperature, escape_html=True):
    """
    与 AntChat API 交互以获取模型响应。
    """
    print(f"[get_model_response] History: {history}")
    # The model_id passed in is now the ground truth, potentially overridden by local.py
    api_model_id = CHAT_MODEL_SPECS[model_id]["model_id"]

    # 如果 local.py 存在 get_local_model_id_map 函数,则使用它来覆盖模型 ID
    try:
        from local import get_local_model_id_map
        local_map = get_local_model_id_map()
        if model_id in local_map:
            api_model_id = local_map[model_id]
            logger.info(f"使用本地模型 ID 映射: {model_id} -> {api_model_id}")
    except ImportError:
        logger.info("local.py 未找到,使用默认模型 ID 映射。")
    except Exception as e:
        logger.error(f"获取本地模型 ID 映射时出错: {e}")
    
    headers = {
        "Authorization": f"Bearer {ANTCHAT_API_KEY}",
        "Content-Type": "application/json",
    }
    
    # 构建消息历史
    messages = [{"role": "system", "content": system_prompt}]
    for user_msg, assistant_msg in history:
        # 关键修复:只处理包含用户消息的轮次,以过滤掉UI的初始欢迎语
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
            # 只有在用户消息之后,才可能追加对应的助手消息
            if assistant_msg:
                messages.append({"role": "assistant", "content": assistant_msg})

    print(f"[get_model_response] Messages: {messages}")

    json_data = {
        "model": api_model_id,
        "messages": messages,
        "stream": True,
        "temperature": temperature,
    }

    logger.debug(f"请求 URL: {ANTCHAT_BASE_URL}/chat/completions")
    logger.debug(f"请求头: {headers}")
    logger.debug(f"请求体: {json.dumps(json_data, ensure_ascii=False)}")
    try:
        with httpx.stream(
            "POST",
            f"{ANTCHAT_BASE_URL}/chat/completions",
            headers=headers,
            json=json_data,
            timeout=120,
        ) as response:
            logger.debug(f"响应状态码: {response.status_code}")
            response.raise_for_status()
            for chunk in response.iter_lines():
                if chunk.startswith("data:"):
                    chunk = chunk[5:]
                    if chunk.strip() == "[DONE]":
                        break
                    try:
                        data = json.loads(chunk)
                        if "choices" in data and data["choices"]:
                            delta = data["choices"][0].get("delta", {})
                            content_chunk = delta.get("content")
                            if content_chunk:
                                yield html.escape(content_chunk) if escape_html else content_chunk
                            
                            elif "tool_calls" in delta:
                                tool_calls = delta.get("tool_calls", [])
                                if tool_calls:
                                    func_chunk = tool_calls[0].get("function", {})
                                    args_chunk = func_chunk.get("arguments")
                                    if args_chunk:
                                        yield html.escape(args_chunk) if escape_html else args_chunk
                    except json.JSONDecodeError as e:
                        logger.error(f"JSON 解析错误: {e}, 数据: {chunk}")
    except Exception as e:
        logger.error(f"请求异常: {e}")

def perform_web_search(query):
    # 调用 Tavily 或 Serper API
    #...
    return "搜索结果摘要"

def generate_code_for_tab(system_prompt, user_prompt, code_type, model_choice):
    """
    为代码生成标签页调用 Ring 模型。
    """
    logger.info(f"为 '{code_type}' 类型生成代码,Prompt: '{user_prompt}', Model: '{model_choice}'")
    
    # 从 UI 的选项中解析出模型名称
    if "Ling-1T" in model_choice:
        model_name = "ling-1t"
    elif "Ring-flash-2.0" in model_choice:
        model_name = "ring-flash-2.0"
    else:
        # 默认或备用模型
        model_name = "inclusionai/ling-1t"
        logger.warning(f"未知的模型选项 '{model_choice}', 回退到默认模型 'inclusionai/ling-1t'")

    history = [[user_prompt, None]]
    temperature = 0.7
    # For code, we don't want to escape HTML entities
    yield from get_model_response(model_name, history, system_prompt, temperature, escape_html=False)