multiturn inference

This commit is contained in:
Xi Yan 2025-04-05 18:22:27 -07:00
parent 3021c87271
commit 5039888762

View file

@ -491,3 +491,80 @@ def test_text_chat_completion_tool_calling_tools_not_in_request(
else:
for tc in response.completion_message.tool_calls:
assert tc.tool_name == "get_object_namespace_list"
def test_multi_turn_chat_completion(client_with_models, text_model_id):
SYSTEM_PROMPT = """
# Tools
You have access to the following tools. You might need to use one or more function/tool calls to fulfill the task.
If none are needed, then proceed to the response.
## Tool Call Syntax
You can call tools using the following syntax:
[{
"name": <tool_name>,
"parameters": "{\"param1\": value1, \"param2\": value2}"
}]
where `parameters` is a JSON string of the parameters for the tool call.
Do not include anything else when calling the tools with the syntax above.
Only respond with the valid tool call JSON.
## Available Tools
[{
"name": "get_weather",
"description": "Retrieve the current temperature for a specified location",
"parameters": {
"type": "dict",
"properties": {
"location": {
"type": "string",
"description": "The city, state, or country for which to fetch the temperature",
"required": true
}
},
"required": ["location"]
}
}]
## Example Tool Calls
1. Single tool call - Get weather:
[{
"name": "get_weather",
"parameters": "{\"location\": \"San Francisco, CA\"}"
}]
2. Multiple tool calls - Weather and web search:
[{
"name": "web_search",
"parameters": "{\"query\": \"What is the capital of France?\"}"
}, {
"name": "get_weather",
"parameters": "{\"location\": \"Paris, France\"}"
}]
"""
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": "What's the weather in Tokyo?"},
]
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=messages,
)
messages.append(response.completion_message)
messages.append({"role": "tool", "content": "raining", "call_id": "1"})
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=messages,
)
print(response.completion_message.tool_calls)
print(response)
assert response.completion_message.content == "raining"