update test to try multi-turn scenarios

This commit is contained in:
Hardik Shah 2025-04-06 12:13:59 -07:00
parent eafbde4e17
commit cd618e9ad0
3 changed files with 198 additions and 64 deletions

View file

@ -496,50 +496,55 @@ def test_text_chat_completion_tool_calling_tools_not_in_request(
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:tool_calling",
# "inference:chat_completion:multi_turn_tool_calling_01",
"inference:chat_completion:multi_turn_tool_calling_02",
],
)
def test_text_chat_completion_with_tool_calling_loop_non_streaming(client_with_models, text_model_id, test_case):
tc = TestCase(test_case)
messages = tc["messages"]
messages[0]["content"] += """
Once you make one or more function calls, try to answer the question using the response if you can.
NEVER invoke the same function with the same argumennts twice. Use the response of the first call instead."""
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=messages,
tools=tc["tools"],
stream=False,
)
# some models can return content for the response in addition to the tool call
assert response.completion_message.role == "assistant"
assert len(response.completion_message.tool_calls) == 1
assert response.completion_message.tool_calls[0].tool_name == tc["tools"][0]["tool_name"]
assert response.completion_message.tool_calls[0].arguments == tc["expected"]
messages.append(response.completion_message)
messages.append(
# Tool Response Message
{
"role": "tool",
"call_id": response.completion_message.tool_calls[0].call_id,
"content": "70 degrees and foggy",
}
)
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=messages,
tools=tc["tools"],
tool_choice="auto",
stream=False,
)
from rich.pretty import pprint
pprint(response.completion_message)
# you would expect no tool call but a text completion message
assert len(response.completion_message.tool_calls) == 0
tc = TestCase(test_case)
messages = []
# keep going until either
# 1. we have messages to test in multi-turn
# 2. no messages bust last message is tool response
while len(tc["messages"]) > 0 or (len(messages) > 0 and messages[-1]["role"] == "tool"):
# do not take new messages if last message is tool response
if len(messages) == 0 or messages[-1]["role"] != "tool":
new_messages = tc["messages"].pop(0)
messages += new_messages
pprint(messages)
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=messages,
tools=tc["tools"],
stream=False,
)
op_msg = response.completion_message
messages.append(op_msg)
pprint(op_msg)
assert op_msg.role == "assistant"
expected = tc["expected"].pop(0)
assert len(op_msg.tool_calls) == expected["num_tool_calls"]
if expected["num_tool_calls"] > 0:
assert op_msg.tool_calls[0].tool_name == expected["tool_name"]
assert op_msg.tool_calls[0].arguments == expected["tool_arguments"]
# messages.append(op_msg)
tool_response = tc["tool_responses"].pop(0)
messages.append(
# Tool Response Message
{
"role": "tool",
"call_id": op_msg.tool_calls[0].call_id,
"content": tool_response["response"],
}
)
else:
actual_answer = op_msg.content.lower()
pprint(actual_answer)
assert expected["answer"] in actual_answer