mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
Fix max_tool_calls for openai and add integration tests for the feat
This commit is contained in:
parent
a3580e6bc0
commit
67d10a9c7c
3 changed files with 161 additions and 168 deletions
|
|
@ -66,6 +66,7 @@ from llama_stack_api import (
|
|||
OpenAIResponseUsage,
|
||||
OpenAIResponseUsageInputTokensDetails,
|
||||
OpenAIResponseUsageOutputTokensDetails,
|
||||
OpenAIToolMessageParam,
|
||||
WebSearchToolTypes,
|
||||
)
|
||||
|
||||
|
|
@ -905,10 +906,16 @@ class StreamingResponseOrchestrator:
|
|||
"""Coordinate execution of both function and non-function tool calls."""
|
||||
# Execute non-function tool calls
|
||||
for tool_call in non_function_tool_calls:
|
||||
# Check if total calls made to built-in and mcp tools exceed max_tool_calls
|
||||
# if total calls made to built-in and mcp tools exceed max_tool_calls
|
||||
# then create a tool response message indicating the call was skipped
|
||||
if self.max_tool_calls is not None and self.accumulated_builtin_tool_calls >= self.max_tool_calls:
|
||||
logger.info(f"Ignoring built-in and mcp tool call since reached the limit of {self.max_tool_calls=}.")
|
||||
break
|
||||
skipped_call_message = OpenAIToolMessageParam(
|
||||
content=f"Tool call skipped: maximum tool calls limit ({self.max_tool_calls}) reached.",
|
||||
tool_call_id=tool_call.id,
|
||||
)
|
||||
next_turn_messages.append(skipped_call_message)
|
||||
continue
|
||||
|
||||
# Find the item_id for this tool call
|
||||
matching_item_id = None
|
||||
|
|
|
|||
|
|
@ -516,169 +516,3 @@ def test_response_with_instructions(openai_client, client_with_models, text_mode
|
|||
|
||||
# Verify instructions from previous response was not carried over to the next response
|
||||
assert response_with_instructions2.instructions == instructions2
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Tool calling is not reliable.")
|
||||
def test_max_tool_calls_with_function_tools(openai_client, client_with_models, text_model_id):
|
||||
"""Test handling of max_tool_calls with function tools in responses."""
|
||||
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||
pytest.skip("OpenAI responses are not supported when testing with library client yet.")
|
||||
|
||||
client = openai_client
|
||||
max_tool_calls = 1
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_weather",
|
||||
"description": "Get weather information for a specified location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city name (e.g., 'New York', 'London')",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_time",
|
||||
"description": "Get current time for a specified location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city name (e.g., 'New York', 'London')",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# First create a response that triggers function tools
|
||||
response = client.responses.create(
|
||||
model=text_model_id,
|
||||
input="Can you tell me the weather in Paris and the current time?",
|
||||
tools=tools,
|
||||
stream=False,
|
||||
max_tool_calls=max_tool_calls,
|
||||
)
|
||||
|
||||
# Verify we got two function calls and that the max_tool_calls do not affect function tools
|
||||
assert len(response.output) == 2
|
||||
assert response.output[0].type == "function_call"
|
||||
assert response.output[0].name == "get_weather"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[1].type == "function_call"
|
||||
assert response.output[1].name == "get_time"
|
||||
assert response.output[0].status == "completed"
|
||||
|
||||
# Verify we have a valid max_tool_calls field
|
||||
assert response.max_tool_calls == max_tool_calls
|
||||
|
||||
|
||||
def test_max_tool_calls_invalid(openai_client, client_with_models, text_model_id):
|
||||
"""Test handling of invalid max_tool_calls in responses."""
|
||||
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||
pytest.skip("OpenAI responses are not supported when testing with library client yet.")
|
||||
|
||||
client = openai_client
|
||||
|
||||
input = "Search for today's top technology news."
|
||||
invalid_max_tool_calls = 0
|
||||
tools = [
|
||||
{"type": "web_search"},
|
||||
]
|
||||
|
||||
# Create a response with an invalid max_tool_calls value i.e. 0
|
||||
# Handle ValueError from LLS and BadRequestError from OpenAI client
|
||||
with pytest.raises((ValueError, BadRequestError)) as excinfo:
|
||||
client.responses.create(
|
||||
model=text_model_id,
|
||||
input=input,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
max_tool_calls=invalid_max_tool_calls,
|
||||
)
|
||||
|
||||
error_message = str(excinfo.value)
|
||||
assert f"Invalid max_tool_calls={invalid_max_tool_calls}; should be >= 1" in error_message, (
|
||||
f"Expected error message about invalid max_tool_calls, got: {error_message}"
|
||||
)
|
||||
|
||||
|
||||
def test_max_tool_calls_with_builtin_tools(openai_client, client_with_models, text_model_id):
|
||||
"""Test handling of max_tool_calls with built-in tools in responses."""
|
||||
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||
pytest.skip("OpenAI responses are not supported when testing with library client yet.")
|
||||
|
||||
client = openai_client
|
||||
|
||||
input = "Search for today's top technology and a positive news story. You MUST make exactly two separate web search calls."
|
||||
max_tool_calls = [1, 5]
|
||||
tools = [
|
||||
{"type": "web_search"},
|
||||
]
|
||||
|
||||
# First create a response that triggers web_search tools without max_tool_calls
|
||||
response = client.responses.create(
|
||||
model=text_model_id,
|
||||
input=input,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Verify we got two web search calls followed by a message
|
||||
assert len(response.output) == 3
|
||||
assert response.output[0].type == "web_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[1].type == "web_search_call"
|
||||
assert response.output[1].status == "completed"
|
||||
assert response.output[2].type == "message"
|
||||
assert response.output[2].status == "completed"
|
||||
assert response.output[2].role == "assistant"
|
||||
|
||||
# Next create a response that triggers web_search tools with max_tool_calls set to 1
|
||||
response_2 = client.responses.create(
|
||||
model=text_model_id,
|
||||
input=input,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
max_tool_calls=max_tool_calls[0],
|
||||
)
|
||||
|
||||
# Verify we got one web search tool call followed by a message
|
||||
assert len(response_2.output) == 2
|
||||
assert response_2.output[0].type == "web_search_call"
|
||||
assert response_2.output[0].status == "completed"
|
||||
assert response_2.output[1].type == "message"
|
||||
assert response_2.output[1].status == "completed"
|
||||
assert response_2.output[1].role == "assistant"
|
||||
|
||||
# Verify we have a valid max_tool_calls field
|
||||
assert response_2.max_tool_calls == max_tool_calls[0]
|
||||
|
||||
# Finally create a response that triggers web_search tools with max_tool_calls set to 5
|
||||
response_3 = client.responses.create(
|
||||
model=text_model_id,
|
||||
input=input,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
max_tool_calls=max_tool_calls[1],
|
||||
)
|
||||
|
||||
# Verify we got two web search calls followed by a message
|
||||
assert len(response_3.output) == 3
|
||||
assert response_3.output[0].type == "web_search_call"
|
||||
assert response_3.output[0].status == "completed"
|
||||
assert response_3.output[1].type == "web_search_call"
|
||||
assert response_3.output[1].status == "completed"
|
||||
assert response_3.output[2].type == "message"
|
||||
assert response_3.output[2].status == "completed"
|
||||
assert response_3.output[2].role == "assistant"
|
||||
|
||||
# Verify we have a valid max_tool_calls field
|
||||
assert response_3.max_tool_calls == max_tool_calls[1]
|
||||
|
|
|
|||
|
|
@ -600,3 +600,155 @@ def test_response_streaming_multi_turn_tool_execution(responses_client, text_mod
|
|||
assert expected_output.lower() in final_response.output_text.lower(), (
|
||||
f"Expected '{expected_output}' to appear in response: {final_response.output_text}"
|
||||
)
|
||||
|
||||
|
||||
def test_max_tool_calls_with_function_tools(responses_client, text_model_id):
|
||||
"""Test handling of max_tool_calls with function tools in responses."""
|
||||
|
||||
max_tool_calls = 1
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_weather",
|
||||
"description": "Get weather information for a specified location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city name (e.g., 'New York', 'London')",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_time",
|
||||
"description": "Get current time for a specified location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city name (e.g., 'New York', 'London')",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
response = responses_client.responses.create(
|
||||
model=text_model_id,
|
||||
input="Can you tell me the weather in Paris and the current time?",
|
||||
tools=tools,
|
||||
stream=False,
|
||||
max_tool_calls=max_tool_calls,
|
||||
)
|
||||
|
||||
# Verify we got two function calls and that the max_tool_calls does not affect function tools
|
||||
assert len(response.output) == 2
|
||||
assert response.output[0].type == "function_call"
|
||||
assert response.output[0].name == "get_weather"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[1].type == "function_call"
|
||||
assert response.output[1].name == "get_time"
|
||||
assert response.output[1].status == "completed"
|
||||
|
||||
# Verify we have a valid max_tool_calls field
|
||||
assert response.max_tool_calls == max_tool_calls
|
||||
|
||||
|
||||
def test_max_tool_calls_invalid(responses_client, text_model_id):
|
||||
"""Test handling of invalid max_tool_calls in responses."""
|
||||
|
||||
input = "Search for today's top technology news."
|
||||
invalid_max_tool_calls = 0
|
||||
tools = [
|
||||
{"type": "web_search"},
|
||||
]
|
||||
|
||||
# Create a response with an invalid max_tool_calls value i.e. 0
|
||||
# Handle ValueError from LLS and BadRequestError from OpenAI client
|
||||
with pytest.raises((ValueError, openai.BadRequestError)) as excinfo:
|
||||
responses_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=input,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
max_tool_calls=invalid_max_tool_calls,
|
||||
)
|
||||
|
||||
error_message = str(excinfo.value)
|
||||
assert f"Invalid max_tool_calls={invalid_max_tool_calls}; should be >= 1" in error_message, (
|
||||
f"Expected error message about invalid max_tool_calls, got: {error_message}"
|
||||
)
|
||||
|
||||
|
||||
def test_max_tool_calls_with_mcp_tools(responses_client, text_model_id):
|
||||
"""Test handling of max_tool_calls with mcp tools in responses."""
|
||||
|
||||
with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
|
||||
input = "Get the experiment ID for 'boiling_point' and get the user ID for 'charlie'"
|
||||
max_tool_calls = [1, 5]
|
||||
tools = [
|
||||
{"type": "mcp", "server_label": "localmcp", "server_url": mcp_server_info["server_url"]},
|
||||
]
|
||||
|
||||
# First create a response that triggers mcp tools without max_tool_calls
|
||||
response = responses_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=input,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Verify we got two mcp tool calls followed by a message
|
||||
assert len(response.output) == 4
|
||||
mcp_list_tools = [output for output in response.output if output.type == "mcp_list_tools"]
|
||||
mcp_calls = [output for output in response.output if output.type == "mcp_call"]
|
||||
message_outputs = [output for output in response.output if output.type == "message"]
|
||||
assert len(mcp_list_tools) == 1
|
||||
assert len(mcp_calls) == 2, f"Expected two mcp calls, got {len(mcp_calls)}"
|
||||
assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}"
|
||||
|
||||
# Next create a response that triggers mcp tools with max_tool_calls set to 1
|
||||
response_2 = responses_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=input,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
max_tool_calls=max_tool_calls[0],
|
||||
)
|
||||
|
||||
# Verify we got one mcp tool call followed by a message
|
||||
assert len(response_2.output) == 3
|
||||
mcp_list_tools = [output for output in response_2.output if output.type == "mcp_list_tools"]
|
||||
mcp_calls = [output for output in response_2.output if output.type == "mcp_call"]
|
||||
message_outputs = [output for output in response_2.output if output.type == "message"]
|
||||
assert len(mcp_list_tools) == 1
|
||||
assert len(mcp_calls) == 1, f"Expected one mcp call, got {len(mcp_calls)}"
|
||||
assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}"
|
||||
|
||||
# Verify we have a valid max_tool_calls field
|
||||
assert response_2.max_tool_calls == max_tool_calls[0]
|
||||
|
||||
# Finally create a response that triggers mcp tools with max_tool_calls set to 5
|
||||
response_3 = responses_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=input,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
max_tool_calls=max_tool_calls[1],
|
||||
)
|
||||
|
||||
# Verify we got two mcp tool calls followed by a message
|
||||
assert len(response_3.output) == 4
|
||||
mcp_list_tools = [output for output in response_3.output if output.type == "mcp_list_tools"]
|
||||
mcp_calls = [output for output in response_3.output if output.type == "mcp_call"]
|
||||
message_outputs = [output for output in response_3.output if output.type == "message"]
|
||||
assert len(mcp_list_tools) == 1
|
||||
assert len(mcp_calls) == 2, f"Expected two mcp calls, got {len(mcp_calls)}"
|
||||
assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}"
|
||||
|
||||
# Verify we have a valid max_tool_calls field
|
||||
assert response_3.max_tool_calls == max_tool_calls[1]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue