From 36d7abd4d5cadb2e1266b718b0823879a1cea25e Mon Sep 17 00:00:00 2001 From: Jaideep Rao Date: Sat, 22 Nov 2025 20:06:36 +0530 Subject: [PATCH] add unit tests Signed-off-by: Jaideep Rao --- client-sdks/stainless/openapi.yml | 12 +- docs/static/deprecated-llama-stack-spec.yaml | 12 +- .../static/experimental-llama-stack-spec.yaml | 12 +- docs/static/llama-stack-spec.yaml | 12 +- docs/static/stainless-llama-stack-spec.yaml | 12 +- .../responses/openai_responses.py | 2 +- .../meta_reference/responses/streaming.py | 50 +-- .../agents/meta_reference/responses/types.py | 9 +- src/llama_stack_api/inference.py | 5 + src/llama_stack_api/openai_responses.py | 13 +- .../test_response_tool_context.py | 338 +++++++++++++++++- 11 files changed, 392 insertions(+), 85 deletions(-) diff --git a/client-sdks/stainless/openapi.yml b/client-sdks/stainless/openapi.yml index 632043fbc..7c819a79e 100644 --- a/client-sdks/stainless/openapi.yml +++ b/client-sdks/stainless/openapi.yml @@ -12002,18 +12002,9 @@ components: - web_search_preview - web_search_preview_2025_03_11 - web_search_2025_08_26 - search_context_size: - anyOf: - - type: string - pattern: ^low|medium|high$ - - type: 'null' - default: medium type: object title: OpenAIResponseInputToolChoiceWebSearch - description: |- - Indicates that the model should use web search to generate a response. - - This is an alias for OpenAIResponseInputToolWebSearch used in tool_choice context. + description: Indicates that the model should use web search to generate a response OpenAIResponseMessage-Input: properties: content: @@ -12420,6 +12411,7 @@ components: title: AllowedToolsConfig type: object CustomToolConfig: + description: Custom tool configuration for OpenAI-compatible chat completion requests. properties: name: title: Name diff --git a/docs/static/deprecated-llama-stack-spec.yaml b/docs/static/deprecated-llama-stack-spec.yaml index a2c70372f..8ad6b51ab 100644 --- a/docs/static/deprecated-llama-stack-spec.yaml +++ b/docs/static/deprecated-llama-stack-spec.yaml @@ -8996,18 +8996,9 @@ components: - web_search_preview - web_search_preview_2025_03_11 - web_search_2025_08_26 - search_context_size: - anyOf: - - type: string - pattern: ^low|medium|high$ - - type: 'null' - default: medium type: object title: OpenAIResponseInputToolChoiceWebSearch - description: |- - Indicates that the model should use web search to generate a response. - - This is an alias for OpenAIResponseInputToolWebSearch used in tool_choice context. + description: Indicates that the model should use web search to generate a response OpenAIResponseMessage-Input: properties: content: @@ -9414,6 +9405,7 @@ components: title: AllowedToolsConfig type: object CustomToolConfig: + description: Custom tool configuration for OpenAI-compatible chat completion requests. properties: name: title: Name diff --git a/docs/static/experimental-llama-stack-spec.yaml b/docs/static/experimental-llama-stack-spec.yaml index 681d81e2f..913bec2f6 100644 --- a/docs/static/experimental-llama-stack-spec.yaml +++ b/docs/static/experimental-llama-stack-spec.yaml @@ -7799,18 +7799,9 @@ components: - web_search_preview - web_search_preview_2025_03_11 - web_search_2025_08_26 - search_context_size: - anyOf: - - type: string - pattern: ^low|medium|high$ - - type: 'null' - default: medium type: object title: OpenAIResponseInputToolChoiceWebSearch - description: |- - Indicates that the model should use web search to generate a response. - - This is an alias for OpenAIResponseInputToolWebSearch used in tool_choice context. + description: Indicates that the model should use web search to generate a response OpenAIResponseMessage-Output: properties: content: @@ -8148,6 +8139,7 @@ components: title: AllowedToolsConfig type: object CustomToolConfig: + description: Custom tool configuration for OpenAI-compatible chat completion requests. properties: name: title: Name diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index 48c99ddb0..d5a3963f3 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -10438,18 +10438,9 @@ components: - web_search_preview - web_search_preview_2025_03_11 - web_search_2025_08_26 - search_context_size: - anyOf: - - type: string - pattern: ^low|medium|high$ - - type: 'null' - default: medium type: object title: OpenAIResponseInputToolChoiceWebSearch - description: |- - Indicates that the model should use web search to generate a response. - - This is an alias for OpenAIResponseInputToolWebSearch used in tool_choice context. + description: Indicates that the model should use web search to generate a response OpenAIResponseMessage-Input: properties: content: @@ -10856,6 +10847,7 @@ components: title: AllowedToolsConfig type: object CustomToolConfig: + description: Custom tool configuration for OpenAI-compatible chat completion requests. properties: name: title: Name diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index 632043fbc..7c819a79e 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -12002,18 +12002,9 @@ components: - web_search_preview - web_search_preview_2025_03_11 - web_search_2025_08_26 - search_context_size: - anyOf: - - type: string - pattern: ^low|medium|high$ - - type: 'null' - default: medium type: object title: OpenAIResponseInputToolChoiceWebSearch - description: |- - Indicates that the model should use web search to generate a response. - - This is an alias for OpenAIResponseInputToolWebSearch used in tool_choice context. + description: Indicates that the model should use web search to generate a response OpenAIResponseMessage-Input: properties: content: @@ -12420,6 +12411,7 @@ components: title: AllowedToolsConfig type: object CustomToolConfig: + description: Custom tool configuration for OpenAI-compatible chat completion requests. properties: name: title: Name diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py index d84cac285..7ad8b406f 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py @@ -473,7 +473,7 @@ class OpenAIResponsesImpl: model=model, messages=messages, response_tools=tools, - responses_tool_choice=tool_choice, + tool_choice=tool_choice, temperature=temperature, response_format=response_format, tool_context=tool_context, diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index 69554ac91..a4dda8303 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -8,8 +8,8 @@ import uuid from collections.abc import AsyncIterator from typing import Any -from opentelemetry import trace from openai.types.chat import ChatCompletionToolParam +from opentelemetry import trace from llama_stack.log import get_logger from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str @@ -217,7 +217,7 @@ class StreamingResponseOrchestrator: output=self._clone_outputs(outputs), text=self.text, tools=self.ctx.available_tools(), - tool_choice=self.ctx.responses_tool_choice, + tool_choice=self.ctx.tool_choice, error=error, usage=self.accumulated_usage, instructions=self.instructions, @@ -253,17 +253,18 @@ class StreamingResponseOrchestrator: async for stream_event in self._process_tools(output_messages): yield stream_event - if self.ctx.responses_tool_choice and len(self.ctx.chat_tools) > 0: - chat_tool_choice = await _process_tool_choice( + chat_tool_choice = None + if self.ctx.tool_choice and len(self.ctx.chat_tools) > 0: + processed_tool_choice = await _process_tool_choice( self.ctx.chat_tools, - self.ctx.responses_tool_choice, + self.ctx.tool_choice, self.server_label_to_tools, ) # chat_tool_choice can be str, dict-like object, or None - if isinstance(chat_tool_choice, str): - self.ctx.chat_tool_choice = chat_tool_choice + if isinstance(processed_tool_choice, str | type(None)): + chat_tool_choice = processed_tool_choice else: - self.ctx.chat_tool_choice = chat_tool_choice.model_dump() + chat_tool_choice = processed_tool_choice.model_dump() n_iter = 0 messages = self.ctx.messages.copy() @@ -284,7 +285,7 @@ class StreamingResponseOrchestrator: messages=messages, # Pydantic models are dict-compatible but mypy treats them as distinct types tools=self.ctx.chat_tools, # type: ignore[arg-type] - tool_choice=self.ctx.chat_tool_choice, + tool_choice=chat_tool_choice, stream=True, temperature=self.ctx.temperature, response_format=response_format, @@ -363,8 +364,8 @@ class StreamingResponseOrchestrator: n_iter += 1 # After first iteration, reset tool_choice to "auto" to let model decide freely # based on tool results (prevents infinite loops when forcing specific tools) - if n_iter == 1 and self.ctx.chat_tool_choice: - self.ctx.chat_tool_choice = "auto" + if n_iter == 1 and chat_tool_choice: + chat_tool_choice = "auto" if n_iter >= self.max_infer_iters: logger.info( f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}" @@ -1332,13 +1333,13 @@ class StreamingResponseOrchestrator: async def _process_tool_choice( chat_tools: list[ChatCompletionToolParam], - responses_tool_choice: OpenAIResponseInputToolChoice, + tool_choice: OpenAIResponseInputToolChoice, server_label_to_tools: dict[str, list[str]], ) -> str | OpenAIChatCompletionToolChoice | None: """Process and validate the OpenAI Responses tool choice and return the appropriate chat completion tool choice object. :param chat_tools: The list of chat tools to enforce tool choice against. - :param responses_tool_choice: The OpenAI Responses tool choice to process. + :param tool_choice: The OpenAI Responses tool choice to process. :param server_label_to_tools: A dictionary mapping server labels to the list of tools available on that server. :return: The appropriate chat completion tool choice object. """ @@ -1347,8 +1348,8 @@ async def _process_tool_choice( # Note: chat_tools contains dicts, not objects chat_tool_names = [tool["function"]["name"] for tool in chat_tools if tool["type"] == "function"] - if isinstance(responses_tool_choice, OpenAIResponseInputToolChoiceMode): - if responses_tool_choice.value == "required": + if isinstance(tool_choice, OpenAIResponseInputToolChoiceMode): + if tool_choice.value == "required": if len(chat_tool_names) == 0: return None @@ -1358,18 +1359,17 @@ async def _process_tool_choice( mode="required", ) # return other modes as is - return responses_tool_choice.value + return tool_choice.value - elif isinstance(responses_tool_choice, OpenAIResponseInputToolChoiceAllowedTools): + elif isinstance(tool_choice, OpenAIResponseInputToolChoiceAllowedTools): # ensure that specified tool choices are available in the chat tools, if not, remove them from the list final_tools = [] - for tool in responses_tool_choice.tools: - tool_name = tool.get("name") + for tool in tool_choice.tools: match tool.get("type"): case "function": - final_tools.append({"type": "function", "function": {"name": tool_name}}) + final_tools.append({"type": "function", "function": {"name": tool.get("name")}}) case "custom": - final_tools.append({"type": "custom", "custom": {"name": tool_name}}) + final_tools.append({"type": "custom", "custom": {"name": tool.get("name")}}) case "mcp": mcp_tools = convert_mcp_tool_choice( chat_tool_names, tool.get("server_label"), server_label_to_tools, None @@ -1390,14 +1390,14 @@ async def _process_tool_choice( return OpenAIChatCompletionToolChoiceAllowedTools( tools=final_tools, - mode=responses_tool_choice.mode, + mode=tool_choice.mode, ) else: # Handle specific tool choice by type # Each case validates the tool exists in chat_tools before returning - tool_name = responses_tool_choice.name if responses_tool_choice.name else None - match responses_tool_choice: + tool_name = getattr(tool_choice, "name", None) + match tool_choice: case OpenAIResponseInputToolChoiceCustomTool(): if tool_name and tool_name not in chat_tool_names: logger.warning(f"Tool {tool_name} not found in chat tools") @@ -1425,7 +1425,7 @@ async def _process_tool_choice( case OpenAIResponseInputToolChoiceMCPTool(): tool_choice = convert_mcp_tool_choice( chat_tool_names, - responses_tool_choice.server_label, + tool_choice.server_label, server_label_to_tools, tool_name, ) diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/types.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/types.py index bc0dbbade..614bdace1 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/types.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/types.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from dataclasses import dataclass -from typing import Any, cast +from typing import cast from openai.types.chat import ChatCompletionToolParam from pydantic import BaseModel @@ -161,8 +161,7 @@ class ChatCompletionContext(BaseModel): temperature: float | None response_format: OpenAIResponseFormatParam tool_context: ToolContext | None - responses_tool_choice: OpenAIResponseInputToolChoice | None = None - chat_tool_choice: str | dict[str, Any] | None = None + tool_choice: OpenAIResponseInputToolChoice | None = None approval_requests: list[OpenAIResponseMCPApprovalRequest] = [] approval_responses: dict[str, OpenAIResponseMCPApprovalResponse] = {} @@ -175,7 +174,7 @@ class ChatCompletionContext(BaseModel): response_format: OpenAIResponseFormatParam, tool_context: ToolContext, inputs: list[OpenAIResponseInput] | str, - responses_tool_choice: OpenAIResponseInputToolChoice | None = None, + tool_choice: OpenAIResponseInputToolChoice | None = None, ): super().__init__( model=model, @@ -184,7 +183,7 @@ class ChatCompletionContext(BaseModel): temperature=temperature, response_format=response_format, tool_context=tool_context, - responses_tool_choice=responses_tool_choice, + tool_choice=tool_choice, ) if not isinstance(inputs, str): self.approval_requests = [input for input in inputs if input.type == "mcp_approval_request"] diff --git a/src/llama_stack_api/inference.py b/src/llama_stack_api/inference.py index f506bce53..aeddf34c5 100644 --- a/src/llama_stack_api/inference.py +++ b/src/llama_stack_api/inference.py @@ -577,6 +577,11 @@ class OpenAIChatCompletionToolChoiceFunctionTool(BaseModel): @json_schema_type class CustomToolConfig(BaseModel): + """Custom tool configuration for OpenAI-compatible chat completion requests. + + :param name: Name of the custom tool + """ + name: str diff --git a/src/llama_stack_api/openai_responses.py b/src/llama_stack_api/openai_responses.py index 0c427c4ea..8dc372394 100644 --- a/src/llama_stack_api/openai_responses.py +++ b/src/llama_stack_api/openai_responses.py @@ -565,12 +565,19 @@ class OpenAIResponseInputToolChoiceFileSearch(BaseModel): @json_schema_type -class OpenAIResponseInputToolChoiceWebSearch(OpenAIResponseInputToolWebSearch): - """Indicates that the model should use web search to generate a response. +class OpenAIResponseInputToolChoiceWebSearch(BaseModel): + """Indicates that the model should use web search to generate a response - This is an alias for OpenAIResponseInputToolWebSearch used in tool_choice context. + :param type: Web search tool type variant to use """ + type: ( + Literal["web_search"] + | Literal["web_search_preview"] + | Literal["web_search_preview_2025_03_11"] + | Literal["web_search_2025_08_26"] + ) = "web_search" + @json_schema_type class OpenAIResponseInputToolChoiceFunctionTool(BaseModel): diff --git a/tests/unit/providers/agents/meta_reference/test_response_tool_context.py b/tests/unit/providers/agents/meta_reference/test_response_tool_context.py index 4054debd5..325844ae2 100644 --- a/tests/unit/providers/agents/meta_reference/test_response_tool_context.py +++ b/tests/unit/providers/agents/meta_reference/test_response_tool_context.py @@ -5,9 +5,22 @@ # the root directory of this source tree. +from llama_stack.providers.inline.agents.meta_reference.responses.streaming import ( + _process_tool_choice, +) from llama_stack.providers.inline.agents.meta_reference.responses.types import ToolContext -from llama_stack_api.openai_responses import ( +from llama_stack_api import ( MCPListToolsTool, + OpenAIChatCompletionToolChoiceAllowedTools, + OpenAIChatCompletionToolChoiceCustomTool, + OpenAIChatCompletionToolChoiceFunctionTool, + OpenAIResponseInputToolChoiceAllowedTools, + OpenAIResponseInputToolChoiceCustomTool, + OpenAIResponseInputToolChoiceFileSearch, + OpenAIResponseInputToolChoiceFunctionTool, + OpenAIResponseInputToolChoiceMCPTool, + OpenAIResponseInputToolChoiceMode, + OpenAIResponseInputToolChoiceWebSearch, OpenAIResponseInputToolFileSearch, OpenAIResponseInputToolFunction, OpenAIResponseInputToolMCP, @@ -181,3 +194,326 @@ class TestToolContext: assert len(context.previous_tool_listings) == 1 assert len(context.previous_tool_listings[0].tools) == 1 assert context.previous_tool_listings[0].server_label == "anotherlabel" + + +class TestProcessToolChoice: + """Comprehensive test suite for _process_tool_choice function.""" + + def setup_method(self): + """Set up common test fixtures.""" + self.chat_tools = [ + {"type": "function", "function": {"name": "get_weather"}}, + {"type": "function", "function": {"name": "calculate"}}, + {"type": "function", "function": {"name": "file_search"}}, + {"type": "function", "function": {"name": "web_search"}}, + ] + self.server_label_to_tools = { + "mcp_server_1": ["mcp_tool_1", "mcp_tool_2"], + "mcp_server_2": ["mcp_tool_3"], + } + + async def test_mode_auto(self): + """Test auto mode - should return 'auto' string.""" + tool_choice = OpenAIResponseInputToolChoiceMode.auto + result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools) + assert result == "auto" + + async def test_mode_none(self): + """Test none mode - should return 'none' string.""" + tool_choice = OpenAIResponseInputToolChoiceMode.none + result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools) + assert result == "none" + + async def test_mode_required_with_tools(self): + """Test required mode with available tools - should return AllowedTools with all function tools.""" + tool_choice = OpenAIResponseInputToolChoiceMode.required + result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools) + + assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools) + assert result.allowed_tools.mode == "required" + assert len(result.allowed_tools.tools) == 4 + tool_names = [tool["function"]["name"] for tool in result.allowed_tools.tools] + assert "get_weather" in tool_names + assert "calculate" in tool_names + assert "file_search" in tool_names + assert "web_search" in tool_names + + async def test_mode_required_without_tools(self): + """Test required mode without available tools - should return None.""" + tool_choice = OpenAIResponseInputToolChoiceMode.required + result = await _process_tool_choice([], tool_choice, self.server_label_to_tools) + assert result is None + + async def test_allowed_tools_function(self): + """Test allowed_tools with function tool types.""" + tool_choice = OpenAIResponseInputToolChoiceAllowedTools( + mode="required", + tools=[ + {"type": "function", "name": "get_weather"}, + {"type": "function", "name": "calculate"}, + ], + ) + result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools) + + assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools) + assert result.allowed_tools.mode == "required" + assert len(result.allowed_tools.tools) == 2 + assert result.allowed_tools.tools[0]["function"]["name"] == "get_weather" + assert result.allowed_tools.tools[1]["function"]["name"] == "calculate" + + async def test_allowed_tools_custom(self): + """Test allowed_tools with custom tool types.""" + chat_tools = [{"type": "function", "function": {"name": "custom_tool_1"}}] + tool_choice = OpenAIResponseInputToolChoiceAllowedTools( + mode="auto", + tools=[{"type": "custom", "name": "custom_tool_1"}], + ) + result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools) + + assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools) + assert result.allowed_tools.mode == "auto" + assert len(result.allowed_tools.tools) == 1 + assert result.allowed_tools.tools[0]["type"] == "custom" + assert result.allowed_tools.tools[0]["custom"]["name"] == "custom_tool_1" + + async def test_allowed_tools_file_search(self): + """Test allowed_tools with file_search.""" + tool_choice = OpenAIResponseInputToolChoiceAllowedTools( + mode="required", + tools=[{"type": "file_search"}], + ) + result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools) + + assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools) + assert len(result.allowed_tools.tools) == 1 + assert result.allowed_tools.tools[0]["function"]["name"] == "file_search" + + async def test_allowed_tools_web_search(self): + """Test allowed_tools with web_search.""" + tool_choice = OpenAIResponseInputToolChoiceAllowedTools( + mode="required", + tools=[ + {"type": "web_search_preview_2025_03_11"}, + {"type": "web_search_2025_08_26"}, + {"type": "web_search_preview"}, + ], + ) + result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools) + + assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools) + assert len(result.allowed_tools.tools) == 3 + assert result.allowed_tools.tools[0]["function"]["name"] == "web_search" + assert result.allowed_tools.tools[0]["type"] == "function" + assert result.allowed_tools.tools[1]["function"]["name"] == "web_search" + assert result.allowed_tools.tools[1]["type"] == "function" + assert result.allowed_tools.tools[2]["function"]["name"] == "web_search" + assert result.allowed_tools.tools[2]["type"] == "function" + + async def test_allowed_tools_mcp_server_label(self): + """Test allowed_tools with MCP server label (no specific tool name).""" + chat_tools = [ + {"type": "function", "function": {"name": "mcp_tool_1"}}, + {"type": "function", "function": {"name": "mcp_tool_2"}}, + ] + tool_choice = OpenAIResponseInputToolChoiceAllowedTools( + mode="required", + tools=[{"type": "mcp", "server_label": "mcp_server_1"}], + ) + result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools) + + assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools) + assert len(result.allowed_tools.tools) == 2 + tool_names = [tool["function"]["name"] for tool in result.allowed_tools.tools] + assert "mcp_tool_1" in tool_names + assert "mcp_tool_2" in tool_names + + async def test_allowed_tools_mixed_types(self): + """Test allowed_tools with mixed tool types.""" + chat_tools = [ + {"type": "function", "function": {"name": "get_weather"}}, + {"type": "function", "function": {"name": "file_search"}}, + {"type": "function", "function": {"name": "web_search"}}, + {"type": "function", "function": {"name": "mcp_tool_1"}}, + ] + tool_choice = OpenAIResponseInputToolChoiceAllowedTools( + mode="auto", + tools=[ + {"type": "function", "name": "get_weather"}, + {"type": "file_search"}, + {"type": "web_search"}, + {"type": "mcp", "server_label": "mcp_server_1"}, + ], + ) + result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools) + + assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools) + # Should have: get_weather, file_search, web_search, mcp_tool_1, mcp_tool_2 + assert len(result.allowed_tools.tools) >= 3 + + async def test_allowed_tools_invalid_type(self): + """Test allowed_tools with an unsupported tool type - should skip it.""" + tool_choice = OpenAIResponseInputToolChoiceAllowedTools( + mode="required", + tools=[ + {"type": "function", "name": "get_weather"}, + {"type": "unsupported_type", "name": "bad_tool"}, + ], + ) + result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools) + + assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools) + # Should only include the valid function tool + assert len(result.allowed_tools.tools) == 1 + assert result.allowed_tools.tools[0]["function"]["name"] == "get_weather" + + async def test_specific_custom_tool_valid(self): + """Test specific custom tool choice when tool exists.""" + chat_tools = [{"type": "function", "function": {"name": "custom_tool"}}] + tool_choice = OpenAIResponseInputToolChoiceCustomTool(name="custom_tool") + result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools) + + assert isinstance(result, OpenAIChatCompletionToolChoiceCustomTool) + assert result.custom.name == "custom_tool" + + async def test_specific_custom_tool_invalid(self): + """Test specific custom tool choice when tool doesn't exist - should return None.""" + tool_choice = OpenAIResponseInputToolChoiceCustomTool(name="nonexistent_tool") + result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools) + assert result is None + + async def test_specific_function_tool_valid(self): + """Test specific function tool choice when tool exists.""" + tool_choice = OpenAIResponseInputToolChoiceFunctionTool(name="get_weather") + result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools) + + assert isinstance(result, OpenAIChatCompletionToolChoiceFunctionTool) + assert result.function.name == "get_weather" + + async def test_specific_function_tool_invalid(self): + """Test specific function tool choice when tool doesn't exist - should return None.""" + tool_choice = OpenAIResponseInputToolChoiceFunctionTool(name="nonexistent_function") + result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools) + assert result is None + + async def test_specific_file_search_valid(self): + """Test file_search tool choice when available.""" + tool_choice = OpenAIResponseInputToolChoiceFileSearch() + result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools) + + assert isinstance(result, OpenAIChatCompletionToolChoiceFunctionTool) + assert result.function.name == "file_search" + + async def test_specific_file_search_invalid(self): + """Test file_search tool choice when not available - should return None.""" + chat_tools = [{"type": "function", "function": {"name": "get_weather"}}] + tool_choice = OpenAIResponseInputToolChoiceFileSearch() + result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools) + assert result is None + + async def test_specific_web_search_valid(self): + """Test web_search tool choice when available.""" + tool_choice = OpenAIResponseInputToolChoiceWebSearch() + result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools) + + assert isinstance(result, OpenAIChatCompletionToolChoiceFunctionTool) + assert result.function.name == "web_search" + + async def test_specific_web_search_invalid(self): + """Test web_search tool choice when not available - should return None.""" + chat_tools = [{"type": "function", "function": {"name": "get_weather"}}] + tool_choice = OpenAIResponseInputToolChoiceWebSearch() + result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools) + assert result is None + + async def test_specific_mcp_tool_with_name(self): + """Test MCP tool choice with specific tool name.""" + chat_tools = [{"type": "function", "function": {"name": "mcp_tool_1"}}] + tool_choice = OpenAIResponseInputToolChoiceMCPTool( + server_label="mcp_server_1", + name="mcp_tool_1", + ) + result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools) + + assert isinstance(result, OpenAIChatCompletionToolChoiceFunctionTool) + assert result.function.name == "mcp_tool_1" + + async def test_specific_mcp_tool_with_name_not_in_chat_tools(self): + """Test MCP tool choice with specific tool name that doesn't exist in chat_tools.""" + chat_tools = [{"type": "function", "function": {"name": "other_tool"}}] + tool_choice = OpenAIResponseInputToolChoiceMCPTool( + server_label="mcp_server_1", + name="mcp_tool_1", + ) + result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools) + assert result is None + + async def test_specific_mcp_tool_server_label_only(self): + """Test MCP tool choice with only server label (no specific tool name).""" + chat_tools = [ + {"type": "function", "function": {"name": "mcp_tool_1"}}, + {"type": "function", "function": {"name": "mcp_tool_2"}}, + ] + tool_choice = OpenAIResponseInputToolChoiceMCPTool(server_label="mcp_server_1") + result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools) + + assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools) + assert result.allowed_tools.mode == "required" + assert len(result.allowed_tools.tools) == 2 + tool_names = [tool["function"]["name"] for tool in result.allowed_tools.tools] + assert "mcp_tool_1" in tool_names + assert "mcp_tool_2" in tool_names + + async def test_specific_mcp_tool_unknown_server(self): + """Test MCP tool choice with unknown server label.""" + tool_choice = OpenAIResponseInputToolChoiceMCPTool( + server_label="unknown_server", + name="some_tool", + ) + result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools) + # Should return None because server not found + assert result is None + + async def test_empty_chat_tools(self): + """Test with empty chat_tools list.""" + tool_choice = OpenAIResponseInputToolChoiceFunctionTool(name="get_weather") + result = await _process_tool_choice([], tool_choice, self.server_label_to_tools) + assert result is None + + async def test_empty_server_label_to_tools(self): + """Test with empty server_label_to_tools mapping.""" + tool_choice = OpenAIResponseInputToolChoiceMCPTool(server_label="mcp_server_1") + result = await _process_tool_choice(self.chat_tools, tool_choice, {}) + # Should handle gracefully + assert result is None or isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools) + + async def test_allowed_tools_empty_list(self): + """Test allowed_tools with empty tools list.""" + tool_choice = OpenAIResponseInputToolChoiceAllowedTools(mode="auto", tools=[]) + result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools) + + assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools) + assert len(result.allowed_tools.tools) == 0 + + async def test_mcp_tool_multiple_servers(self): + """Test MCP tool choice with multiple server labels.""" + chat_tools = [ + {"type": "function", "function": {"name": "mcp_tool_1"}}, + {"type": "function", "function": {"name": "mcp_tool_2"}}, + {"type": "function", "function": {"name": "mcp_tool_3"}}, + ] + server_label_to_tools = { + "server_a": ["mcp_tool_1"], + "server_b": ["mcp_tool_2", "mcp_tool_3"], + } + + # Test server_a + tool_choice_a = OpenAIResponseInputToolChoiceMCPTool(server_label="server_a") + result_a = await _process_tool_choice(chat_tools, tool_choice_a, server_label_to_tools) + assert isinstance(result_a, OpenAIChatCompletionToolChoiceAllowedTools) + assert len(result_a.allowed_tools.tools) == 1 + + # Test server_b + tool_choice_b = OpenAIResponseInputToolChoiceMCPTool(server_label="server_b") + result_b = await _process_tool_choice(chat_tools, tool_choice_b, server_label_to_tools) + assert isinstance(result_b, OpenAIChatCompletionToolChoiceAllowedTools) + assert len(result_b.allowed_tools.tools) == 2