feat: Implement the 'max_tool_calls' parameter for the Responses API (#4062)

# Problem
Responses API uses max_tool_calls parameter to limit the number of tool
calls that can be generated in a response. Currently, LLS implementation
of the Responses API does not support this parameter.

# What does this PR do?
This pull request adds the max_tool_calls field to the response object
definition and updates the inline provider. it also ensures that:

- the total number of calls to built-in and mcp tools do not exceed
max_tool_calls
- an error is thrown if max_tool_calls < 1 (behavior seen with the
OpenAI Responses API, but we can change this if needed)

Closes #[3563](https://github.com/llamastack/llama-stack/issues/3563)

## Test Plan
- Tested manually for change in model response w.r.t supplied
max_tool_calls field.
- Added integration tests to test invalid max_tool_calls parameter.
- Added integration tests to check max_tool_calls parameter with
built-in and function tools.
- Added integration tests to check max_tool_calls parameter in the
returned response object.
- Recorded OpenAI Responses API behavior using a sample script:
https://github.com/s-akhtar-baig/llama-stack-examples/blob/main/responses/src/max_tool_calls.py

Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
Shabana Baig 2025-11-10 16:21:27 -05:00 committed by GitHub
parent 209a78b618
commit 433438cfc0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 240 additions and 2 deletions

View file

@ -87,6 +87,7 @@ class Agents(Protocol):
"List of guardrails to apply during response generation. Guardrails provide safety and content moderation."
),
] = None,
max_tool_calls: int | None = None,
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
"""Create a model response.
@ -97,6 +98,7 @@ class Agents(Protocol):
:param conversation: (Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation.
:param include: (Optional) Additional fields to include in the response.
:param guardrails: (Optional) List of guardrails to apply during response generation. Can be guardrail IDs (strings) or guardrail specifications.
:param max_tool_calls: (Optional) Max number of total calls to built-in tools that can be processed in a response.
:returns: An OpenAIResponseObject.
"""
...

View file

@ -594,6 +594,7 @@ class OpenAIResponseObject(BaseModel):
:param truncation: (Optional) Truncation strategy applied to the response
:param usage: (Optional) Token usage information for the response
:param instructions: (Optional) System message inserted into the model's context
:param max_tool_calls: (Optional) Max number of total calls to built-in tools that can be processed in a response
"""
created_at: int
@ -615,6 +616,7 @@ class OpenAIResponseObject(BaseModel):
truncation: str | None = None
usage: OpenAIResponseUsage | None = None
instructions: str | None = None
max_tool_calls: int | None = None
@json_schema_type

View file

@ -102,6 +102,7 @@ class MetaReferenceAgentsImpl(Agents):
include: list[str] | None = None,
max_infer_iters: int | None = 10,
guardrails: list[ResponseGuardrail] | None = None,
max_tool_calls: int | None = None,
) -> OpenAIResponseObject:
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
result = await self.openai_responses_impl.create_openai_response(
@ -119,6 +120,7 @@ class MetaReferenceAgentsImpl(Agents):
include,
max_infer_iters,
guardrails,
max_tool_calls,
)
return result # type: ignore[no-any-return]

View file

@ -255,6 +255,7 @@ class OpenAIResponsesImpl:
include: list[str] | None = None,
max_infer_iters: int | None = 10,
guardrails: list[str | ResponseGuardrailSpec] | None = None,
max_tool_calls: int | None = None,
):
stream = bool(stream)
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
@ -270,6 +271,9 @@ class OpenAIResponsesImpl:
if not conversation.startswith("conv_"):
raise InvalidConversationIdError(conversation)
if max_tool_calls is not None and max_tool_calls < 1:
raise ValueError(f"Invalid {max_tool_calls=}; should be >= 1")
stream_gen = self._create_streaming_response(
input=input,
conversation=conversation,
@ -282,6 +286,7 @@ class OpenAIResponsesImpl:
tools=tools,
max_infer_iters=max_infer_iters,
guardrail_ids=guardrail_ids,
max_tool_calls=max_tool_calls,
)
if stream:
@ -331,6 +336,7 @@ class OpenAIResponsesImpl:
tools: list[OpenAIResponseInputTool] | None = None,
max_infer_iters: int | None = 10,
guardrail_ids: list[str] | None = None,
max_tool_calls: int | None = None,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# These should never be None when called from create_openai_response (which sets defaults)
# but we assert here to help mypy understand the types
@ -373,6 +379,7 @@ class OpenAIResponsesImpl:
safety_api=self.safety_api,
guardrail_ids=guardrail_ids,
instructions=instructions,
max_tool_calls=max_tool_calls,
)
# Stream the response

View file

@ -115,6 +115,7 @@ class StreamingResponseOrchestrator:
safety_api,
guardrail_ids: list[str] | None = None,
prompt: OpenAIResponsePrompt | None = None,
max_tool_calls: int | None = None,
):
self.inference_api = inference_api
self.ctx = ctx
@ -126,6 +127,10 @@ class StreamingResponseOrchestrator:
self.safety_api = safety_api
self.guardrail_ids = guardrail_ids or []
self.prompt = prompt
# System message that is inserted into the model's context
self.instructions = instructions
# Max number of total calls to built-in tools that can be processed in a response
self.max_tool_calls = max_tool_calls
self.sequence_number = 0
# Store MCP tool mapping that gets built during tool processing
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
@ -139,8 +144,8 @@ class StreamingResponseOrchestrator:
self.accumulated_usage: OpenAIResponseUsage | None = None
# Track if we've sent a refusal response
self.violation_detected = False
# system message that is inserted into the model's context
self.instructions = instructions
# Track total calls made to built-in tools
self.accumulated_builtin_tool_calls = 0
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
"""Create a refusal response to replace streaming content."""
@ -186,6 +191,7 @@ class StreamingResponseOrchestrator:
usage=self.accumulated_usage,
instructions=self.instructions,
prompt=self.prompt,
max_tool_calls=self.max_tool_calls,
)
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
@ -894,6 +900,11 @@ class StreamingResponseOrchestrator:
"""Coordinate execution of both function and non-function tool calls."""
# Execute non-function tool calls
for tool_call in non_function_tool_calls:
# Check if total calls made to built-in and mcp tools exceed max_tool_calls
if self.max_tool_calls is not None and self.accumulated_builtin_tool_calls >= self.max_tool_calls:
logger.info(f"Ignoring built-in and mcp tool call since reached the limit of {self.max_tool_calls=}.")
break
# Find the item_id for this tool call
matching_item_id = None
for index, item_id in completion_result_data.tool_call_item_ids.items():
@ -974,6 +985,9 @@ class StreamingResponseOrchestrator:
if tool_response_message:
next_turn_messages.append(tool_response_message)
# Track number of calls made to built-in and mcp tools
self.accumulated_builtin_tool_calls += 1
# Execute function tool calls (client-side)
for tool_call in function_tool_calls:
# Find the item_id for this tool call from our tracking dictionary