diff --git a/tests/integration/inference/test_text_inference.py b/tests/integration/inference/test_text_inference.py index a204511d0..2da189fa4 100644 --- a/tests/integration/inference/test_text_inference.py +++ b/tests/integration/inference/test_text_inference.py @@ -496,48 +496,53 @@ def test_text_chat_completion_tool_calling_tools_not_in_request( @pytest.mark.parametrize( "test_case", [ + # Tests if the model can handle simple messages like "Hi" or + # a message unrelated to one of the tool calls "inference:chat_completion:multi_turn_tool_calling_01", + # Tests if the model can do full tool call with responses correctly "inference:chat_completion:multi_turn_tool_calling_02", + # Tests if model can generate multiple params and + # read outputs correctly "inference:chat_completion:multi_turn_tool_calling_03", + # Tests if model can do different tool calls in a seqeunce + # and use the information between appropriately "inference:chat_completion:multi_turn_tool_calling_04", + # Tests if model can use current date and run multiple tool calls + # sequentially and infer using both "inference:chat_completion:multi_turn_tool_calling_05", ], ) -def test_text_chat_completion_with_tool_calling_loop_non_streaming(client_with_models, text_model_id, test_case): - from rich.pretty import pprint - +def test_text_chat_completion_with_multi_turn_tool_calling(client_with_models, text_model_id, test_case): + """This test tests the model's tool calling loop in various scenarios""" tc = TestCase(test_case) messages = [] + # keep going until either # 1. we have messages to test in multi-turn # 2. no messages bust last message is tool response while len(tc["messages"]) > 0 or (len(messages) > 0 and messages[-1]["role"] == "tool"): # do not take new messages if last message is tool response - if ( - len(messages) == 0 - or (isinstance(messages[-1], dict) and messages[-1]["role"] != "tool") - or (not isinstance(messages[-1], dict) and messages[-1].role != "tool") - ): + if len(messages) == 0 or messages[-1]["role"] != "tool": new_messages = tc["messages"].pop(0) messages += new_messages - pprint(messages) + # pprint(messages) response = client_with_models.inference.chat_completion( model_id=text_model_id, messages=messages, tools=tc["tools"], stream=False, - # sampling_params={ - # "strategy": { - # "type": "top_p", - # "top_p": 0.9, - # "temperature": 0.6, - # } - # }, + sampling_params={ + "strategy": { + "type": "top_p", + "top_p": 0.9, + "temperature": 0.6, + } + }, ) op_msg = response.completion_message messages.append(op_msg.model_dump()) - pprint(op_msg) + # pprint(op_msg) assert op_msg.role == "assistant" expected = tc["expected"].pop(0) @@ -558,5 +563,5 @@ def test_text_chat_completion_with_tool_calling_loop_non_streaming(client_with_m ) else: actual_answer = op_msg.content.lower() - pprint(actual_answer) + # pprint(actual_answer) assert expected["answer"] in actual_answer diff --git a/tests/integration/test_cases/inference/chat_completion.json b/tests/integration/test_cases/inference/chat_completion.json index 648c78ce1..f842bca9a 100644 --- a/tests/integration/test_cases/inference/chat_completion.json +++ b/tests/integration/test_cases/inference/chat_completion.json @@ -94,8 +94,7 @@ } ], "expected": { - "num_tool_calls": 1, - "expected": "San Francisco, CA" + "location": "San Francisco, CA" } } },