File size: 3,393 Bytes
21757cb
 
 
 
77ac6c2
 
 
 
 
21757cb
 
 
 
 
 
89acc5c
21757cb
77ac6c2
d3ba21b
ed8b0c6
d3ba21b
 
982d6e8
89acc5c
d3ba21b
e584a13
ed8b0c6
e584a13
982d6e8
ed8b0c6
 
89acc5c
21757cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3ba21b
e584a13
 
21757cb
 
 
89acc5c
 
21757cb
89acc5c
21757cb
89acc5c
 
 
 
21757cb
 
 
982d6e8
 
89acc5c
21757cb
 
 
 
 
89acc5c
21757cb
77ac6c2
89acc5c
21757cb
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""
## 来源


## 参考

- [Llama-3.1的工具调用 | 2024/08](https://huggingface.co/meta-llama/Llama-3.1-405B-Instruct)
- [huggingface统一工具调用 | 2024/08/12](https://huggingface.co/blog/unified-tool-use)


"""

import os
import json
from transformers import AutoTokenizer
from transformers.utils import get_json_schema


# MODEL_PATH = "meta-llama/Llama-3.1-405B-Instruct"
# MODEL_PATH = "NousResearch/Hermes-3-Llama-3.1-405B"  # messages里的tool_calls必须要有content字段
# MODEL_PATH = "../../test/Llama-4-Maverick-17B-128E-Instruct/"
# MODEL_PATH = "meta-llama/Llama-4-Maverick-17B-128E-Instruct"
# MODEL_PATH = "Qwen/Qwen3-235B-A22B-Instruct-2507"
# MODEL_PATH = "mistralai/Mistral-7B-Instruct-v0.1"  # messages里不支持tool_calls,不支持 role=tool,不支持 tools
# MODEL_PATH = "mistralai/Ministral-8B-Instruct-2410" # 支持 tools, 支持tool_calls(必须要有id), 格式非主流
# MODEL_PATH = "deepseek-ai/DeepSeek-R1"  # 不支持tools,tool_calls也有问题
# MODEL_PATH = "deepseek-ai/DeepSeek-V3.1"
# MODEL_PATH = "google/gemma-3-27b-it"  # 不支持任何tool
# MODEL_PATH = "moonshotai/Kimi-K2-Instruct"
MODEL_PATH = "xai-org/grok-2"

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

# First, define a tool
def get_current_temperature(location: str) -> float:
    """
    Get the current temperature at a location.
    
    Args:
        location: The location to get the temperature for, in the format "City, Country"
    Returns:
        The current temperature at the specified location in the specified units, as a float.
    """
    return 22.  # A real function should probably actually get the temperature!

# Next, create a chat and apply the chat template
messages = [
  {"role": "system", "content": "You are a bot that responds to weather queries."},
  {"role": "user", "content": "Hey, what's the temperature in Paris right now?"},
#   {"role": "assitant", "content": "test1"},
#   {"role": "user", "content": "test2"},
]

# step1:
print("###" * 10, "llm with tools", MODEL_PATH)
print(json.dumps(messages, ensure_ascii=False, indent=2))
inputs = tokenizer.apply_chat_template(messages, tools=[get_current_temperature], add_generation_prompt=True, tokenize=False)
print("###" * 5, "prompt")
print(inputs)
json_schema = get_json_schema(get_current_temperature)
print("###" * 5, "json_schema")
print(json.dumps(json_schema, ensure_ascii=False, indent=2))


# step2: 调用 LLM,以下是LLM的返回
tool_call = {"name": "get_current_temperature", "arguments": {"location": "Paris, France"}}
# messages.append({"role": "assistant", "tool_calls": [{"type": "function", "function": tool_call}]})
# messages.append({"role": "assistant", "tool_calls": [{"type": "function", "function": tool_call}], "content": ""}) # Hermes-3-Llama-3.1-405B 的content 不能为空
# messages.append({"role": "assistant", "tool_calls": [{"id": "123456789", "type": "function", "function": tool_call}]})  # Ministral-8B-Instruct-2410 仍然报错

# step3: 调用tool,以下是 tool 的返回
messages.append({"role": "tool", "name": "get_current_temperature", "content": "22.0"})

# step4: 再调用 LLM
print("###" * 10, "tool_calls & tool_response")
print(json.dumps(messages, ensure_ascii=False, indent=2))
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
print("###" * 5, 'prompt')
print(inputs)