mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
feat: Implement the 'max_tool_calls' parameter for the Responses API (#4062)
# Problem Responses API uses max_tool_calls parameter to limit the number of tool calls that can be generated in a response. Currently, LLS implementation of the Responses API does not support this parameter. # What does this PR do? This pull request adds the max_tool_calls field to the response object definition and updates the inline provider. it also ensures that: - the total number of calls to built-in and mcp tools do not exceed max_tool_calls - an error is thrown if max_tool_calls < 1 (behavior seen with the OpenAI Responses API, but we can change this if needed) Closes #[3563](https://github.com/llamastack/llama-stack/issues/3563) ## Test Plan - Tested manually for change in model response w.r.t supplied max_tool_calls field. - Added integration tests to test invalid max_tool_calls parameter. - Added integration tests to check max_tool_calls parameter with built-in and function tools. - Added integration tests to check max_tool_calls parameter in the returned response object. - Recorded OpenAI Responses API behavior using a sample script: https://github.com/s-akhtar-baig/llama-stack-examples/blob/main/responses/src/max_tool_calls.py Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
parent
209a78b618
commit
433438cfc0
9 changed files with 240 additions and 2 deletions
|
|
@ -6626,6 +6626,11 @@ components:
|
||||||
type: string
|
type: string
|
||||||
description: >-
|
description: >-
|
||||||
(Optional) System message inserted into the model's context
|
(Optional) System message inserted into the model's context
|
||||||
|
max_tool_calls:
|
||||||
|
type: integer
|
||||||
|
description: >-
|
||||||
|
(Optional) Max number of total calls to built-in tools that can be processed
|
||||||
|
in a response
|
||||||
input:
|
input:
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
|
|
@ -6984,6 +6989,11 @@ components:
|
||||||
(Optional) Additional fields to include in the response.
|
(Optional) Additional fields to include in the response.
|
||||||
max_infer_iters:
|
max_infer_iters:
|
||||||
type: integer
|
type: integer
|
||||||
|
max_tool_calls:
|
||||||
|
type: integer
|
||||||
|
description: >-
|
||||||
|
(Optional) Max number of total calls to built-in tools that can be processed
|
||||||
|
in a response.
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- input
|
- input
|
||||||
|
|
@ -7065,6 +7075,11 @@ components:
|
||||||
type: string
|
type: string
|
||||||
description: >-
|
description: >-
|
||||||
(Optional) System message inserted into the model's context
|
(Optional) System message inserted into the model's context
|
||||||
|
max_tool_calls:
|
||||||
|
type: integer
|
||||||
|
description: >-
|
||||||
|
(Optional) Max number of total calls to built-in tools that can be processed
|
||||||
|
in a response
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- created_at
|
- created_at
|
||||||
|
|
|
||||||
15
docs/static/llama-stack-spec.yaml
vendored
15
docs/static/llama-stack-spec.yaml
vendored
|
|
@ -5910,6 +5910,11 @@ components:
|
||||||
type: string
|
type: string
|
||||||
description: >-
|
description: >-
|
||||||
(Optional) System message inserted into the model's context
|
(Optional) System message inserted into the model's context
|
||||||
|
max_tool_calls:
|
||||||
|
type: integer
|
||||||
|
description: >-
|
||||||
|
(Optional) Max number of total calls to built-in tools that can be processed
|
||||||
|
in a response
|
||||||
input:
|
input:
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
|
|
@ -6268,6 +6273,11 @@ components:
|
||||||
(Optional) Additional fields to include in the response.
|
(Optional) Additional fields to include in the response.
|
||||||
max_infer_iters:
|
max_infer_iters:
|
||||||
type: integer
|
type: integer
|
||||||
|
max_tool_calls:
|
||||||
|
type: integer
|
||||||
|
description: >-
|
||||||
|
(Optional) Max number of total calls to built-in tools that can be processed
|
||||||
|
in a response.
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- input
|
- input
|
||||||
|
|
@ -6349,6 +6359,11 @@ components:
|
||||||
type: string
|
type: string
|
||||||
description: >-
|
description: >-
|
||||||
(Optional) System message inserted into the model's context
|
(Optional) System message inserted into the model's context
|
||||||
|
max_tool_calls:
|
||||||
|
type: integer
|
||||||
|
description: >-
|
||||||
|
(Optional) Max number of total calls to built-in tools that can be processed
|
||||||
|
in a response
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- created_at
|
- created_at
|
||||||
|
|
|
||||||
15
docs/static/stainless-llama-stack-spec.yaml
vendored
15
docs/static/stainless-llama-stack-spec.yaml
vendored
|
|
@ -6626,6 +6626,11 @@ components:
|
||||||
type: string
|
type: string
|
||||||
description: >-
|
description: >-
|
||||||
(Optional) System message inserted into the model's context
|
(Optional) System message inserted into the model's context
|
||||||
|
max_tool_calls:
|
||||||
|
type: integer
|
||||||
|
description: >-
|
||||||
|
(Optional) Max number of total calls to built-in tools that can be processed
|
||||||
|
in a response
|
||||||
input:
|
input:
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
|
|
@ -6984,6 +6989,11 @@ components:
|
||||||
(Optional) Additional fields to include in the response.
|
(Optional) Additional fields to include in the response.
|
||||||
max_infer_iters:
|
max_infer_iters:
|
||||||
type: integer
|
type: integer
|
||||||
|
max_tool_calls:
|
||||||
|
type: integer
|
||||||
|
description: >-
|
||||||
|
(Optional) Max number of total calls to built-in tools that can be processed
|
||||||
|
in a response.
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- input
|
- input
|
||||||
|
|
@ -7065,6 +7075,11 @@ components:
|
||||||
type: string
|
type: string
|
||||||
description: >-
|
description: >-
|
||||||
(Optional) System message inserted into the model's context
|
(Optional) System message inserted into the model's context
|
||||||
|
max_tool_calls:
|
||||||
|
type: integer
|
||||||
|
description: >-
|
||||||
|
(Optional) Max number of total calls to built-in tools that can be processed
|
||||||
|
in a response
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- created_at
|
- created_at
|
||||||
|
|
|
||||||
|
|
@ -87,6 +87,7 @@ class Agents(Protocol):
|
||||||
"List of guardrails to apply during response generation. Guardrails provide safety and content moderation."
|
"List of guardrails to apply during response generation. Guardrails provide safety and content moderation."
|
||||||
),
|
),
|
||||||
] = None,
|
] = None,
|
||||||
|
max_tool_calls: int | None = None,
|
||||||
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
|
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
"""Create a model response.
|
"""Create a model response.
|
||||||
|
|
||||||
|
|
@ -97,6 +98,7 @@ class Agents(Protocol):
|
||||||
:param conversation: (Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation.
|
:param conversation: (Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation.
|
||||||
:param include: (Optional) Additional fields to include in the response.
|
:param include: (Optional) Additional fields to include in the response.
|
||||||
:param guardrails: (Optional) List of guardrails to apply during response generation. Can be guardrail IDs (strings) or guardrail specifications.
|
:param guardrails: (Optional) List of guardrails to apply during response generation. Can be guardrail IDs (strings) or guardrail specifications.
|
||||||
|
:param max_tool_calls: (Optional) Max number of total calls to built-in tools that can be processed in a response.
|
||||||
:returns: An OpenAIResponseObject.
|
:returns: An OpenAIResponseObject.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
|
||||||
|
|
@ -594,6 +594,7 @@ class OpenAIResponseObject(BaseModel):
|
||||||
:param truncation: (Optional) Truncation strategy applied to the response
|
:param truncation: (Optional) Truncation strategy applied to the response
|
||||||
:param usage: (Optional) Token usage information for the response
|
:param usage: (Optional) Token usage information for the response
|
||||||
:param instructions: (Optional) System message inserted into the model's context
|
:param instructions: (Optional) System message inserted into the model's context
|
||||||
|
:param max_tool_calls: (Optional) Max number of total calls to built-in tools that can be processed in a response
|
||||||
"""
|
"""
|
||||||
|
|
||||||
created_at: int
|
created_at: int
|
||||||
|
|
@ -615,6 +616,7 @@ class OpenAIResponseObject(BaseModel):
|
||||||
truncation: str | None = None
|
truncation: str | None = None
|
||||||
usage: OpenAIResponseUsage | None = None
|
usage: OpenAIResponseUsage | None = None
|
||||||
instructions: str | None = None
|
instructions: str | None = None
|
||||||
|
max_tool_calls: int | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
||||||
|
|
@ -102,6 +102,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
include: list[str] | None = None,
|
include: list[str] | None = None,
|
||||||
max_infer_iters: int | None = 10,
|
max_infer_iters: int | None = 10,
|
||||||
guardrails: list[ResponseGuardrail] | None = None,
|
guardrails: list[ResponseGuardrail] | None = None,
|
||||||
|
max_tool_calls: int | None = None,
|
||||||
) -> OpenAIResponseObject:
|
) -> OpenAIResponseObject:
|
||||||
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||||
result = await self.openai_responses_impl.create_openai_response(
|
result = await self.openai_responses_impl.create_openai_response(
|
||||||
|
|
@ -119,6 +120,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
include,
|
include,
|
||||||
max_infer_iters,
|
max_infer_iters,
|
||||||
guardrails,
|
guardrails,
|
||||||
|
max_tool_calls,
|
||||||
)
|
)
|
||||||
return result # type: ignore[no-any-return]
|
return result # type: ignore[no-any-return]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -255,6 +255,7 @@ class OpenAIResponsesImpl:
|
||||||
include: list[str] | None = None,
|
include: list[str] | None = None,
|
||||||
max_infer_iters: int | None = 10,
|
max_infer_iters: int | None = 10,
|
||||||
guardrails: list[str | ResponseGuardrailSpec] | None = None,
|
guardrails: list[str | ResponseGuardrailSpec] | None = None,
|
||||||
|
max_tool_calls: int | None = None,
|
||||||
):
|
):
|
||||||
stream = bool(stream)
|
stream = bool(stream)
|
||||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
||||||
|
|
@ -270,6 +271,9 @@ class OpenAIResponsesImpl:
|
||||||
if not conversation.startswith("conv_"):
|
if not conversation.startswith("conv_"):
|
||||||
raise InvalidConversationIdError(conversation)
|
raise InvalidConversationIdError(conversation)
|
||||||
|
|
||||||
|
if max_tool_calls is not None and max_tool_calls < 1:
|
||||||
|
raise ValueError(f"Invalid {max_tool_calls=}; should be >= 1")
|
||||||
|
|
||||||
stream_gen = self._create_streaming_response(
|
stream_gen = self._create_streaming_response(
|
||||||
input=input,
|
input=input,
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
|
|
@ -282,6 +286,7 @@ class OpenAIResponsesImpl:
|
||||||
tools=tools,
|
tools=tools,
|
||||||
max_infer_iters=max_infer_iters,
|
max_infer_iters=max_infer_iters,
|
||||||
guardrail_ids=guardrail_ids,
|
guardrail_ids=guardrail_ids,
|
||||||
|
max_tool_calls=max_tool_calls,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
|
@ -331,6 +336,7 @@ class OpenAIResponsesImpl:
|
||||||
tools: list[OpenAIResponseInputTool] | None = None,
|
tools: list[OpenAIResponseInputTool] | None = None,
|
||||||
max_infer_iters: int | None = 10,
|
max_infer_iters: int | None = 10,
|
||||||
guardrail_ids: list[str] | None = None,
|
guardrail_ids: list[str] | None = None,
|
||||||
|
max_tool_calls: int | None = None,
|
||||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
# These should never be None when called from create_openai_response (which sets defaults)
|
# These should never be None when called from create_openai_response (which sets defaults)
|
||||||
# but we assert here to help mypy understand the types
|
# but we assert here to help mypy understand the types
|
||||||
|
|
@ -373,6 +379,7 @@ class OpenAIResponsesImpl:
|
||||||
safety_api=self.safety_api,
|
safety_api=self.safety_api,
|
||||||
guardrail_ids=guardrail_ids,
|
guardrail_ids=guardrail_ids,
|
||||||
instructions=instructions,
|
instructions=instructions,
|
||||||
|
max_tool_calls=max_tool_calls,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Stream the response
|
# Stream the response
|
||||||
|
|
|
||||||
|
|
@ -115,6 +115,7 @@ class StreamingResponseOrchestrator:
|
||||||
safety_api,
|
safety_api,
|
||||||
guardrail_ids: list[str] | None = None,
|
guardrail_ids: list[str] | None = None,
|
||||||
prompt: OpenAIResponsePrompt | None = None,
|
prompt: OpenAIResponsePrompt | None = None,
|
||||||
|
max_tool_calls: int | None = None,
|
||||||
):
|
):
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
|
|
@ -126,6 +127,10 @@ class StreamingResponseOrchestrator:
|
||||||
self.safety_api = safety_api
|
self.safety_api = safety_api
|
||||||
self.guardrail_ids = guardrail_ids or []
|
self.guardrail_ids = guardrail_ids or []
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
|
# System message that is inserted into the model's context
|
||||||
|
self.instructions = instructions
|
||||||
|
# Max number of total calls to built-in tools that can be processed in a response
|
||||||
|
self.max_tool_calls = max_tool_calls
|
||||||
self.sequence_number = 0
|
self.sequence_number = 0
|
||||||
# Store MCP tool mapping that gets built during tool processing
|
# Store MCP tool mapping that gets built during tool processing
|
||||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
|
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
|
||||||
|
|
@ -139,8 +144,8 @@ class StreamingResponseOrchestrator:
|
||||||
self.accumulated_usage: OpenAIResponseUsage | None = None
|
self.accumulated_usage: OpenAIResponseUsage | None = None
|
||||||
# Track if we've sent a refusal response
|
# Track if we've sent a refusal response
|
||||||
self.violation_detected = False
|
self.violation_detected = False
|
||||||
# system message that is inserted into the model's context
|
# Track total calls made to built-in tools
|
||||||
self.instructions = instructions
|
self.accumulated_builtin_tool_calls = 0
|
||||||
|
|
||||||
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
|
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
|
||||||
"""Create a refusal response to replace streaming content."""
|
"""Create a refusal response to replace streaming content."""
|
||||||
|
|
@ -186,6 +191,7 @@ class StreamingResponseOrchestrator:
|
||||||
usage=self.accumulated_usage,
|
usage=self.accumulated_usage,
|
||||||
instructions=self.instructions,
|
instructions=self.instructions,
|
||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
|
max_tool_calls=self.max_tool_calls,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
|
|
@ -894,6 +900,11 @@ class StreamingResponseOrchestrator:
|
||||||
"""Coordinate execution of both function and non-function tool calls."""
|
"""Coordinate execution of both function and non-function tool calls."""
|
||||||
# Execute non-function tool calls
|
# Execute non-function tool calls
|
||||||
for tool_call in non_function_tool_calls:
|
for tool_call in non_function_tool_calls:
|
||||||
|
# Check if total calls made to built-in and mcp tools exceed max_tool_calls
|
||||||
|
if self.max_tool_calls is not None and self.accumulated_builtin_tool_calls >= self.max_tool_calls:
|
||||||
|
logger.info(f"Ignoring built-in and mcp tool call since reached the limit of {self.max_tool_calls=}.")
|
||||||
|
break
|
||||||
|
|
||||||
# Find the item_id for this tool call
|
# Find the item_id for this tool call
|
||||||
matching_item_id = None
|
matching_item_id = None
|
||||||
for index, item_id in completion_result_data.tool_call_item_ids.items():
|
for index, item_id in completion_result_data.tool_call_item_ids.items():
|
||||||
|
|
@ -974,6 +985,9 @@ class StreamingResponseOrchestrator:
|
||||||
if tool_response_message:
|
if tool_response_message:
|
||||||
next_turn_messages.append(tool_response_message)
|
next_turn_messages.append(tool_response_message)
|
||||||
|
|
||||||
|
# Track number of calls made to built-in and mcp tools
|
||||||
|
self.accumulated_builtin_tool_calls += 1
|
||||||
|
|
||||||
# Execute function tool calls (client-side)
|
# Execute function tool calls (client-side)
|
||||||
for tool_call in function_tool_calls:
|
for tool_call in function_tool_calls:
|
||||||
# Find the item_id for this tool call from our tracking dictionary
|
# Find the item_id for this tool call from our tracking dictionary
|
||||||
|
|
|
||||||
|
|
@ -516,3 +516,169 @@ def test_response_with_instructions(openai_client, client_with_models, text_mode
|
||||||
|
|
||||||
# Verify instructions from previous response was not carried over to the next response
|
# Verify instructions from previous response was not carried over to the next response
|
||||||
assert response_with_instructions2.instructions == instructions2
|
assert response_with_instructions2.instructions == instructions2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Tool calling is not reliable.")
|
||||||
|
def test_max_tool_calls_with_function_tools(openai_client, client_with_models, text_model_id):
|
||||||
|
"""Test handling of max_tool_calls with function tools in responses."""
|
||||||
|
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||||
|
pytest.skip("OpenAI responses are not supported when testing with library client yet.")
|
||||||
|
|
||||||
|
client = openai_client
|
||||||
|
max_tool_calls = 1
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get weather information for a specified location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city name (e.g., 'New York', 'London')",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"name": "get_time",
|
||||||
|
"description": "Get current time for a specified location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city name (e.g., 'New York', 'London')",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# First create a response that triggers function tools
|
||||||
|
response = client.responses.create(
|
||||||
|
model=text_model_id,
|
||||||
|
input="Can you tell me the weather in Paris and the current time?",
|
||||||
|
tools=tools,
|
||||||
|
stream=False,
|
||||||
|
max_tool_calls=max_tool_calls,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify we got two function calls and that the max_tool_calls do not affect function tools
|
||||||
|
assert len(response.output) == 2
|
||||||
|
assert response.output[0].type == "function_call"
|
||||||
|
assert response.output[0].name == "get_weather"
|
||||||
|
assert response.output[0].status == "completed"
|
||||||
|
assert response.output[1].type == "function_call"
|
||||||
|
assert response.output[1].name == "get_time"
|
||||||
|
assert response.output[0].status == "completed"
|
||||||
|
|
||||||
|
# Verify we have a valid max_tool_calls field
|
||||||
|
assert response.max_tool_calls == max_tool_calls
|
||||||
|
|
||||||
|
|
||||||
|
def test_max_tool_calls_invalid(openai_client, client_with_models, text_model_id):
|
||||||
|
"""Test handling of invalid max_tool_calls in responses."""
|
||||||
|
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||||
|
pytest.skip("OpenAI responses are not supported when testing with library client yet.")
|
||||||
|
|
||||||
|
client = openai_client
|
||||||
|
|
||||||
|
input = "Search for today's top technology news."
|
||||||
|
invalid_max_tool_calls = 0
|
||||||
|
tools = [
|
||||||
|
{"type": "web_search"},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create a response with an invalid max_tool_calls value i.e. 0
|
||||||
|
# Handle ValueError from LLS and BadRequestError from OpenAI client
|
||||||
|
with pytest.raises((ValueError, BadRequestError)) as excinfo:
|
||||||
|
client.responses.create(
|
||||||
|
model=text_model_id,
|
||||||
|
input=input,
|
||||||
|
tools=tools,
|
||||||
|
stream=False,
|
||||||
|
max_tool_calls=invalid_max_tool_calls,
|
||||||
|
)
|
||||||
|
|
||||||
|
error_message = str(excinfo.value)
|
||||||
|
assert f"Invalid max_tool_calls={invalid_max_tool_calls}; should be >= 1" in error_message, (
|
||||||
|
f"Expected error message about invalid max_tool_calls, got: {error_message}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_max_tool_calls_with_builtin_tools(openai_client, client_with_models, text_model_id):
|
||||||
|
"""Test handling of max_tool_calls with built-in tools in responses."""
|
||||||
|
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||||
|
pytest.skip("OpenAI responses are not supported when testing with library client yet.")
|
||||||
|
|
||||||
|
client = openai_client
|
||||||
|
|
||||||
|
input = "Search for today's top technology and a positive news story. You MUST make exactly two separate web search calls."
|
||||||
|
max_tool_calls = [1, 5]
|
||||||
|
tools = [
|
||||||
|
{"type": "web_search"},
|
||||||
|
]
|
||||||
|
|
||||||
|
# First create a response that triggers web_search tools without max_tool_calls
|
||||||
|
response = client.responses.create(
|
||||||
|
model=text_model_id,
|
||||||
|
input=input,
|
||||||
|
tools=tools,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify we got two web search calls followed by a message
|
||||||
|
assert len(response.output) == 3
|
||||||
|
assert response.output[0].type == "web_search_call"
|
||||||
|
assert response.output[0].status == "completed"
|
||||||
|
assert response.output[1].type == "web_search_call"
|
||||||
|
assert response.output[1].status == "completed"
|
||||||
|
assert response.output[2].type == "message"
|
||||||
|
assert response.output[2].status == "completed"
|
||||||
|
assert response.output[2].role == "assistant"
|
||||||
|
|
||||||
|
# Next create a response that triggers web_search tools with max_tool_calls set to 1
|
||||||
|
response_2 = client.responses.create(
|
||||||
|
model=text_model_id,
|
||||||
|
input=input,
|
||||||
|
tools=tools,
|
||||||
|
stream=False,
|
||||||
|
max_tool_calls=max_tool_calls[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify we got one web search tool call followed by a message
|
||||||
|
assert len(response_2.output) == 2
|
||||||
|
assert response_2.output[0].type == "web_search_call"
|
||||||
|
assert response_2.output[0].status == "completed"
|
||||||
|
assert response_2.output[1].type == "message"
|
||||||
|
assert response_2.output[1].status == "completed"
|
||||||
|
assert response_2.output[1].role == "assistant"
|
||||||
|
|
||||||
|
# Verify we have a valid max_tool_calls field
|
||||||
|
assert response_2.max_tool_calls == max_tool_calls[0]
|
||||||
|
|
||||||
|
# Finally create a response that triggers web_search tools with max_tool_calls set to 5
|
||||||
|
response_3 = client.responses.create(
|
||||||
|
model=text_model_id,
|
||||||
|
input=input,
|
||||||
|
tools=tools,
|
||||||
|
stream=False,
|
||||||
|
max_tool_calls=max_tool_calls[1],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify we got two web search calls followed by a message
|
||||||
|
assert len(response_3.output) == 3
|
||||||
|
assert response_3.output[0].type == "web_search_call"
|
||||||
|
assert response_3.output[0].status == "completed"
|
||||||
|
assert response_3.output[1].type == "web_search_call"
|
||||||
|
assert response_3.output[1].status == "completed"
|
||||||
|
assert response_3.output[2].type == "message"
|
||||||
|
assert response_3.output[2].status == "completed"
|
||||||
|
assert response_3.output[2].role == "assistant"
|
||||||
|
|
||||||
|
# Verify we have a valid max_tool_calls field
|
||||||
|
assert response_3.max_tool_calls == max_tool_calls[1]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue