multi-turn tool call test

This commit is contained in:
Hardik Shah 2025-04-05 20:26:22 -07:00
parent 3021c87271
commit eafbde4e17
2 changed files with 60 additions and 0 deletions

View file

@ -338,6 +338,10 @@ class MetaReferenceInferenceImpl(
stop_reason = None stop_reason = None
for token_result in self.generator.chat_completion(request): for token_result in self.generator.chat_completion(request):
from termcolor import cprint
cprint(token_result.text, "cyan", end="")
tokens.append(token_result.token) tokens.append(token_result.token)
if token_result.token == tokenizer.eot_id: if token_result.token == tokenizer.eot_id:
@ -386,6 +390,10 @@ class MetaReferenceInferenceImpl(
ipython = False ipython = False
for token_result in self.generator.chat_completion(request): for token_result in self.generator.chat_completion(request):
from termcolor import cprint
cprint(token_result.text, "cyan", end="")
tokens.append(token_result.token) tokens.append(token_result.token)
if not ipython and token_result.text.startswith("<|python_tag|>"): if not ipython and token_result.text.startswith("<|python_tag|>"):

View file

@ -491,3 +491,55 @@ def test_text_chat_completion_tool_calling_tools_not_in_request(
else: else:
for tc in response.completion_message.tool_calls: for tc in response.completion_message.tool_calls:
assert tc.tool_name == "get_object_namespace_list" assert tc.tool_name == "get_object_namespace_list"
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:tool_calling",
],
)
def test_text_chat_completion_with_tool_calling_loop_non_streaming(client_with_models, text_model_id, test_case):
tc = TestCase(test_case)
messages = tc["messages"]
messages[0]["content"] += """
Once you make one or more function calls, try to answer the question using the response if you can.
NEVER invoke the same function with the same argumennts twice. Use the response of the first call instead."""
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=messages,
tools=tc["tools"],
stream=False,
)
# some models can return content for the response in addition to the tool call
assert response.completion_message.role == "assistant"
assert len(response.completion_message.tool_calls) == 1
assert response.completion_message.tool_calls[0].tool_name == tc["tools"][0]["tool_name"]
assert response.completion_message.tool_calls[0].arguments == tc["expected"]
messages.append(response.completion_message)
messages.append(
# Tool Response Message
{
"role": "tool",
"call_id": response.completion_message.tool_calls[0].call_id,
"content": "70 degrees and foggy",
}
)
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=messages,
tools=tc["tools"],
tool_choice="auto",
stream=False,
)
from rich.pretty import pprint
pprint(response.completion_message)
# you would expect no tool call but a text completion message
assert len(response.completion_message.tool_calls) == 0