llama-stack-mirror/tests/unit/providers/utils/tools/test_mcp.py
2025-09-30 09:22:58 -07:00

263 lines
9.6 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, Mock, patch
from llama_stack.apis.tools import ListToolDefsResponse
from llama_stack.providers.utils.tools.mcp import (
MCPProtol,
list_mcp_tools,
resolve_json_schema_refs,
)
class TestResolveJsonSchemaRefs:
"""Test cases for resolve_json_schema_refs function."""
def test_resolve_simple_ref(self):
"""Test resolving a simple $ref reference."""
schema = {
"type": "object",
"properties": {"user": {"$ref": "#/$defs/User"}},
"$defs": {
"User": {"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}}
},
}
result = resolve_json_schema_refs(schema)
expected = {
"type": "object",
"properties": {
"user": {"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}}
},
}
assert result == expected
def test_resolve_nested_refs(self):
"""Test resolving nested $ref references."""
schema = {
"type": "object",
"properties": {"data": {"$ref": "#/$defs/Container"}},
"$defs": {
"Container": {"type": "object", "properties": {"user": {"$ref": "#/$defs/User"}}},
"User": {"type": "object", "properties": {"name": {"type": "string"}}},
},
}
result = resolve_json_schema_refs(schema)
expected = {
"type": "object",
"properties": {
"data": {
"type": "object",
"properties": {"user": {"type": "object", "properties": {"name": {"type": "string"}}}},
}
},
}
assert result == expected
def test_resolve_refs_in_array(self):
"""Test resolving $ref references within arrays."""
schema = {
"type": "object",
"properties": {"items": {"type": "array", "items": {"$ref": "#/$defs/Item"}}},
"$defs": {"Item": {"type": "object", "properties": {"id": {"type": "string"}}}},
}
result = resolve_json_schema_refs(schema)
expected = {
"type": "object",
"properties": {
"items": {"type": "array", "items": {"type": "object", "properties": {"id": {"type": "string"}}}}
},
}
assert result == expected
def test_resolve_missing_ref(self):
"""Test handling of missing $ref definition."""
schema = {
"type": "object",
"properties": {"user": {"$ref": "#/$defs/MissingUser"}},
"$defs": {"User": {"type": "object", "properties": {"name": {"type": "string"}}}},
}
with patch("llama_stack.providers.utils.tools.mcp.logger") as mock_logger:
result = resolve_json_schema_refs(schema)
mock_logger.warning.assert_called_once_with("Referenced definition 'MissingUser' not found in $defs")
# Should return the original $ref unchanged
expected = {"type": "object", "properties": {"user": {"$ref": "#/$defs/MissingUser"}}}
assert result == expected
def test_resolve_unsupported_ref_format(self):
"""Test handling of unsupported $ref format."""
schema = {"type": "object", "properties": {"user": {"$ref": "http://example.com/schema"}}, "$defs": {}}
with patch("llama_stack.providers.utils.tools.mcp.logger") as mock_logger:
result = resolve_json_schema_refs(schema)
mock_logger.warning.assert_called_once_with("Unsupported $ref format: http://example.com/schema")
# Should return the original $ref unchanged
expected = {"type": "object", "properties": {"user": {"$ref": "http://example.com/schema"}}}
assert result == expected
def test_resolve_no_defs(self):
"""Test schema without $defs section."""
schema = {"type": "object", "properties": {"name": {"type": "string"}}}
result = resolve_json_schema_refs(schema)
assert result == schema
def test_resolve_non_dict_input(self):
"""Test with non-dictionary input."""
assert resolve_json_schema_refs("string") == "string"
assert resolve_json_schema_refs(123) == 123
assert resolve_json_schema_refs(["list"]) == ["list"]
assert resolve_json_schema_refs(None) is None
def test_resolve_preserves_original(self):
"""Test that original schema is not modified."""
original_schema = {
"type": "object",
"properties": {"user": {"$ref": "#/$defs/User"}},
"$defs": {"User": {"type": "object", "properties": {"name": {"type": "string"}}}},
}
resolve_json_schema_refs(original_schema)
# Original should be unchanged (but this is a shallow comparison)
assert "$ref" in original_schema["properties"]["user"]
assert "$defs" in original_schema
class TestMCPProtocol:
"""Test cases for MCPProtol enum."""
def test_protocol_values(self):
"""Test enum values are correct."""
assert MCPProtol.UNKNOWN.value == 0
assert MCPProtol.STREAMABLE_HTTP.value == 1
assert MCPProtol.SSE.value == 2
class TestListMcpTools:
"""Test cases for list_mcp_tools function."""
async def test_list_tools_success(self):
"""Test successful listing of MCP tools."""
endpoint = "http://example.com/mcp"
headers = {"Authorization": "Bearer token"}
# Mock tool from MCP
mock_tool = Mock()
mock_tool.name = "test_tool"
mock_tool.description = "A test tool"
mock_tool.inputSchema = {
"type": "object",
"properties": {
"param1": {"type": "string", "description": "First parameter", "default": "default_value"},
"param2": {"type": "integer", "description": "Second parameter"},
},
}
mock_tools_result = Mock()
mock_tools_result.tools = [mock_tool]
mock_session = Mock()
mock_session.list_tools = AsyncMock(return_value=mock_tools_result)
with patch("llama_stack.providers.utils.tools.mcp.client_wrapper") as mock_wrapper:
mock_wrapper.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_wrapper.return_value.__aexit__ = AsyncMock()
result = await list_mcp_tools(endpoint, headers)
assert isinstance(result, ListToolDefsResponse)
assert len(result.data) == 1
tool_def = result.data[0]
assert tool_def.name == "test_tool"
assert tool_def.description == "A test tool"
assert tool_def.metadata["endpoint"] == endpoint
# Check parameters
assert len(tool_def.parameters) == 2
param1 = next(p for p in tool_def.parameters if p.name == "param1")
assert param1.parameter_type == "string"
assert param1.description == "First parameter"
assert param1.required is False # Has default value
assert param1.default == "default_value"
param2 = next(p for p in tool_def.parameters if p.name == "param2")
assert param2.parameter_type == "integer"
assert param2.description == "Second parameter"
assert param2.required is True # No default value
async def test_list_tools_with_schema_refs(self):
"""Test listing tools with JSON Schema $refs."""
endpoint = "http://example.com/mcp"
headers = {}
# Mock tool with $ref in schema
mock_tool = Mock()
mock_tool.name = "ref_tool"
mock_tool.description = "Tool with refs"
mock_tool.inputSchema = {
"type": "object",
"properties": {"user": {"$ref": "#/$defs/User"}},
"$defs": {
"User": {"type": "object", "properties": {"name": {"type": "string", "description": "User name"}}}
},
}
mock_tools_result = Mock()
mock_tools_result.tools = [mock_tool]
mock_session = Mock()
mock_session.list_tools = AsyncMock(return_value=mock_tools_result)
with patch("llama_stack.providers.utils.tools.mcp.client_wrapper") as mock_wrapper:
mock_wrapper.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_wrapper.return_value.__aexit__ = AsyncMock()
result = await list_mcp_tools(endpoint, headers)
# Should have resolved the $ref
tool_def = result.data[0]
assert len(tool_def.parameters) == 1
# The user parameter should be flattened from the resolved $ref
# Note: This depends on how the schema resolution works with nested objects
async def test_list_tools_empty_result(self):
"""Test listing tools when no tools are available."""
endpoint = "http://example.com/mcp"
headers = {}
mock_tools_result = Mock()
mock_tools_result.tools = []
mock_session = Mock()
mock_session.list_tools = AsyncMock(return_value=mock_tools_result)
with patch("llama_stack.providers.utils.tools.mcp.client_wrapper") as mock_wrapper:
mock_wrapper.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_wrapper.return_value.__aexit__ = AsyncMock()
result = await list_mcp_tools(endpoint, headers)
assert isinstance(result, ListToolDefsResponse)
assert len(result.data) == 0