xu-song's picture
update
ed8b0c6
"""
## 来源
## 参考
- [Llama-3.1的工具调用 | 2024/08](https://huggingface.co/meta-llama/Llama-3.1-405B-Instruct)
- [huggingface统一工具调用 | 2024/08/12](https://huggingface.co/blog/unified-tool-use)
"""
import os
import json
from transformers import AutoTokenizer
from transformers.utils import get_json_schema
# MODEL_PATH = "meta-llama/Llama-3.1-405B-Instruct"
# MODEL_PATH = "NousResearch/Hermes-3-Llama-3.1-405B" # messages里的tool_calls必须要有content字段
# MODEL_PATH = "../../test/Llama-4-Maverick-17B-128E-Instruct/"
# MODEL_PATH = "meta-llama/Llama-4-Maverick-17B-128E-Instruct"
# MODEL_PATH = "Qwen/Qwen3-235B-A22B-Instruct-2507"
# MODEL_PATH = "mistralai/Mistral-7B-Instruct-v0.1" # messages里不支持tool_calls,不支持 role=tool,不支持 tools
# MODEL_PATH = "mistralai/Ministral-8B-Instruct-2410" # 支持 tools, 支持tool_calls(必须要有id), 格式非主流
# MODEL_PATH = "deepseek-ai/DeepSeek-R1" # 不支持tools,tool_calls也有问题
# MODEL_PATH = "deepseek-ai/DeepSeek-V3.1"
# MODEL_PATH = "google/gemma-3-27b-it" # 不支持任何tool
# MODEL_PATH = "moonshotai/Kimi-K2-Instruct"
MODEL_PATH = "xai-org/grok-2"
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
# First, define a tool
def get_current_temperature(location: str) -> float:
"""
Get the current temperature at a location.
Args:
location: The location to get the temperature for, in the format "City, Country"
Returns:
The current temperature at the specified location in the specified units, as a float.
"""
return 22. # A real function should probably actually get the temperature!
# Next, create a chat and apply the chat template
messages = [
{"role": "system", "content": "You are a bot that responds to weather queries."},
{"role": "user", "content": "Hey, what's the temperature in Paris right now?"},
# {"role": "assitant", "content": "test1"},
# {"role": "user", "content": "test2"},
]
# step1:
print("###" * 10, "llm with tools", MODEL_PATH)
print(json.dumps(messages, ensure_ascii=False, indent=2))
inputs = tokenizer.apply_chat_template(messages, tools=[get_current_temperature], add_generation_prompt=True, tokenize=False)
print("###" * 5, "prompt")
print(inputs)
json_schema = get_json_schema(get_current_temperature)
print("###" * 5, "json_schema")
print(json.dumps(json_schema, ensure_ascii=False, indent=2))
# step2: 调用 LLM,以下是LLM的返回
tool_call = {"name": "get_current_temperature", "arguments": {"location": "Paris, France"}}
# messages.append({"role": "assistant", "tool_calls": [{"type": "function", "function": tool_call}]})
# messages.append({"role": "assistant", "tool_calls": [{"type": "function", "function": tool_call}], "content": ""}) # Hermes-3-Llama-3.1-405B 的content 不能为空
# messages.append({"role": "assistant", "tool_calls": [{"id": "123456789", "type": "function", "function": tool_call}]}) # Ministral-8B-Instruct-2410 仍然报错
# step3: 调用tool,以下是 tool 的返回
messages.append({"role": "tool", "name": "get_current_temperature", "content": "22.0"})
# step4: 再调用 LLM
print("###" * 10, "tool_calls & tool_response")
print(json.dumps(messages, ensure_ascii=False, indent=2))
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
print("###" * 5, 'prompt')
print(inputs)