minor fix

This commit is contained in:
Hardik Shah 2025-04-06 17:36:42 -07:00
parent 31453f3f79
commit 541d0c6f1a
2 changed files with 24 additions and 20 deletions

View file

@ -496,48 +496,53 @@ def test_text_chat_completion_tool_calling_tools_not_in_request(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"test_case", "test_case",
[ [
# Tests if the model can handle simple messages like "Hi" or
# a message unrelated to one of the tool calls
"inference:chat_completion:multi_turn_tool_calling_01", "inference:chat_completion:multi_turn_tool_calling_01",
# Tests if the model can do full tool call with responses correctly
"inference:chat_completion:multi_turn_tool_calling_02", "inference:chat_completion:multi_turn_tool_calling_02",
# Tests if model can generate multiple params and
# read outputs correctly
"inference:chat_completion:multi_turn_tool_calling_03", "inference:chat_completion:multi_turn_tool_calling_03",
# Tests if model can do different tool calls in a seqeunce
# and use the information between appropriately
"inference:chat_completion:multi_turn_tool_calling_04", "inference:chat_completion:multi_turn_tool_calling_04",
# Tests if model can use current date and run multiple tool calls
# sequentially and infer using both
"inference:chat_completion:multi_turn_tool_calling_05", "inference:chat_completion:multi_turn_tool_calling_05",
], ],
) )
def test_text_chat_completion_with_tool_calling_loop_non_streaming(client_with_models, text_model_id, test_case): def test_text_chat_completion_with_multi_turn_tool_calling(client_with_models, text_model_id, test_case):
from rich.pretty import pprint """This test tests the model's tool calling loop in various scenarios"""
tc = TestCase(test_case) tc = TestCase(test_case)
messages = [] messages = []
# keep going until either # keep going until either
# 1. we have messages to test in multi-turn # 1. we have messages to test in multi-turn
# 2. no messages bust last message is tool response # 2. no messages bust last message is tool response
while len(tc["messages"]) > 0 or (len(messages) > 0 and messages[-1]["role"] == "tool"): while len(tc["messages"]) > 0 or (len(messages) > 0 and messages[-1]["role"] == "tool"):
# do not take new messages if last message is tool response # do not take new messages if last message is tool response
if ( if len(messages) == 0 or messages[-1]["role"] != "tool":
len(messages) == 0
or (isinstance(messages[-1], dict) and messages[-1]["role"] != "tool")
or (not isinstance(messages[-1], dict) and messages[-1].role != "tool")
):
new_messages = tc["messages"].pop(0) new_messages = tc["messages"].pop(0)
messages += new_messages messages += new_messages
pprint(messages) # pprint(messages)
response = client_with_models.inference.chat_completion( response = client_with_models.inference.chat_completion(
model_id=text_model_id, model_id=text_model_id,
messages=messages, messages=messages,
tools=tc["tools"], tools=tc["tools"],
stream=False, stream=False,
# sampling_params={ sampling_params={
# "strategy": { "strategy": {
# "type": "top_p", "type": "top_p",
# "top_p": 0.9, "top_p": 0.9,
# "temperature": 0.6, "temperature": 0.6,
# } }
# }, },
) )
op_msg = response.completion_message op_msg = response.completion_message
messages.append(op_msg.model_dump()) messages.append(op_msg.model_dump())
pprint(op_msg) # pprint(op_msg)
assert op_msg.role == "assistant" assert op_msg.role == "assistant"
expected = tc["expected"].pop(0) expected = tc["expected"].pop(0)
@ -558,5 +563,5 @@ def test_text_chat_completion_with_tool_calling_loop_non_streaming(client_with_m
) )
else: else:
actual_answer = op_msg.content.lower() actual_answer = op_msg.content.lower()
pprint(actual_answer) # pprint(actual_answer)
assert expected["answer"] in actual_answer assert expected["answer"] in actual_answer

View file

@ -94,8 +94,7 @@
} }
], ],
"expected": { "expected": {
"num_tool_calls": 1, "location": "San Francisco, CA"
"expected": "San Francisco, CA"
} }
} }
}, },