Merge branch 'main' into add-mcp-authentication-param

This commit is contained in:
Omar Abdelwahab 2025-11-10 15:13:45 -08:00 committed by GitHub
commit 5c6f713354
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 240 additions and 2 deletions

View file

@ -6626,6 +6626,11 @@ components:
type: string type: string
description: >- description: >-
(Optional) System message inserted into the model's context (Optional) System message inserted into the model's context
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response
input: input:
type: array type: array
items: items:
@ -6988,6 +6993,11 @@ components:
(Optional) Additional fields to include in the response. (Optional) Additional fields to include in the response.
max_infer_iters: max_infer_iters:
type: integer type: integer
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response.
additionalProperties: false additionalProperties: false
required: required:
- input - input
@ -7069,6 +7079,11 @@ components:
type: string type: string
description: >- description: >-
(Optional) System message inserted into the model's context (Optional) System message inserted into the model's context
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response
additionalProperties: false additionalProperties: false
required: required:
- created_at - created_at

View file

@ -5910,6 +5910,11 @@ components:
type: string type: string
description: >- description: >-
(Optional) System message inserted into the model's context (Optional) System message inserted into the model's context
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response
input: input:
type: array type: array
items: items:
@ -6272,6 +6277,11 @@ components:
(Optional) Additional fields to include in the response. (Optional) Additional fields to include in the response.
max_infer_iters: max_infer_iters:
type: integer type: integer
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response.
additionalProperties: false additionalProperties: false
required: required:
- input - input
@ -6353,6 +6363,11 @@ components:
type: string type: string
description: >- description: >-
(Optional) System message inserted into the model's context (Optional) System message inserted into the model's context
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response
additionalProperties: false additionalProperties: false
required: required:
- created_at - created_at

View file

@ -6626,6 +6626,11 @@ components:
type: string type: string
description: >- description: >-
(Optional) System message inserted into the model's context (Optional) System message inserted into the model's context
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response
input: input:
type: array type: array
items: items:
@ -6988,6 +6993,11 @@ components:
(Optional) Additional fields to include in the response. (Optional) Additional fields to include in the response.
max_infer_iters: max_infer_iters:
type: integer type: integer
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response.
additionalProperties: false additionalProperties: false
required: required:
- input - input
@ -7069,6 +7079,11 @@ components:
type: string type: string
description: >- description: >-
(Optional) System message inserted into the model's context (Optional) System message inserted into the model's context
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response
additionalProperties: false additionalProperties: false
required: required:
- created_at - created_at

View file

@ -87,6 +87,7 @@ class Agents(Protocol):
"List of guardrails to apply during response generation. Guardrails provide safety and content moderation." "List of guardrails to apply during response generation. Guardrails provide safety and content moderation."
), ),
] = None, ] = None,
max_tool_calls: int | None = None,
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]: ) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
"""Create a model response. """Create a model response.
@ -97,6 +98,7 @@ class Agents(Protocol):
:param conversation: (Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation. :param conversation: (Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation.
:param include: (Optional) Additional fields to include in the response. :param include: (Optional) Additional fields to include in the response.
:param guardrails: (Optional) List of guardrails to apply during response generation. Can be guardrail IDs (strings) or guardrail specifications. :param guardrails: (Optional) List of guardrails to apply during response generation. Can be guardrail IDs (strings) or guardrail specifications.
:param max_tool_calls: (Optional) Max number of total calls to built-in tools that can be processed in a response.
:returns: An OpenAIResponseObject. :returns: An OpenAIResponseObject.
""" """
... ...

View file

@ -596,6 +596,7 @@ class OpenAIResponseObject(BaseModel):
:param truncation: (Optional) Truncation strategy applied to the response :param truncation: (Optional) Truncation strategy applied to the response
:param usage: (Optional) Token usage information for the response :param usage: (Optional) Token usage information for the response
:param instructions: (Optional) System message inserted into the model's context :param instructions: (Optional) System message inserted into the model's context
:param max_tool_calls: (Optional) Max number of total calls to built-in tools that can be processed in a response
""" """
created_at: int created_at: int
@ -617,6 +618,7 @@ class OpenAIResponseObject(BaseModel):
truncation: str | None = None truncation: str | None = None
usage: OpenAIResponseUsage | None = None usage: OpenAIResponseUsage | None = None
instructions: str | None = None instructions: str | None = None
max_tool_calls: int | None = None
@json_schema_type @json_schema_type

View file

@ -102,6 +102,7 @@ class MetaReferenceAgentsImpl(Agents):
include: list[str] | None = None, include: list[str] | None = None,
max_infer_iters: int | None = 10, max_infer_iters: int | None = 10,
guardrails: list[ResponseGuardrail] | None = None, guardrails: list[ResponseGuardrail] | None = None,
max_tool_calls: int | None = None,
) -> OpenAIResponseObject: ) -> OpenAIResponseObject:
assert self.openai_responses_impl is not None, "OpenAI responses not initialized" assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
result = await self.openai_responses_impl.create_openai_response( result = await self.openai_responses_impl.create_openai_response(
@ -119,6 +120,7 @@ class MetaReferenceAgentsImpl(Agents):
include, include,
max_infer_iters, max_infer_iters,
guardrails, guardrails,
max_tool_calls,
) )
return result # type: ignore[no-any-return] return result # type: ignore[no-any-return]

View file

@ -255,6 +255,7 @@ class OpenAIResponsesImpl:
include: list[str] | None = None, include: list[str] | None = None,
max_infer_iters: int | None = 10, max_infer_iters: int | None = 10,
guardrails: list[str | ResponseGuardrailSpec] | None = None, guardrails: list[str | ResponseGuardrailSpec] | None = None,
max_tool_calls: int | None = None,
): ):
stream = bool(stream) stream = bool(stream)
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
@ -283,6 +284,9 @@ class OpenAIResponsesImpl:
if not conversation.startswith("conv_"): if not conversation.startswith("conv_"):
raise InvalidConversationIdError(conversation) raise InvalidConversationIdError(conversation)
if max_tool_calls is not None and max_tool_calls < 1:
raise ValueError(f"Invalid {max_tool_calls=}; should be >= 1")
stream_gen = self._create_streaming_response( stream_gen = self._create_streaming_response(
input=input, input=input,
conversation=conversation, conversation=conversation,
@ -295,6 +299,7 @@ class OpenAIResponsesImpl:
tools=tools, tools=tools,
max_infer_iters=max_infer_iters, max_infer_iters=max_infer_iters,
guardrail_ids=guardrail_ids, guardrail_ids=guardrail_ids,
max_tool_calls=max_tool_calls,
) )
if stream: if stream:
@ -344,6 +349,7 @@ class OpenAIResponsesImpl:
tools: list[OpenAIResponseInputTool] | None = None, tools: list[OpenAIResponseInputTool] | None = None,
max_infer_iters: int | None = 10, max_infer_iters: int | None = 10,
guardrail_ids: list[str] | None = None, guardrail_ids: list[str] | None = None,
max_tool_calls: int | None = None,
) -> AsyncIterator[OpenAIResponseObjectStream]: ) -> AsyncIterator[OpenAIResponseObjectStream]:
# These should never be None when called from create_openai_response (which sets defaults) # These should never be None when called from create_openai_response (which sets defaults)
# but we assert here to help mypy understand the types # but we assert here to help mypy understand the types
@ -386,6 +392,7 @@ class OpenAIResponsesImpl:
safety_api=self.safety_api, safety_api=self.safety_api,
guardrail_ids=guardrail_ids, guardrail_ids=guardrail_ids,
instructions=instructions, instructions=instructions,
max_tool_calls=max_tool_calls,
) )
# Stream the response # Stream the response

View file

@ -115,6 +115,7 @@ class StreamingResponseOrchestrator:
safety_api, safety_api,
guardrail_ids: list[str] | None = None, guardrail_ids: list[str] | None = None,
prompt: OpenAIResponsePrompt | None = None, prompt: OpenAIResponsePrompt | None = None,
max_tool_calls: int | None = None,
): ):
self.inference_api = inference_api self.inference_api = inference_api
self.ctx = ctx self.ctx = ctx
@ -126,6 +127,10 @@ class StreamingResponseOrchestrator:
self.safety_api = safety_api self.safety_api = safety_api
self.guardrail_ids = guardrail_ids or [] self.guardrail_ids = guardrail_ids or []
self.prompt = prompt self.prompt = prompt
# System message that is inserted into the model's context
self.instructions = instructions
# Max number of total calls to built-in tools that can be processed in a response
self.max_tool_calls = max_tool_calls
self.sequence_number = 0 self.sequence_number = 0
# Store MCP tool mapping that gets built during tool processing # Store MCP tool mapping that gets built during tool processing
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ( self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
@ -139,8 +144,8 @@ class StreamingResponseOrchestrator:
self.accumulated_usage: OpenAIResponseUsage | None = None self.accumulated_usage: OpenAIResponseUsage | None = None
# Track if we've sent a refusal response # Track if we've sent a refusal response
self.violation_detected = False self.violation_detected = False
# system message that is inserted into the model's context # Track total calls made to built-in tools
self.instructions = instructions self.accumulated_builtin_tool_calls = 0
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream: async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
"""Create a refusal response to replace streaming content.""" """Create a refusal response to replace streaming content."""
@ -186,6 +191,7 @@ class StreamingResponseOrchestrator:
usage=self.accumulated_usage, usage=self.accumulated_usage,
instructions=self.instructions, instructions=self.instructions,
prompt=self.prompt, prompt=self.prompt,
max_tool_calls=self.max_tool_calls,
) )
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]: async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
@ -894,6 +900,11 @@ class StreamingResponseOrchestrator:
"""Coordinate execution of both function and non-function tool calls.""" """Coordinate execution of both function and non-function tool calls."""
# Execute non-function tool calls # Execute non-function tool calls
for tool_call in non_function_tool_calls: for tool_call in non_function_tool_calls:
# Check if total calls made to built-in and mcp tools exceed max_tool_calls
if self.max_tool_calls is not None and self.accumulated_builtin_tool_calls >= self.max_tool_calls:
logger.info(f"Ignoring built-in and mcp tool call since reached the limit of {self.max_tool_calls=}.")
break
# Find the item_id for this tool call # Find the item_id for this tool call
matching_item_id = None matching_item_id = None
for index, item_id in completion_result_data.tool_call_item_ids.items(): for index, item_id in completion_result_data.tool_call_item_ids.items():
@ -974,6 +985,9 @@ class StreamingResponseOrchestrator:
if tool_response_message: if tool_response_message:
next_turn_messages.append(tool_response_message) next_turn_messages.append(tool_response_message)
# Track number of calls made to built-in and mcp tools
self.accumulated_builtin_tool_calls += 1
# Execute function tool calls (client-side) # Execute function tool calls (client-side)
for tool_call in function_tool_calls: for tool_call in function_tool_calls:
# Find the item_id for this tool call from our tracking dictionary # Find the item_id for this tool call from our tracking dictionary

View file

@ -516,3 +516,169 @@ def test_response_with_instructions(openai_client, client_with_models, text_mode
# Verify instructions from previous response was not carried over to the next response # Verify instructions from previous response was not carried over to the next response
assert response_with_instructions2.instructions == instructions2 assert response_with_instructions2.instructions == instructions2
@pytest.mark.skip(reason="Tool calling is not reliable.")
def test_max_tool_calls_with_function_tools(openai_client, client_with_models, text_model_id):
"""Test handling of max_tool_calls with function tools in responses."""
if isinstance(client_with_models, LlamaStackAsLibraryClient):
pytest.skip("OpenAI responses are not supported when testing with library client yet.")
client = openai_client
max_tool_calls = 1
tools = [
{
"type": "function",
"name": "get_weather",
"description": "Get weather information for a specified location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city name (e.g., 'New York', 'London')",
},
},
},
},
{
"type": "function",
"name": "get_time",
"description": "Get current time for a specified location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city name (e.g., 'New York', 'London')",
},
},
},
},
]
# First create a response that triggers function tools
response = client.responses.create(
model=text_model_id,
input="Can you tell me the weather in Paris and the current time?",
tools=tools,
stream=False,
max_tool_calls=max_tool_calls,
)
# Verify we got two function calls and that the max_tool_calls do not affect function tools
assert len(response.output) == 2
assert response.output[0].type == "function_call"
assert response.output[0].name == "get_weather"
assert response.output[0].status == "completed"
assert response.output[1].type == "function_call"
assert response.output[1].name == "get_time"
assert response.output[0].status == "completed"
# Verify we have a valid max_tool_calls field
assert response.max_tool_calls == max_tool_calls
def test_max_tool_calls_invalid(openai_client, client_with_models, text_model_id):
"""Test handling of invalid max_tool_calls in responses."""
if isinstance(client_with_models, LlamaStackAsLibraryClient):
pytest.skip("OpenAI responses are not supported when testing with library client yet.")
client = openai_client
input = "Search for today's top technology news."
invalid_max_tool_calls = 0
tools = [
{"type": "web_search"},
]
# Create a response with an invalid max_tool_calls value i.e. 0
# Handle ValueError from LLS and BadRequestError from OpenAI client
with pytest.raises((ValueError, BadRequestError)) as excinfo:
client.responses.create(
model=text_model_id,
input=input,
tools=tools,
stream=False,
max_tool_calls=invalid_max_tool_calls,
)
error_message = str(excinfo.value)
assert f"Invalid max_tool_calls={invalid_max_tool_calls}; should be >= 1" in error_message, (
f"Expected error message about invalid max_tool_calls, got: {error_message}"
)
def test_max_tool_calls_with_builtin_tools(openai_client, client_with_models, text_model_id):
"""Test handling of max_tool_calls with built-in tools in responses."""
if isinstance(client_with_models, LlamaStackAsLibraryClient):
pytest.skip("OpenAI responses are not supported when testing with library client yet.")
client = openai_client
input = "Search for today's top technology and a positive news story. You MUST make exactly two separate web search calls."
max_tool_calls = [1, 5]
tools = [
{"type": "web_search"},
]
# First create a response that triggers web_search tools without max_tool_calls
response = client.responses.create(
model=text_model_id,
input=input,
tools=tools,
stream=False,
)
# Verify we got two web search calls followed by a message
assert len(response.output) == 3
assert response.output[0].type == "web_search_call"
assert response.output[0].status == "completed"
assert response.output[1].type == "web_search_call"
assert response.output[1].status == "completed"
assert response.output[2].type == "message"
assert response.output[2].status == "completed"
assert response.output[2].role == "assistant"
# Next create a response that triggers web_search tools with max_tool_calls set to 1
response_2 = client.responses.create(
model=text_model_id,
input=input,
tools=tools,
stream=False,
max_tool_calls=max_tool_calls[0],
)
# Verify we got one web search tool call followed by a message
assert len(response_2.output) == 2
assert response_2.output[0].type == "web_search_call"
assert response_2.output[0].status == "completed"
assert response_2.output[1].type == "message"
assert response_2.output[1].status == "completed"
assert response_2.output[1].role == "assistant"
# Verify we have a valid max_tool_calls field
assert response_2.max_tool_calls == max_tool_calls[0]
# Finally create a response that triggers web_search tools with max_tool_calls set to 5
response_3 = client.responses.create(
model=text_model_id,
input=input,
tools=tools,
stream=False,
max_tool_calls=max_tool_calls[1],
)
# Verify we got two web search calls followed by a message
assert len(response_3.output) == 3
assert response_3.output[0].type == "web_search_call"
assert response_3.output[0].status == "completed"
assert response_3.output[1].type == "web_search_call"
assert response_3.output[1].status == "completed"
assert response_3.output[2].type == "message"
assert response_3.output[2].status == "completed"
assert response_3.output[2].role == "assistant"
# Verify we have a valid max_tool_calls field
assert response_3.max_tool_calls == max_tool_calls[1]