mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
add unit tests
Signed-off-by: Jaideep Rao <jrao@redhat.com>
This commit is contained in:
parent
d2d2c88921
commit
36d7abd4d5
11 changed files with 392 additions and 85 deletions
|
|
@ -5,9 +5,22 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.streaming import (
|
||||
_process_tool_choice,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.types import ToolContext
|
||||
from llama_stack_api.openai_responses import (
|
||||
from llama_stack_api import (
|
||||
MCPListToolsTool,
|
||||
OpenAIChatCompletionToolChoiceAllowedTools,
|
||||
OpenAIChatCompletionToolChoiceCustomTool,
|
||||
OpenAIChatCompletionToolChoiceFunctionTool,
|
||||
OpenAIResponseInputToolChoiceAllowedTools,
|
||||
OpenAIResponseInputToolChoiceCustomTool,
|
||||
OpenAIResponseInputToolChoiceFileSearch,
|
||||
OpenAIResponseInputToolChoiceFunctionTool,
|
||||
OpenAIResponseInputToolChoiceMCPTool,
|
||||
OpenAIResponseInputToolChoiceMode,
|
||||
OpenAIResponseInputToolChoiceWebSearch,
|
||||
OpenAIResponseInputToolFileSearch,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolMCP,
|
||||
|
|
@ -181,3 +194,326 @@ class TestToolContext:
|
|||
assert len(context.previous_tool_listings) == 1
|
||||
assert len(context.previous_tool_listings[0].tools) == 1
|
||||
assert context.previous_tool_listings[0].server_label == "anotherlabel"
|
||||
|
||||
|
||||
class TestProcessToolChoice:
|
||||
"""Comprehensive test suite for _process_tool_choice function."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up common test fixtures."""
|
||||
self.chat_tools = [
|
||||
{"type": "function", "function": {"name": "get_weather"}},
|
||||
{"type": "function", "function": {"name": "calculate"}},
|
||||
{"type": "function", "function": {"name": "file_search"}},
|
||||
{"type": "function", "function": {"name": "web_search"}},
|
||||
]
|
||||
self.server_label_to_tools = {
|
||||
"mcp_server_1": ["mcp_tool_1", "mcp_tool_2"],
|
||||
"mcp_server_2": ["mcp_tool_3"],
|
||||
}
|
||||
|
||||
async def test_mode_auto(self):
|
||||
"""Test auto mode - should return 'auto' string."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceMode.auto
|
||||
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||
assert result == "auto"
|
||||
|
||||
async def test_mode_none(self):
|
||||
"""Test none mode - should return 'none' string."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceMode.none
|
||||
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||
assert result == "none"
|
||||
|
||||
async def test_mode_required_with_tools(self):
|
||||
"""Test required mode with available tools - should return AllowedTools with all function tools."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceMode.required
|
||||
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||
|
||||
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||
assert result.allowed_tools.mode == "required"
|
||||
assert len(result.allowed_tools.tools) == 4
|
||||
tool_names = [tool["function"]["name"] for tool in result.allowed_tools.tools]
|
||||
assert "get_weather" in tool_names
|
||||
assert "calculate" in tool_names
|
||||
assert "file_search" in tool_names
|
||||
assert "web_search" in tool_names
|
||||
|
||||
async def test_mode_required_without_tools(self):
|
||||
"""Test required mode without available tools - should return None."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceMode.required
|
||||
result = await _process_tool_choice([], tool_choice, self.server_label_to_tools)
|
||||
assert result is None
|
||||
|
||||
async def test_allowed_tools_function(self):
|
||||
"""Test allowed_tools with function tool types."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(
|
||||
mode="required",
|
||||
tools=[
|
||||
{"type": "function", "name": "get_weather"},
|
||||
{"type": "function", "name": "calculate"},
|
||||
],
|
||||
)
|
||||
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||
|
||||
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||
assert result.allowed_tools.mode == "required"
|
||||
assert len(result.allowed_tools.tools) == 2
|
||||
assert result.allowed_tools.tools[0]["function"]["name"] == "get_weather"
|
||||
assert result.allowed_tools.tools[1]["function"]["name"] == "calculate"
|
||||
|
||||
async def test_allowed_tools_custom(self):
|
||||
"""Test allowed_tools with custom tool types."""
|
||||
chat_tools = [{"type": "function", "function": {"name": "custom_tool_1"}}]
|
||||
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(
|
||||
mode="auto",
|
||||
tools=[{"type": "custom", "name": "custom_tool_1"}],
|
||||
)
|
||||
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
|
||||
|
||||
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||
assert result.allowed_tools.mode == "auto"
|
||||
assert len(result.allowed_tools.tools) == 1
|
||||
assert result.allowed_tools.tools[0]["type"] == "custom"
|
||||
assert result.allowed_tools.tools[0]["custom"]["name"] == "custom_tool_1"
|
||||
|
||||
async def test_allowed_tools_file_search(self):
|
||||
"""Test allowed_tools with file_search."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(
|
||||
mode="required",
|
||||
tools=[{"type": "file_search"}],
|
||||
)
|
||||
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||
|
||||
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||
assert len(result.allowed_tools.tools) == 1
|
||||
assert result.allowed_tools.tools[0]["function"]["name"] == "file_search"
|
||||
|
||||
async def test_allowed_tools_web_search(self):
|
||||
"""Test allowed_tools with web_search."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(
|
||||
mode="required",
|
||||
tools=[
|
||||
{"type": "web_search_preview_2025_03_11"},
|
||||
{"type": "web_search_2025_08_26"},
|
||||
{"type": "web_search_preview"},
|
||||
],
|
||||
)
|
||||
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||
|
||||
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||
assert len(result.allowed_tools.tools) == 3
|
||||
assert result.allowed_tools.tools[0]["function"]["name"] == "web_search"
|
||||
assert result.allowed_tools.tools[0]["type"] == "function"
|
||||
assert result.allowed_tools.tools[1]["function"]["name"] == "web_search"
|
||||
assert result.allowed_tools.tools[1]["type"] == "function"
|
||||
assert result.allowed_tools.tools[2]["function"]["name"] == "web_search"
|
||||
assert result.allowed_tools.tools[2]["type"] == "function"
|
||||
|
||||
async def test_allowed_tools_mcp_server_label(self):
|
||||
"""Test allowed_tools with MCP server label (no specific tool name)."""
|
||||
chat_tools = [
|
||||
{"type": "function", "function": {"name": "mcp_tool_1"}},
|
||||
{"type": "function", "function": {"name": "mcp_tool_2"}},
|
||||
]
|
||||
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(
|
||||
mode="required",
|
||||
tools=[{"type": "mcp", "server_label": "mcp_server_1"}],
|
||||
)
|
||||
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
|
||||
|
||||
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||
assert len(result.allowed_tools.tools) == 2
|
||||
tool_names = [tool["function"]["name"] for tool in result.allowed_tools.tools]
|
||||
assert "mcp_tool_1" in tool_names
|
||||
assert "mcp_tool_2" in tool_names
|
||||
|
||||
async def test_allowed_tools_mixed_types(self):
|
||||
"""Test allowed_tools with mixed tool types."""
|
||||
chat_tools = [
|
||||
{"type": "function", "function": {"name": "get_weather"}},
|
||||
{"type": "function", "function": {"name": "file_search"}},
|
||||
{"type": "function", "function": {"name": "web_search"}},
|
||||
{"type": "function", "function": {"name": "mcp_tool_1"}},
|
||||
]
|
||||
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(
|
||||
mode="auto",
|
||||
tools=[
|
||||
{"type": "function", "name": "get_weather"},
|
||||
{"type": "file_search"},
|
||||
{"type": "web_search"},
|
||||
{"type": "mcp", "server_label": "mcp_server_1"},
|
||||
],
|
||||
)
|
||||
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
|
||||
|
||||
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||
# Should have: get_weather, file_search, web_search, mcp_tool_1, mcp_tool_2
|
||||
assert len(result.allowed_tools.tools) >= 3
|
||||
|
||||
async def test_allowed_tools_invalid_type(self):
|
||||
"""Test allowed_tools with an unsupported tool type - should skip it."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(
|
||||
mode="required",
|
||||
tools=[
|
||||
{"type": "function", "name": "get_weather"},
|
||||
{"type": "unsupported_type", "name": "bad_tool"},
|
||||
],
|
||||
)
|
||||
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||
|
||||
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||
# Should only include the valid function tool
|
||||
assert len(result.allowed_tools.tools) == 1
|
||||
assert result.allowed_tools.tools[0]["function"]["name"] == "get_weather"
|
||||
|
||||
async def test_specific_custom_tool_valid(self):
|
||||
"""Test specific custom tool choice when tool exists."""
|
||||
chat_tools = [{"type": "function", "function": {"name": "custom_tool"}}]
|
||||
tool_choice = OpenAIResponseInputToolChoiceCustomTool(name="custom_tool")
|
||||
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
|
||||
|
||||
assert isinstance(result, OpenAIChatCompletionToolChoiceCustomTool)
|
||||
assert result.custom.name == "custom_tool"
|
||||
|
||||
async def test_specific_custom_tool_invalid(self):
|
||||
"""Test specific custom tool choice when tool doesn't exist - should return None."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceCustomTool(name="nonexistent_tool")
|
||||
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||
assert result is None
|
||||
|
||||
async def test_specific_function_tool_valid(self):
|
||||
"""Test specific function tool choice when tool exists."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceFunctionTool(name="get_weather")
|
||||
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||
|
||||
assert isinstance(result, OpenAIChatCompletionToolChoiceFunctionTool)
|
||||
assert result.function.name == "get_weather"
|
||||
|
||||
async def test_specific_function_tool_invalid(self):
|
||||
"""Test specific function tool choice when tool doesn't exist - should return None."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceFunctionTool(name="nonexistent_function")
|
||||
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||
assert result is None
|
||||
|
||||
async def test_specific_file_search_valid(self):
|
||||
"""Test file_search tool choice when available."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceFileSearch()
|
||||
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||
|
||||
assert isinstance(result, OpenAIChatCompletionToolChoiceFunctionTool)
|
||||
assert result.function.name == "file_search"
|
||||
|
||||
async def test_specific_file_search_invalid(self):
|
||||
"""Test file_search tool choice when not available - should return None."""
|
||||
chat_tools = [{"type": "function", "function": {"name": "get_weather"}}]
|
||||
tool_choice = OpenAIResponseInputToolChoiceFileSearch()
|
||||
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
|
||||
assert result is None
|
||||
|
||||
async def test_specific_web_search_valid(self):
|
||||
"""Test web_search tool choice when available."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceWebSearch()
|
||||
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||
|
||||
assert isinstance(result, OpenAIChatCompletionToolChoiceFunctionTool)
|
||||
assert result.function.name == "web_search"
|
||||
|
||||
async def test_specific_web_search_invalid(self):
|
||||
"""Test web_search tool choice when not available - should return None."""
|
||||
chat_tools = [{"type": "function", "function": {"name": "get_weather"}}]
|
||||
tool_choice = OpenAIResponseInputToolChoiceWebSearch()
|
||||
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
|
||||
assert result is None
|
||||
|
||||
async def test_specific_mcp_tool_with_name(self):
|
||||
"""Test MCP tool choice with specific tool name."""
|
||||
chat_tools = [{"type": "function", "function": {"name": "mcp_tool_1"}}]
|
||||
tool_choice = OpenAIResponseInputToolChoiceMCPTool(
|
||||
server_label="mcp_server_1",
|
||||
name="mcp_tool_1",
|
||||
)
|
||||
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
|
||||
|
||||
assert isinstance(result, OpenAIChatCompletionToolChoiceFunctionTool)
|
||||
assert result.function.name == "mcp_tool_1"
|
||||
|
||||
async def test_specific_mcp_tool_with_name_not_in_chat_tools(self):
|
||||
"""Test MCP tool choice with specific tool name that doesn't exist in chat_tools."""
|
||||
chat_tools = [{"type": "function", "function": {"name": "other_tool"}}]
|
||||
tool_choice = OpenAIResponseInputToolChoiceMCPTool(
|
||||
server_label="mcp_server_1",
|
||||
name="mcp_tool_1",
|
||||
)
|
||||
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
|
||||
assert result is None
|
||||
|
||||
async def test_specific_mcp_tool_server_label_only(self):
|
||||
"""Test MCP tool choice with only server label (no specific tool name)."""
|
||||
chat_tools = [
|
||||
{"type": "function", "function": {"name": "mcp_tool_1"}},
|
||||
{"type": "function", "function": {"name": "mcp_tool_2"}},
|
||||
]
|
||||
tool_choice = OpenAIResponseInputToolChoiceMCPTool(server_label="mcp_server_1")
|
||||
result = await _process_tool_choice(chat_tools, tool_choice, self.server_label_to_tools)
|
||||
|
||||
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||
assert result.allowed_tools.mode == "required"
|
||||
assert len(result.allowed_tools.tools) == 2
|
||||
tool_names = [tool["function"]["name"] for tool in result.allowed_tools.tools]
|
||||
assert "mcp_tool_1" in tool_names
|
||||
assert "mcp_tool_2" in tool_names
|
||||
|
||||
async def test_specific_mcp_tool_unknown_server(self):
|
||||
"""Test MCP tool choice with unknown server label."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceMCPTool(
|
||||
server_label="unknown_server",
|
||||
name="some_tool",
|
||||
)
|
||||
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||
# Should return None because server not found
|
||||
assert result is None
|
||||
|
||||
async def test_empty_chat_tools(self):
|
||||
"""Test with empty chat_tools list."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceFunctionTool(name="get_weather")
|
||||
result = await _process_tool_choice([], tool_choice, self.server_label_to_tools)
|
||||
assert result is None
|
||||
|
||||
async def test_empty_server_label_to_tools(self):
|
||||
"""Test with empty server_label_to_tools mapping."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceMCPTool(server_label="mcp_server_1")
|
||||
result = await _process_tool_choice(self.chat_tools, tool_choice, {})
|
||||
# Should handle gracefully
|
||||
assert result is None or isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||
|
||||
async def test_allowed_tools_empty_list(self):
|
||||
"""Test allowed_tools with empty tools list."""
|
||||
tool_choice = OpenAIResponseInputToolChoiceAllowedTools(mode="auto", tools=[])
|
||||
result = await _process_tool_choice(self.chat_tools, tool_choice, self.server_label_to_tools)
|
||||
|
||||
assert isinstance(result, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||
assert len(result.allowed_tools.tools) == 0
|
||||
|
||||
async def test_mcp_tool_multiple_servers(self):
|
||||
"""Test MCP tool choice with multiple server labels."""
|
||||
chat_tools = [
|
||||
{"type": "function", "function": {"name": "mcp_tool_1"}},
|
||||
{"type": "function", "function": {"name": "mcp_tool_2"}},
|
||||
{"type": "function", "function": {"name": "mcp_tool_3"}},
|
||||
]
|
||||
server_label_to_tools = {
|
||||
"server_a": ["mcp_tool_1"],
|
||||
"server_b": ["mcp_tool_2", "mcp_tool_3"],
|
||||
}
|
||||
|
||||
# Test server_a
|
||||
tool_choice_a = OpenAIResponseInputToolChoiceMCPTool(server_label="server_a")
|
||||
result_a = await _process_tool_choice(chat_tools, tool_choice_a, server_label_to_tools)
|
||||
assert isinstance(result_a, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||
assert len(result_a.allowed_tools.tools) == 1
|
||||
|
||||
# Test server_b
|
||||
tool_choice_b = OpenAIResponseInputToolChoiceMCPTool(server_label="server_b")
|
||||
result_b = await _process_tool_choice(chat_tools, tool_choice_b, server_label_to_tools)
|
||||
assert isinstance(result_b, OpenAIChatCompletionToolChoiceAllowedTools)
|
||||
assert len(result_b.allowed_tools.tools) == 2
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue