mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 01:48:05 +00:00
add unit tests
Signed-off-by: Jaideep Rao <jrao@redhat.com>
This commit is contained in:
parent
d2d2c88921
commit
36d7abd4d5
11 changed files with 392 additions and 85 deletions
|
|
@ -12002,18 +12002,9 @@ components:
|
|||
- web_search_preview
|
||||
- web_search_preview_2025_03_11
|
||||
- web_search_2025_08_26
|
||||
search_context_size:
|
||||
anyOf:
|
||||
- type: string
|
||||
pattern: ^low|medium|high$
|
||||
- type: 'null'
|
||||
default: medium
|
||||
type: object
|
||||
title: OpenAIResponseInputToolChoiceWebSearch
|
||||
description: |-
|
||||
Indicates that the model should use web search to generate a response.
|
||||
|
||||
This is an alias for OpenAIResponseInputToolWebSearch used in tool_choice context.
|
||||
description: Indicates that the model should use web search to generate a response
|
||||
OpenAIResponseMessage-Input:
|
||||
properties:
|
||||
content:
|
||||
|
|
@ -12420,6 +12411,7 @@ components:
|
|||
title: AllowedToolsConfig
|
||||
type: object
|
||||
CustomToolConfig:
|
||||
description: Custom tool configuration for OpenAI-compatible chat completion requests.
|
||||
properties:
|
||||
name:
|
||||
title: Name
|
||||
|
|
|
|||
12
docs/static/deprecated-llama-stack-spec.yaml
vendored
12
docs/static/deprecated-llama-stack-spec.yaml
vendored
|
|
@ -8996,18 +8996,9 @@ components:
|
|||
- web_search_preview
|
||||
- web_search_preview_2025_03_11
|
||||
- web_search_2025_08_26
|
||||
search_context_size:
|
||||
anyOf:
|
||||
- type: string
|
||||
pattern: ^low|medium|high$
|
||||
- type: 'null'
|
||||
default: medium
|
||||
type: object
|
||||
title: OpenAIResponseInputToolChoiceWebSearch
|
||||
description: |-
|
||||
Indicates that the model should use web search to generate a response.
|
||||
|
||||
This is an alias for OpenAIResponseInputToolWebSearch used in tool_choice context.
|
||||
description: Indicates that the model should use web search to generate a response
|
||||
OpenAIResponseMessage-Input:
|
||||
properties:
|
||||
content:
|
||||
|
|
@ -9414,6 +9405,7 @@ components:
|
|||
title: AllowedToolsConfig
|
||||
type: object
|
||||
CustomToolConfig:
|
||||
description: Custom tool configuration for OpenAI-compatible chat completion requests.
|
||||
properties:
|
||||
name:
|
||||
title: Name
|
||||
|
|
|
|||
12
docs/static/experimental-llama-stack-spec.yaml
vendored
12
docs/static/experimental-llama-stack-spec.yaml
vendored
|
|
@ -7799,18 +7799,9 @@ components:
|
|||
- web_search_preview
|
||||
- web_search_preview_2025_03_11
|
||||
- web_search_2025_08_26
|
||||
search_context_size:
|
||||
anyOf:
|
||||
- type: string
|
||||
pattern: ^low|medium|high$
|
||||
- type: 'null'
|
||||
default: medium
|
||||
type: object
|
||||
title: OpenAIResponseInputToolChoiceWebSearch
|
||||
description: |-
|
||||
Indicates that the model should use web search to generate a response.
|
||||
|
||||
This is an alias for OpenAIResponseInputToolWebSearch used in tool_choice context.
|
||||
description: Indicates that the model should use web search to generate a response
|
||||
OpenAIResponseMessage-Output:
|
||||
properties:
|
||||
content:
|
||||
|
|
@ -8148,6 +8139,7 @@ components:
|
|||
title: AllowedToolsConfig
|
||||
type: object
|
||||
CustomToolConfig:
|
||||
description: Custom tool configuration for OpenAI-compatible chat completion requests.
|
||||
properties:
|
||||
name:
|
||||
title: Name
|
||||
|
|
|
|||
12
docs/static/llama-stack-spec.yaml
vendored
12
docs/static/llama-stack-spec.yaml
vendored
|
|
@ -10438,18 +10438,9 @@ components:
|
|||
- web_search_preview
|
||||
- web_search_preview_2025_03_11
|
||||
- web_search_2025_08_26
|
||||
search_context_size:
|
||||
anyOf:
|
||||
- type: string
|
||||
pattern: ^low|medium|high$
|
||||
- type: 'null'
|
||||
default: medium
|
||||
type: object
|
||||
title: OpenAIResponseInputToolChoiceWebSearch
|
||||
description: |-
|
||||
Indicates that the model should use web search to generate a response.
|
||||
|
||||
This is an alias for OpenAIResponseInputToolWebSearch used in tool_choice context.
|
||||
description: Indicates that the model should use web search to generate a response
|
||||
OpenAIResponseMessage-Input:
|
||||
properties:
|
||||
content:
|
||||
|
|
@ -10856,6 +10847,7 @@ components:
|
|||
title: AllowedToolsConfig
|
||||
type: object
|
||||
CustomToolConfig:
|
||||
description: Custom tool configuration for OpenAI-compatible chat completion requests.
|
||||
properties:
|
||||
name:
|
||||
title: Name
|
||||
|
|
|
|||
12
docs/static/stainless-llama-stack-spec.yaml
vendored
12
docs/static/stainless-llama-stack-spec.yaml
vendored
|
|
@ -12002,18 +12002,9 @@ components:
|
|||
- web_search_preview
|
||||
- web_search_preview_2025_03_11
|
||||
- web_search_2025_08_26
|
||||
search_context_size:
|
||||
anyOf:
|
||||
- type: string
|
||||
pattern: ^low|medium|high$
|
||||
- type: 'null'
|
||||
default: medium
|
||||
type: object
|
||||
title: OpenAIResponseInputToolChoiceWebSearch
|
||||
description: |-
|
||||
Indicates that the model should use web search to generate a response.
|
||||
|
||||
This is an alias for OpenAIResponseInputToolWebSearch used in tool_choice context.
|
||||
description: Indicates that the model should use web search to generate a response
|
||||
OpenAIResponseMessage-Input:
|
||||
properties:
|
||||
content:
|
||||
|
|
@ -12420,6 +12411,7 @@ components:
|
|||
title: AllowedToolsConfig
|
||||
type: object
|
||||
CustomToolConfig:
|
||||
description: Custom tool configuration for OpenAI-compatible chat completion requests.
|
||||
properties:
|
||||
name:
|
||||
title: Name
|
||||
|
|
|
|||
|
|
@ -473,7 +473,7 @@ class OpenAIResponsesImpl:
|
|||
model=model,
|
||||
messages=messages,
|
||||
response_tools=tools,
|
||||
responses_tool_choice=tool_choice,
|
||||
tool_choice=tool_choice,
|
||||
temperature=temperature,
|
||||
response_format=response_format,
|
||||
tool_context=tool_context,
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ import uuid
|
|||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from opentelemetry import trace
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
from opentelemetry import trace
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
|
|
@ -217,7 +217,7 @@ class StreamingResponseOrchestrator:
|
|||
output=self._clone_outputs(outputs),
|
||||
text=self.text,
|
||||
tools=self.ctx.available_tools(),
|
||||
tool_choice=self.ctx.responses_tool_choice,
|
||||
tool_choice=self.ctx.tool_choice,
|
||||
error=error,
|
||||
usage=self.accumulated_usage,
|
||||
instructions=self.instructions,
|
||||
|
|
@ -253,17 +253,18 @@ class StreamingResponseOrchestrator:
|
|||
async for stream_event in self._process_tools(output_messages):
|
||||
yield stream_event
|
||||
|
||||
if self.ctx.responses_tool_choice and len(self.ctx.chat_tools) > 0:
|
||||
chat_tool_choice = await _process_tool_choice(
|
||||
chat_tool_choice = None
|
||||
if self.ctx.tool_choice and len(self.ctx.chat_tools) > 0:
|
||||
processed_tool_choice = await _process_tool_choice(
|
||||
self.ctx.chat_tools,
|
||||
self.ctx.responses_tool_choice,
|
||||
self.ctx.tool_choice,
|
||||
self.server_label_to_tools,
|
||||
)
|
||||
# chat_tool_choice can be str, dict-like object, or None
|
||||
if isinstance(chat_tool_choice, str):
|
||||
self.ctx.chat_tool_choice = chat_tool_choice
|
||||
if isinstance(processed_tool_choice, str | type(None)):
|
||||
chat_tool_choice = processed_tool_choice
|
||||
else:
|
||||
self.ctx.chat_tool_choice = chat_tool_choice.model_dump()
|
||||
chat_tool_choice = processed_tool_choice.model_dump()
|
||||
|
||||
n_iter = 0
|
||||
messages = self.ctx.messages.copy()
|
||||
|
|
@ -284,7 +285,7 @@ class StreamingResponseOrchestrator:
|
|||
messages=messages,
|
||||
# Pydantic models are dict-compatible but mypy treats them as distinct types
|
||||
tools=self.ctx.chat_tools, # type: ignore[arg-type]
|
||||
tool_choice=self.ctx.chat_tool_choice,
|
||||
tool_choice=chat_tool_choice,
|
||||
stream=True,
|
||||
temperature=self.ctx.temperature,
|
||||
response_format=response_format,
|
||||
|
|
@ -363,8 +364,8 @@ class StreamingResponseOrchestrator:
|
|||
n_iter += 1
|
||||
# After first iteration, reset tool_choice to "auto" to let model decide freely
|
||||
# based on tool results (prevents infinite loops when forcing specific tools)
|
||||
if n_iter == 1 and self.ctx.chat_tool_choice:
|
||||
self.ctx.chat_tool_choice = "auto"
|
||||
if n_iter == 1 and chat_tool_choice:
|
||||
chat_tool_choice = "auto"
|
||||
if n_iter >= self.max_infer_iters:
|
||||
logger.info(
|
||||
f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}"
|
||||
|
|
@ -1332,13 +1333,13 @@ class StreamingResponseOrchestrator:
|
|||
|
||||
async def _process_tool_choice(
|
||||
chat_tools: list[ChatCompletionToolParam],
|
||||
responses_tool_choice: OpenAIResponseInputToolChoice,
|
||||
tool_choice: OpenAIResponseInputToolChoice,
|
||||
server_label_to_tools: dict[str, list[str]],
|
||||
) -> str | OpenAIChatCompletionToolChoice | None:
|
||||
"""Process and validate the OpenAI Responses tool choice and return the appropriate chat completion tool choice object.
|
||||
|
||||
:param chat_tools: The list of chat tools to enforce tool choice against.
|
||||
:param responses_tool_choice: The OpenAI Responses tool choice to process.
|
||||
:param tool_choice: The OpenAI Responses tool choice to process.
|
||||
:param server_label_to_tools: A dictionary mapping server labels to the list of tools available on that server.
|
||||
:return: The appropriate chat completion tool choice object.
|
||||
"""
|
||||
|
|
@ -1347,8 +1348,8 @@ async def _process_tool_choice(
|
|||
# Note: chat_tools contains dicts, not objects
|
||||
chat_tool_names = [tool["function"]["name"] for tool in chat_tools if tool["type"] == "function"]
|
||||
|
||||
if isinstance(responses_tool_choice, OpenAIResponseInputToolChoiceMode):
|
||||
if responses_tool_choice.value == "required":
|
||||
if isinstance(tool_choice, OpenAIResponseInputToolChoiceMode):
|
||||
if tool_choice.value == "required":
|
||||
if len(chat_tool_names) == 0:
|
||||
return None
|
||||
|
||||
|
|
@ -1358,18 +1359,17 @@ async def _process_tool_choice(
|
|||
mode="required",
|
||||
)
|
||||
# return other modes as is
|
||||
return responses_tool_choice.value
|
||||
return tool_choice.value
|
||||
|
||||
elif isinstance(responses_tool_choice, OpenAIResponseInputToolChoiceAllowedTools):
|
||||
elif isinstance(tool_choice, OpenAIResponseInputToolChoiceAllowedTools):
|
||||
# ensure that specified tool choices are available in the chat tools, if not, remove them from the list
|
||||
final_tools = []
|
||||
for tool in responses_tool_choice.tools:
|
||||
tool_name = tool.get("name")
|
||||
for tool in tool_choice.tools:
|
||||
match tool.get("type"):
|
||||
case "function":
|
||||
final_tools.append({"type": "function", "function": {"name": tool_name}})
|
||||
final_tools.append({"type": "function", "function": {"name": tool.get("name")}})
|
||||
case "custom":
|
||||
final_tools.append({"type": "custom", "custom": {"name": tool_name}})
|
||||
final_tools.append({"type": "custom", "custom": {"name": tool.get("name")}})
|
||||
case "mcp":
|
||||
mcp_tools = convert_mcp_tool_choice(
|
||||
chat_tool_names, tool.get("server_label"), server_label_to_tools, None
|
||||
|
|
@ -1390,14 +1390,14 @@ async def _process_tool_choice(
|
|||
|
||||
return OpenAIChatCompletionToolChoiceAllowedTools(
|
||||
tools=final_tools,
|
||||
mode=responses_tool_choice.mode,
|
||||
mode=tool_choice.mode,
|
||||
)
|
||||
|
||||
else:
|
||||
# Handle specific tool choice by type
|
||||
# Each case validates the tool exists in chat_tools before returning
|
||||
tool_name = responses_tool_choice.name if responses_tool_choice.name else None
|
||||
match responses_tool_choice:
|
||||
tool_name = getattr(tool_choice, "name", None)
|
||||
match tool_choice:
|
||||
case OpenAIResponseInputToolChoiceCustomTool():
|
||||
if tool_name and tool_name not in chat_tool_names:
|
||||
logger.warning(f"Tool {tool_name} not found in chat tools")
|
||||
|
|
@ -1425,7 +1425,7 @@ async def _process_tool_choice(
|
|||
case OpenAIResponseInputToolChoiceMCPTool():
|
||||
tool_choice = convert_mcp_tool_choice(
|
||||
chat_tool_names,
|
||||
responses_tool_choice.server_label,
|
||||
tool_choice.server_label,
|
||||
server_label_to_tools,
|
||||
tool_name,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
from typing import cast
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -161,8 +161,7 @@ class ChatCompletionContext(BaseModel):
|
|||
temperature: float | None
|
||||
response_format: OpenAIResponseFormatParam
|
||||
tool_context: ToolContext | None
|
||||
responses_tool_choice: OpenAIResponseInputToolChoice | None = None
|
||||
chat_tool_choice: str | dict[str, Any] | None = None
|
||||
tool_choice: OpenAIResponseInputToolChoice | None = None
|
||||
approval_requests: list[OpenAIResponseMCPApprovalRequest] = []
|
||||
approval_responses: dict[str, OpenAIResponseMCPApprovalResponse] = {}
|
||||
|
||||
|
|
@ -175,7 +174,7 @@ class ChatCompletionContext(BaseModel):
|
|||
response_format: OpenAIResponseFormatParam,
|
||||
tool_context: ToolContext,
|
||||
inputs: list[OpenAIResponseInput] | str,
|
||||
responses_tool_choice: OpenAIResponseInputToolChoice | None = None,
|
||||
tool_choice: OpenAIResponseInputToolChoice | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
model=model,
|
||||
|
|
@ -184,7 +183,7 @@ class ChatCompletionContext(BaseModel):
|
|||
temperature=temperature,
|
||||
response_format=response_format,
|
||||
tool_context=tool_context,
|
||||
responses_tool_choice=responses_tool_choice,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
if not isinstance(inputs, str):
|
||||
self.approval_requests = [input for input in inputs if input.type == "mcp_approval_request"]
|
||||
|
|
|
|||
|
|
@ -577,6 +577,11 @@ class OpenAIChatCompletionToolChoiceFunctionTool(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class CustomToolConfig(BaseModel):
|
||||
"""Custom tool configuration for OpenAI-compatible chat completion requests.
|
||||
|
||||
:param name: Name of the custom tool
|
||||
"""
|
||||
|
||||
name: str
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -565,12 +565,19 @@ class OpenAIResponseInputToolChoiceFileSearch(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolChoiceWebSearch(OpenAIResponseInputToolWebSearch):
|
||||
"""Indicates that the model should use web search to generate a response.
|
||||
class OpenAIResponseInputToolChoiceWebSearch(BaseModel):
|
||||
"""Indicates that the model should use web search to generate a response
|
||||
|
||||
This is an alias for OpenAIResponseInputToolWebSearch used in tool_choice context.
|
||||
:param type: Web search tool type variant to use
|
||||
"""
|
||||
|
||||
type: (
|
||||
Literal["web_search"]
|
||||
| Literal["web_search_preview"]
|
||||
| Literal["web_search_preview_2025_03_11"]
|
||||
| Literal["web_search_2025_08_26"]
|
||||
) = "web_search"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolChoiceFunctionTool(BaseModel):
|
||||
|
|
|
|||
|
|
@ -5,9 +5,22 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.streaming import (
|
||||
_process_tool_choice,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.types import ToolContext
|
||||
from llama_stack_api.openai_responses import (
|
||||
from llama_stack_api import (
|
||||
MCPListToolsTool,
|
||||
OpenAIChatCompletionToolChoiceAllowedTools,
|
||||
OpenAIChatCompletionToolChoiceCustomTool,
|
||||
OpenAIChatCompletionToolChoiceFunctionTool,
|
||||
OpenAIResponseInputToolChoiceAllowedTools,
|
||||
OpenAIResponseInputToolChoiceCustomTool,
|
||||
OpenAIResponseInputToolChoiceFileSearch,
|
||||
OpenAIResponseInputToolChoiceFunctionTool,
|
||||
OpenAIResponseInputToolChoiceMCPTool,
|
||||
OpenAIResponseInputToolChoiceMode,
|
||||
OpenAIResponseInputToolChoiceWebSearch,
|
||||
OpenAIResponseInputToolFileSearch,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolMCP,
|
||||
|
|
@ -181,3 +194,326 @@ class TestToolContext:
|
|||
assert len(context.previous_tool_listings) == 1
|
||||
assert len(context.previous_tool_listings[0].tools) == 1
|
||||
assert context.previous_tool_listings[0].server_label == "anotherlabel"
|
||||
|
||||
|
||||
class TestProcessToolChoice:
|
||||
"""Comprehensive test suite for _process_tool_choice function."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up common test fixtures."""
|
||||
self.chat_tools = [
|
||||
{"type": "function", "function": {"name": "get_weather"}},
|
||||
{"type": "function", "function": {"name": "calculate"}},
|
||||
{"type": "function", "function": {"name": "file_search"}},
|
||||
{"type": "function", "function": {"name": "web_search"}},
|
||||
]
|
||||
self.server_label_to_tools = {
|
||||
"mcp_server_1": ["mcp_tool_1", "mcp_tool_2"],
|
||||
"mcp_server_2": ["mcp_tool_3"],
|
||||
}
|
||||
|
||||
async def test_mode_auto(self):
|
||||
"""Test auto mode - should return 'auto' string."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceMode.auto
|
||||
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||
assert result == "auto"
|
||||
|
||||
async def test_mode_none(self):
|
||||
"""Test none mode - should return 'none' string."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceMode.none
|
||||
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||
assert result == "none"
|
||||
|
||||
async def test_mode_required_with_tools(self):
|
||||
"""Test required mode with available tools - should return AllowedTools with all function tools."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceMode.required
|
||||
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||
|
||||
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||
assert result.allowed_tools.mode == "required"
|
||||
assert len(result.allowed_tools.tools) == 4
|
||||
tool_names = [tool["function"]["name"] for tool in result.allowed_tools.tools]
|
||||
assert "get_weather" in tool_names
|
||||
assert "calculate" in tool_names
|
||||
assert "file_search" in tool_names
|
||||
assert "web_search" in tool_names
|
||||
|
||||
async def test_mode_required_without_tools(self):
|
||||
"""Test required mode without available tools - should return None."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceMode.required
|
||||
result = await _process_tool_choice([], tool_choice, self.server_label_to_tools)
|
||||
assert result is None
|
||||
|
||||
async def test_allowed_tools_function(self):
|
||||
"""Test allowed_tools with function tool types."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(
|
||||
mode="required",
|
||||
tools=[
|
||||
{"type": "function", "name": "get_weather"},
|
||||
{"type": "function", "name": "calculate"},
|
||||
],
|
||||
)
|
||||
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||
|
||||
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||
assert result.allowed_tools.mode == "required"
|
||||
assert len(result.allowed_tools.tools) == 2
|
||||
assert result.allowed_tools.tools[0]["function"]["name"] == "get_weather"
|
||||
assert result.allowed_tools.tools[1]["function"]["name"] == "calculate"
|
||||
|
||||
async def test_allowed_tools_custom(self):
|
||||
"""Test allowed_tools with custom tool types."""
|
||||
chat_tools = [{"type": "function", "function": {"name": "custom_tool_1"}}]
|
||||
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(
|
||||
mode="auto",
|
||||
tools=[{"type": "custom", "name": "custom_tool_1"}],
|
||||
)
|
||||
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
|
||||
|
||||
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||
assert result.allowed_tools.mode == "auto"
|
||||
assert len(result.allowed_tools.tools) == 1
|
||||
assert result.allowed_tools.tools[0]["type"] == "custom"
|
||||
assert result.allowed_tools.tools[0]["custom"]["name"] == "custom_tool_1"
|
||||
|
||||
async def test_allowed_tools_file_search(self):
|
||||
"""Test allowed_tools with file_search."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(
|
||||
mode="required",
|
||||
tools=[{"type": "file_search"}],
|
||||
)
|
||||
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||
|
||||
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||
assert len(result.allowed_tools.tools) == 1
|
||||
assert result.allowed_tools.tools[0]["function"]["name"] == "file_search"
|
||||
|
||||
async def test_allowed_tools_web_search(self):
|
||||
"""Test allowed_tools with web_search."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(
|
||||
mode="required",
|
||||
tools=[
|
||||
{"type": "web_search_preview_2025_03_11"},
|
||||
{"type": "web_search_2025_08_26"},
|
||||
{"type": "web_search_preview"},
|
||||
],
|
||||
)
|
||||
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||
|
||||
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||
assert len(result.allowed_tools.tools) == 3
|
||||
assert result.allowed_tools.tools[0]["function"]["name"] == "web_search"
|
||||
assert result.allowed_tools.tools[0]["type"] == "function"
|
||||
assert result.allowed_tools.tools[1]["function"]["name"] == "web_search"
|
||||
assert result.allowed_tools.tools[1]["type"] == "function"
|
||||
assert result.allowed_tools.tools[2]["function"]["name"] == "web_search"
|
||||
assert result.allowed_tools.tools[2]["type"] == "function"
|
||||
|
||||
async def test_allowed_tools_mcp_server_label(self):
|
||||
"""Test allowed_tools with MCP server label (no specific tool name)."""
|
||||
chat_tools = [
|
||||
{"type": "function", "function": {"name": "mcp_tool_1"}},
|
||||
{"type": "function", "function": {"name": "mcp_tool_2"}},
|
||||
]
|
||||
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(
|
||||
mode="required",
|
||||
tools=[{"type": "mcp", "server_label": "mcp_server_1"}],
|
||||
)
|
||||
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
|
||||
|
||||
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||
assert len(result.allowed_tools.tools) == 2
|
||||
tool_names = [tool["function"]["name"] for tool in result.allowed_tools.tools]
|
||||
assert "mcp_tool_1" in tool_names
|
||||
assert "mcp_tool_2" in tool_names
|
||||
|
||||
async def test_allowed_tools_mixed_types(self):
|
||||
"""Test allowed_tools with mixed tool types."""
|
||||
chat_tools = [
|
||||
{"type": "function", "function": {"name": "get_weather"}},
|
||||
{"type": "function", "function": {"name": "file_search"}},
|
||||
{"type": "function", "function": {"name": "web_search"}},
|
||||
{"type": "function", "function": {"name": "mcp_tool_1"}},
|
||||
]
|
||||
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(
|
||||
mode="auto",
|
||||
tools=[
|
||||
{"type": "function", "name": "get_weather"},
|
||||
{"type": "file_search"},
|
||||
{"type": "web_search"},
|
||||
{"type": "mcp", "server_label": "mcp_server_1"},
|
||||
],
|
||||
)
|
||||
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
|
||||
|
||||
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||
# Should have: get_weather, file_search, web_search, mcp_tool_1, mcp_tool_2
|
||||
assert len(result.allowed_tools.tools) >= 3
|
||||
|
||||
async def test_allowed_tools_invalid_type(self):
|
||||
"""Test allowed_tools with an unsupported tool type - should skip it."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(
|
||||
mode="required",
|
||||
tools=[
|
||||
{"type": "function", "name": "get_weather"},
|
||||
{"type": "unsupported_type", "name": "bad_tool"},
|
||||
],
|
||||
)
|
||||
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||
|
||||
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||
# Should only include the valid function tool
|
||||
assert len(result.allowed_tools.tools) == 1
|
||||
assert result.allowed_tools.tools[0]["function"]["name"] == "get_weather"
|
||||
|
||||
async def test_specific_custom_tool_valid(self):
|
||||
"""Test specific custom tool choice when tool exists."""
|
||||
chat_tools = [{"type": "function", "function": {"name": "custom_tool"}}]
|
||||
tool_choice = OpenAIResponseInputToolChoiceCustomTool(name="custom_tool")
|
||||
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
|
||||
|
||||
assert isinstance(result, OpenAIChatCompletionToolChoiceCustomTool)
|
||||
assert result.custom.name == "custom_tool"
|
||||
|
||||
async def test_specific_custom_tool_invalid(self):
|
||||
"""Test specific custom tool choice when tool doesn't exist - should return None."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceCustomTool(name="nonexistent_tool")
|
||||
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||
assert result is None
|
||||
|
||||
async def test_specific_function_tool_valid(self):
|
||||
"""Test specific function tool choice when tool exists."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceFunctionTool(name="get_weather")
|
||||
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||
|
||||
assert isinstance(result, OpenAIChatCompletionToolChoiceFunctionTool)
|
||||
assert result.function.name == "get_weather"
|
||||
|
||||
async def test_specific_function_tool_invalid(self):
|
||||
"""Test specific function tool choice when tool doesn't exist - should return None."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceFunctionTool(name="nonexistent_function")
|
||||
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||
assert result is None
|
||||
|
||||
async def test_specific_file_search_valid(self):
|
||||
"""Test file_search tool choice when available."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceFileSearch()
|
||||
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||
|
||||
assert isinstance(result, OpenAIChatCompletionToolChoiceFunctionTool)
|
||||
assert result.function.name == "file_search"
|
||||
|
||||
async def test_specific_file_search_invalid(self):
|
||||
"""Test file_search tool choice when not available - should return None."""
|
||||
chat_tools = [{"type": "function", "function": {"name": "get_weather"}}]
|
||||
tool_choice = OpenAIResponseInputToolChoiceFileSearch()
|
||||
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
|
||||
assert result is None
|
||||
|
||||
async def test_specific_web_search_valid(self):
|
||||
"""Test web_search tool choice when available."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceWebSearch()
|
||||
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||
|
||||
assert isinstance(result, OpenAIChatCompletionToolChoiceFunctionTool)
|
||||
assert result.function.name == "web_search"
|
||||
|
||||
async def test_specific_web_search_invalid(self):
|
||||
"""Test web_search tool choice when not available - should return None."""
|
||||
chat_tools = [{"type": "function", "function": {"name": "get_weather"}}]
|
||||
tool_choice = OpenAIResponseInputToolChoiceWebSearch()
|
||||
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
|
||||
assert result is None
|
||||
|
||||
async def test_specific_mcp_tool_with_name(self):
|
||||
"""Test MCP tool choice with specific tool name."""
|
||||
chat_tools = [{"type": "function", "function": {"name": "mcp_tool_1"}}]
|
||||
tool_choice = OpenAIResponseInputToolChoiceMCPTool(
|
||||
server_label="mcp_server_1",
|
||||
name="mcp_tool_1",
|
||||
)
|
||||
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
|
||||
|
||||
assert isinstance(result, OpenAIChatCompletionToolChoiceFunctionTool)
|
||||
assert result.function.name == "mcp_tool_1"
|
||||
|
||||
async def test_specific_mcp_tool_with_name_not_in_chat_tools(self):
|
||||
"""Test MCP tool choice with specific tool name that doesn't exist in chat_tools."""
|
||||
chat_tools = [{"type": "function", "function": {"name": "other_tool"}}]
|
||||
tool_choice = OpenAIResponseInputToolChoiceMCPTool(
|
||||
server_label="mcp_server_1",
|
||||
name="mcp_tool_1",
|
||||
)
|
||||
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
|
||||
assert result is None
|
||||
|
||||
async def test_specific_mcp_tool_server_label_only(self):
|
||||
"""Test MCP tool choice with only server label (no specific tool name)."""
|
||||
chat_tools = [
|
||||
{"type": "function", "function": {"name": "mcp_tool_1"}},
|
||||
{"type": "function", "function": {"name": "mcp_tool_2"}},
|
||||
]
|
||||
tool_choice = OpenAIResponseInputToolChoiceMCPTool(server_label="mcp_server_1")
|
||||
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
|
||||
|
||||
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||
assert result.allowed_tools.mode == "required"
|
||||
assert len(result.allowed_tools.tools) == 2
|
||||
tool_names = [tool["function"]["name"] for tool in result.allowed_tools.tools]
|
||||
assert "mcp_tool_1" in tool_names
|
||||
assert "mcp_tool_2" in tool_names
|
||||
|
||||
async def test_specific_mcp_tool_unknown_server(self):
|
||||
"""Test MCP tool choice with unknown server label."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceMCPTool(
|
||||
server_label="unknown_server",
|
||||
name="some_tool",
|
||||
)
|
||||
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||
# Should return None because server not found
|
||||
assert result is None
|
||||
|
||||
async def test_empty_chat_tools(self):
|
||||
"""Test with empty chat_tools list."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceFunctionTool(name="get_weather")
|
||||
result = await _process_tool_choice([], tool_choice, self.server_label_to_tools)
|
||||
assert result is None
|
||||
|
||||
async def test_empty_server_label_to_tools(self):
|
||||
"""Test with empty server_label_to_tools mapping."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceMCPTool(server_label="mcp_server_1")
|
||||
result = await _process_tool_choice(self.chat_tools, tool_choice, {})
|
||||
# Should handle gracefully
|
||||
assert result is None or isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||
|
||||
async def test_allowed_tools_empty_list(self):
|
||||
"""Test allowed_tools with empty tools list."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(mode="auto", tools=[])
|
||||
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||
|
||||
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||
assert len(result.allowed_tools.tools) == 0
|
||||
|
||||
async def test_mcp_tool_multiple_servers(self):
|
||||
"""Test MCP tool choice with multiple server labels."""
|
||||
chat_tools = [
|
||||
{"type": "function", "function": {"name": "mcp_tool_1"}},
|
||||
{"type": "function", "function": {"name": "mcp_tool_2"}},
|
||||
{"type": "function", "function": {"name": "mcp_tool_3"}},
|
||||
]
|
||||
server_label_to_tools = {
|
||||
"server_a": ["mcp_tool_1"],
|
||||
"server_b": ["mcp_tool_2", "mcp_tool_3"],
|
||||
}
|
||||
|
||||
# Test server_a
|
||||
tool_choice_a = OpenAIResponseInputToolChoiceMCPTool(server_label="server_a")
|
||||
result_a = await _process_tool_choice(chat_tools, tool_choice_a, server_label_to_tools)
|
||||
assert isinstance(result_a, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||
assert len(result_a.allowed_tools.tools) == 1
|
||||
|
||||
# Test server_b
|
||||
tool_choice_b = OpenAIResponseInputToolChoiceMCPTool(server_label="server_b")
|
||||
result_b = await _process_tool_choice(chat_tools, tool_choice_b, server_label_to_tools)
|
||||
assert isinstance(result_b, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||
assert len(result_b.allowed_tools.tools) == 2
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue