From 67d10a9c7c1d0c5761e451c30a248b2d6ad6d1dc Mon Sep 17 00:00:00 2001 From: Shabana Baig <43451943+s-akhtar-baig@users.noreply.github.com> Date: Tue, 18 Nov 2025 16:12:31 -0500 Subject: [PATCH] Fix max_tool_calls for openai and add integration tests for the feat --- .../meta_reference/responses/streaming.py | 11 +- .../agents/test_openai_responses.py | 166 ------------------ .../responses/test_tool_responses.py | 152 ++++++++++++++++ 3 files changed, 161 insertions(+), 168 deletions(-) diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index cdbd87244..43ef6c460 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -66,6 +66,7 @@ from llama_stack_api import ( OpenAIResponseUsage, OpenAIResponseUsageInputTokensDetails, OpenAIResponseUsageOutputTokensDetails, + OpenAIToolMessageParam, WebSearchToolTypes, ) @@ -905,10 +906,16 @@ class StreamingResponseOrchestrator: """Coordinate execution of both function and non-function tool calls.""" # Execute non-function tool calls for tool_call in non_function_tool_calls: - # Check if total calls made to built-in and mcp tools exceed max_tool_calls + # if total calls made to built-in and mcp tools exceed max_tool_calls + # then create a tool response message indicating the call was skipped if self.max_tool_calls is not None and self.accumulated_builtin_tool_calls >= self.max_tool_calls: logger.info(f"Ignoring built-in and mcp tool call since reached the limit of {self.max_tool_calls=}.") - break + skipped_call_message = OpenAIToolMessageParam( + content=f"Tool call skipped: maximum tool calls limit ({self.max_tool_calls}) reached.", + tool_call_id=tool_call.id, + ) + next_turn_messages.append(skipped_call_message) + continue # Find the item_id for this tool call matching_item_id = None diff --git a/tests/integration/agents/test_openai_responses.py b/tests/integration/agents/test_openai_responses.py index 057cee774..d413d5201 100644 --- a/tests/integration/agents/test_openai_responses.py +++ b/tests/integration/agents/test_openai_responses.py @@ -516,169 +516,3 @@ def test_response_with_instructions(openai_client, client_with_models, text_mode # Verify instructions from previous response was not carried over to the next response assert response_with_instructions2.instructions == instructions2 - - -@pytest.mark.skip(reason="Tool calling is not reliable.") -def test_max_tool_calls_with_function_tools(openai_client, client_with_models, text_model_id): - """Test handling of max_tool_calls with function tools in responses.""" - if isinstance(client_with_models, LlamaStackAsLibraryClient): - pytest.skip("OpenAI responses are not supported when testing with library client yet.") - - client = openai_client - max_tool_calls = 1 - - tools = [ - { - "type": "function", - "name": "get_weather", - "description": "Get weather information for a specified location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city name (e.g., 'New York', 'London')", - }, - }, - }, - }, - { - "type": "function", - "name": "get_time", - "description": "Get current time for a specified location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city name (e.g., 'New York', 'London')", - }, - }, - }, - }, - ] - - # First create a response that triggers function tools - response = client.responses.create( - model=text_model_id, - input="Can you tell me the weather in Paris and the current time?", - tools=tools, - stream=False, - max_tool_calls=max_tool_calls, - ) - - # Verify we got two function calls and that the max_tool_calls do not affect function tools - assert len(response.output) == 2 - assert response.output[0].type == "function_call" - assert response.output[0].name == "get_weather" - assert response.output[0].status == "completed" - assert response.output[1].type == "function_call" - assert response.output[1].name == "get_time" - assert response.output[0].status == "completed" - - # Verify we have a valid max_tool_calls field - assert response.max_tool_calls == max_tool_calls - - -def test_max_tool_calls_invalid(openai_client, client_with_models, text_model_id): - """Test handling of invalid max_tool_calls in responses.""" - if isinstance(client_with_models, LlamaStackAsLibraryClient): - pytest.skip("OpenAI responses are not supported when testing with library client yet.") - - client = openai_client - - input = "Search for today's top technology news." - invalid_max_tool_calls = 0 - tools = [ - {"type": "web_search"}, - ] - - # Create a response with an invalid max_tool_calls value i.e. 0 - # Handle ValueError from LLS and BadRequestError from OpenAI client - with pytest.raises((ValueError, BadRequestError)) as excinfo: - client.responses.create( - model=text_model_id, - input=input, - tools=tools, - stream=False, - max_tool_calls=invalid_max_tool_calls, - ) - - error_message = str(excinfo.value) - assert f"Invalid max_tool_calls={invalid_max_tool_calls}; should be >= 1" in error_message, ( - f"Expected error message about invalid max_tool_calls, got: {error_message}" - ) - - -def test_max_tool_calls_with_builtin_tools(openai_client, client_with_models, text_model_id): - """Test handling of max_tool_calls with built-in tools in responses.""" - if isinstance(client_with_models, LlamaStackAsLibraryClient): - pytest.skip("OpenAI responses are not supported when testing with library client yet.") - - client = openai_client - - input = "Search for today's top technology and a positive news story. You MUST make exactly two separate web search calls." - max_tool_calls = [1, 5] - tools = [ - {"type": "web_search"}, - ] - - # First create a response that triggers web_search tools without max_tool_calls - response = client.responses.create( - model=text_model_id, - input=input, - tools=tools, - stream=False, - ) - - # Verify we got two web search calls followed by a message - assert len(response.output) == 3 - assert response.output[0].type == "web_search_call" - assert response.output[0].status == "completed" - assert response.output[1].type == "web_search_call" - assert response.output[1].status == "completed" - assert response.output[2].type == "message" - assert response.output[2].status == "completed" - assert response.output[2].role == "assistant" - - # Next create a response that triggers web_search tools with max_tool_calls set to 1 - response_2 = client.responses.create( - model=text_model_id, - input=input, - tools=tools, - stream=False, - max_tool_calls=max_tool_calls[0], - ) - - # Verify we got one web search tool call followed by a message - assert len(response_2.output) == 2 - assert response_2.output[0].type == "web_search_call" - assert response_2.output[0].status == "completed" - assert response_2.output[1].type == "message" - assert response_2.output[1].status == "completed" - assert response_2.output[1].role == "assistant" - - # Verify we have a valid max_tool_calls field - assert response_2.max_tool_calls == max_tool_calls[0] - - # Finally create a response that triggers web_search tools with max_tool_calls set to 5 - response_3 = client.responses.create( - model=text_model_id, - input=input, - tools=tools, - stream=False, - max_tool_calls=max_tool_calls[1], - ) - - # Verify we got two web search calls followed by a message - assert len(response_3.output) == 3 - assert response_3.output[0].type == "web_search_call" - assert response_3.output[0].status == "completed" - assert response_3.output[1].type == "web_search_call" - assert response_3.output[1].status == "completed" - assert response_3.output[2].type == "message" - assert response_3.output[2].status == "completed" - assert response_3.output[2].role == "assistant" - - # Verify we have a valid max_tool_calls field - assert response_3.max_tool_calls == max_tool_calls[1] diff --git a/tests/integration/responses/test_tool_responses.py b/tests/integration/responses/test_tool_responses.py index 742d45f8b..e7087dcd0 100644 --- a/tests/integration/responses/test_tool_responses.py +++ b/tests/integration/responses/test_tool_responses.py @@ -600,3 +600,155 @@ def test_response_streaming_multi_turn_tool_execution(responses_client, text_mod assert expected_output.lower() in final_response.output_text.lower(), ( f"Expected '{expected_output}' to appear in response: {final_response.output_text}" ) + + +def test_max_tool_calls_with_function_tools(responses_client, text_model_id): + """Test handling of max_tool_calls with function tools in responses.""" + + max_tool_calls = 1 + tools = [ + { + "type": "function", + "name": "get_weather", + "description": "Get weather information for a specified location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name (e.g., 'New York', 'London')", + }, + }, + }, + }, + { + "type": "function", + "name": "get_time", + "description": "Get current time for a specified location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name (e.g., 'New York', 'London')", + }, + }, + }, + }, + ] + + response = responses_client.responses.create( + model=text_model_id, + input="Can you tell me the weather in Paris and the current time?", + tools=tools, + stream=False, + max_tool_calls=max_tool_calls, + ) + + # Verify we got two function calls and that the max_tool_calls does not affect function tools + assert len(response.output) == 2 + assert response.output[0].type == "function_call" + assert response.output[0].name == "get_weather" + assert response.output[0].status == "completed" + assert response.output[1].type == "function_call" + assert response.output[1].name == "get_time" + assert response.output[1].status == "completed" + + # Verify we have a valid max_tool_calls field + assert response.max_tool_calls == max_tool_calls + + +def test_max_tool_calls_invalid(responses_client, text_model_id): + """Test handling of invalid max_tool_calls in responses.""" + + input = "Search for today's top technology news." + invalid_max_tool_calls = 0 + tools = [ + {"type": "web_search"}, + ] + + # Create a response with an invalid max_tool_calls value i.e. 0 + # Handle ValueError from LLS and BadRequestError from OpenAI client + with pytest.raises((ValueError, openai.BadRequestError)) as excinfo: + responses_client.responses.create( + model=text_model_id, + input=input, + tools=tools, + stream=False, + max_tool_calls=invalid_max_tool_calls, + ) + + error_message = str(excinfo.value) + assert f"Invalid max_tool_calls={invalid_max_tool_calls}; should be >= 1" in error_message, ( + f"Expected error message about invalid max_tool_calls, got: {error_message}" + ) + + +def test_max_tool_calls_with_mcp_tools(responses_client, text_model_id): + """Test handling of max_tool_calls with mcp tools in responses.""" + + with make_mcp_server(tools=dependency_tools()) as mcp_server_info: + input = "Get the experiment ID for 'boiling_point' and get the user ID for 'charlie'" + max_tool_calls = [1, 5] + tools = [ + {"type": "mcp", "server_label": "localmcp", "server_url": mcp_server_info["server_url"]}, + ] + + # First create a response that triggers mcp tools without max_tool_calls + response = responses_client.responses.create( + model=text_model_id, + input=input, + tools=tools, + stream=False, + ) + + # Verify we got two mcp tool calls followed by a message + assert len(response.output) == 4 + mcp_list_tools = [output for output in response.output if output.type == "mcp_list_tools"] + mcp_calls = [output for output in response.output if output.type == "mcp_call"] + message_outputs = [output for output in response.output if output.type == "message"] + assert len(mcp_list_tools) == 1 + assert len(mcp_calls) == 2, f"Expected two mcp calls, got {len(mcp_calls)}" + assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}" + + # Next create a response that triggers mcp tools with max_tool_calls set to 1 + response_2 = responses_client.responses.create( + model=text_model_id, + input=input, + tools=tools, + stream=False, + max_tool_calls=max_tool_calls[0], + ) + + # Verify we got one mcp tool call followed by a message + assert len(response_2.output) == 3 + mcp_list_tools = [output for output in response_2.output if output.type == "mcp_list_tools"] + mcp_calls = [output for output in response_2.output if output.type == "mcp_call"] + message_outputs = [output for output in response_2.output if output.type == "message"] + assert len(mcp_list_tools) == 1 + assert len(mcp_calls) == 1, f"Expected one mcp call, got {len(mcp_calls)}" + assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}" + + # Verify we have a valid max_tool_calls field + assert response_2.max_tool_calls == max_tool_calls[0] + + # Finally create a response that triggers mcp tools with max_tool_calls set to 5 + response_3 = responses_client.responses.create( + model=text_model_id, + input=input, + tools=tools, + stream=False, + max_tool_calls=max_tool_calls[1], + ) + + # Verify we got two mcp tool calls followed by a message + assert len(response_3.output) == 4 + mcp_list_tools = [output for output in response_3.output if output.type == "mcp_list_tools"] + mcp_calls = [output for output in response_3.output if output.type == "mcp_call"] + message_outputs = [output for output in response_3.output if output.type == "message"] + assert len(mcp_list_tools) == 1 + assert len(mcp_calls) == 2, f"Expected two mcp calls, got {len(mcp_calls)}" + assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}" + + # Verify we have a valid max_tool_calls field + assert response_3.max_tool_calls == max_tool_calls[1]