From 774e4311ad6f324915a01100f5d62bc662e5c389 Mon Sep 17 00:00:00 2001 From: Kai Wu Date: Tue, 30 Sep 2025 09:22:58 -0700 Subject: [PATCH] precommit fix --- llama_stack/providers/utils/tools/mcp.py | 21 +- tests/unit/providers/utils/tools/test_mcp.py | 277 ++++++------------- 2 files changed, 88 insertions(+), 210 deletions(-) diff --git a/llama_stack/providers/utils/tools/mcp.py b/llama_stack/providers/utils/tools/mcp.py index d182f25ae..8c814c51a 100644 --- a/llama_stack/providers/utils/tools/mcp.py +++ b/llama_stack/providers/utils/tools/mcp.py @@ -4,11 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import copy from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from enum import Enum from typing import Any, cast -import copy import httpx from mcp import ClientSession, McpError @@ -32,20 +32,19 @@ logger = get_logger(__name__, category="tools") protocol_cache = TTLDict(ttl_seconds=3600) -def resolve_json_schema_refs(schema: dict[str, Any]) -> dict[str, Any]: +def resolve_json_schema_refs(schema: Any) -> Any: """ Resolve JSON Schema $ref references using $defs. - This function takes a JSON schema that may contain $ref and $defs, and returns a new schema with all $ref references resolved inline. """ if not isinstance(schema, dict): return schema - + # Make a deep copy to avoid modifying the original resolved_schema = copy.deepcopy(schema) defs = resolved_schema.get("$defs", {}) - + def resolve_refs(obj: Any) -> Any: """Recursively resolve $ref references in the schema.""" if isinstance(obj, dict): @@ -53,7 +52,7 @@ def resolve_json_schema_refs(schema: dict[str, Any]) -> dict[str, Any]: # Extract the reference path ref_path = obj["$ref"] if ref_path.startswith("#/$defs/"): - def_name = ref_path[len("#/$defs/"):] + def_name = ref_path[len("#/$defs/") :] if def_name in defs: # Recursively resolve refs in the definition itself resolved_def = resolve_refs(defs[def_name]) @@ -73,13 +72,13 @@ def resolve_json_schema_refs(schema: dict[str, Any]) -> dict[str, Any]: else: # Return primitive values as-is return obj - + # Resolve all refs in the schema resolved_schema = resolve_refs(resolved_schema) - + # Remove the $defs section as it's no longer needed resolved_schema.pop("$defs", None) - + return resolved_schema @@ -166,10 +165,10 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs tools_result = await session.list_tools() for tool in tools_result.tools: parameters = [] - + # Resolve $ref and $defs in the input schema resolved_input_schema = resolve_json_schema_refs(tool.inputSchema) - + for param_name, param_schema in resolved_input_schema.get("properties", {}).items(): parameters.append( ToolParameter( diff --git a/tests/unit/providers/utils/tools/test_mcp.py b/tests/unit/providers/utils/tools/test_mcp.py index 80d574f2d..16f97fd54 100644 --- a/tests/unit/providers/utils/tools/test_mcp.py +++ b/tests/unit/providers/utils/tools/test_mcp.py @@ -4,25 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import pytest -from unittest.mock import AsyncMock, Mock, patch, MagicMock -from contextlib import asynccontextmanager -from typing import Any, AsyncGenerator -import httpx -from mcp import McpError -from mcp import types as mcp_types +from unittest.mock import AsyncMock, Mock, patch +from llama_stack.apis.tools import ListToolDefsResponse from llama_stack.providers.utils.tools.mcp import ( - resolve_json_schema_refs, MCPProtol, - client_wrapper, list_mcp_tools, - invoke_mcp_tool, - protocol_cache, + resolve_json_schema_refs, ) -from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolParameter, ToolInvocationResult -from llama_stack.apis.common.content_types import TextContentItem, ImageContentItem -from llama_stack.core.datatypes import AuthenticationRequiredError class TestResolveJsonSchemaRefs: @@ -32,186 +21,103 @@ class TestResolveJsonSchemaRefs: """Test resolving a simple $ref reference.""" schema = { "type": "object", - "properties": { - "user": {"$ref": "#/$defs/User"} - }, + "properties": {"user": {"$ref": "#/$defs/User"}}, "$defs": { - "User": { - "type": "object", - "properties": { - "name": {"type": "string"}, - "age": {"type": "integer"} - } - } - } + "User": {"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}} + }, } - + result = resolve_json_schema_refs(schema) - + expected = { "type": "object", "properties": { - "user": { - "type": "object", - "properties": { - "name": {"type": "string"}, - "age": {"type": "integer"} - } - } - } + "user": {"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}} + }, } - + assert result == expected def test_resolve_nested_refs(self): """Test resolving nested $ref references.""" schema = { "type": "object", - "properties": { - "data": {"$ref": "#/$defs/Container"} - }, + "properties": {"data": {"$ref": "#/$defs/Container"}}, "$defs": { - "Container": { - "type": "object", - "properties": { - "user": {"$ref": "#/$defs/User"} - } - }, - "User": { - "type": "object", - "properties": { - "name": {"type": "string"} - } - } - } + "Container": {"type": "object", "properties": {"user": {"$ref": "#/$defs/User"}}}, + "User": {"type": "object", "properties": {"name": {"type": "string"}}}, + }, } - + result = resolve_json_schema_refs(schema) - + expected = { "type": "object", "properties": { "data": { "type": "object", - "properties": { - "user": { - "type": "object", - "properties": { - "name": {"type": "string"} - } - } - } + "properties": {"user": {"type": "object", "properties": {"name": {"type": "string"}}}}, } - } + }, } - + assert result == expected def test_resolve_refs_in_array(self): """Test resolving $ref references within arrays.""" schema = { "type": "object", - "properties": { - "items": { - "type": "array", - "items": {"$ref": "#/$defs/Item"} - } - }, - "$defs": { - "Item": { - "type": "object", - "properties": { - "id": {"type": "string"} - } - } - } + "properties": {"items": {"type": "array", "items": {"$ref": "#/$defs/Item"}}}, + "$defs": {"Item": {"type": "object", "properties": {"id": {"type": "string"}}}}, } - + result = resolve_json_schema_refs(schema) - + expected = { "type": "object", "properties": { - "items": { - "type": "array", - "items": { - "type": "object", - "properties": { - "id": {"type": "string"} - } - } - } - } + "items": {"type": "array", "items": {"type": "object", "properties": {"id": {"type": "string"}}}} + }, } - + assert result == expected def test_resolve_missing_ref(self): """Test handling of missing $ref definition.""" schema = { "type": "object", - "properties": { - "user": {"$ref": "#/$defs/MissingUser"} - }, - "$defs": { - "User": { - "type": "object", - "properties": { - "name": {"type": "string"} - } - } - } + "properties": {"user": {"$ref": "#/$defs/MissingUser"}}, + "$defs": {"User": {"type": "object", "properties": {"name": {"type": "string"}}}}, } - - with patch('llama_stack.providers.utils.tools.mcp.logger') as mock_logger: + + with patch("llama_stack.providers.utils.tools.mcp.logger") as mock_logger: result = resolve_json_schema_refs(schema) mock_logger.warning.assert_called_once_with("Referenced definition 'MissingUser' not found in $defs") - + # Should return the original $ref unchanged - expected = { - "type": "object", - "properties": { - "user": {"$ref": "#/$defs/MissingUser"} - } - } - + expected = {"type": "object", "properties": {"user": {"$ref": "#/$defs/MissingUser"}}} + assert result == expected def test_resolve_unsupported_ref_format(self): """Test handling of unsupported $ref format.""" - schema = { - "type": "object", - "properties": { - "user": {"$ref": "http://example.com/schema"} - }, - "$defs": {} - } - - with patch('llama_stack.providers.utils.tools.mcp.logger') as mock_logger: + schema = {"type": "object", "properties": {"user": {"$ref": "http://example.com/schema"}}, "$defs": {}} + + with patch("llama_stack.providers.utils.tools.mcp.logger") as mock_logger: result = resolve_json_schema_refs(schema) mock_logger.warning.assert_called_once_with("Unsupported $ref format: http://example.com/schema") - + # Should return the original $ref unchanged - expected = { - "type": "object", - "properties": { - "user": {"$ref": "http://example.com/schema"} - } - } - + expected = {"type": "object", "properties": {"user": {"$ref": "http://example.com/schema"}}} + assert result == expected def test_resolve_no_defs(self): """Test schema without $defs section.""" - schema = { - "type": "object", - "properties": { - "name": {"type": "string"} - } - } - + schema = {"type": "object", "properties": {"name": {"type": "string"}}} + result = resolve_json_schema_refs(schema) - + assert result == schema def test_resolve_non_dict_input(self): @@ -225,22 +131,12 @@ class TestResolveJsonSchemaRefs: """Test that original schema is not modified.""" original_schema = { "type": "object", - "properties": { - "user": {"$ref": "#/$defs/User"} - }, - "$defs": { - "User": { - "type": "object", - "properties": { - "name": {"type": "string"} - } - } - } + "properties": {"user": {"$ref": "#/$defs/User"}}, + "$defs": {"User": {"type": "object", "properties": {"name": {"type": "string"}}}}, } - - original_copy = original_schema.copy() + resolve_json_schema_refs(original_schema) - + # Original should be unchanged (but this is a shallow comparison) assert "$ref" in original_schema["properties"]["user"] assert "$defs" in original_schema @@ -259,12 +155,11 @@ class TestMCPProtocol: class TestListMcpTools: """Test cases for list_mcp_tools function.""" - @pytest.mark.asyncio async def test_list_tools_success(self): """Test successful listing of MCP tools.""" endpoint = "http://example.com/mcp" headers = {"Authorization": "Bearer token"} - + # Mock tool from MCP mock_tool = Mock() mock_tool.name = "test_tool" @@ -272,113 +167,97 @@ class TestListMcpTools: mock_tool.inputSchema = { "type": "object", "properties": { - "param1": { - "type": "string", - "description": "First parameter", - "default": "default_value" - }, - "param2": { - "type": "integer", - "description": "Second parameter" - } - } + "param1": {"type": "string", "description": "First parameter", "default": "default_value"}, + "param2": {"type": "integer", "description": "Second parameter"}, + }, } - + mock_tools_result = Mock() mock_tools_result.tools = [mock_tool] - + mock_session = Mock() mock_session.list_tools = AsyncMock(return_value=mock_tools_result) - - with patch('llama_stack.providers.utils.tools.mcp.client_wrapper') as mock_wrapper: + + with patch("llama_stack.providers.utils.tools.mcp.client_wrapper") as mock_wrapper: mock_wrapper.return_value.__aenter__ = AsyncMock(return_value=mock_session) mock_wrapper.return_value.__aexit__ = AsyncMock() - + result = await list_mcp_tools(endpoint, headers) - + assert isinstance(result, ListToolDefsResponse) assert len(result.data) == 1 - + tool_def = result.data[0] assert tool_def.name == "test_tool" assert tool_def.description == "A test tool" assert tool_def.metadata["endpoint"] == endpoint - + # Check parameters assert len(tool_def.parameters) == 2 - + param1 = next(p for p in tool_def.parameters if p.name == "param1") assert param1.parameter_type == "string" assert param1.description == "First parameter" assert param1.required is False # Has default value assert param1.default == "default_value" - + param2 = next(p for p in tool_def.parameters if p.name == "param2") assert param2.parameter_type == "integer" assert param2.description == "Second parameter" assert param2.required is True # No default value - @pytest.mark.asyncio async def test_list_tools_with_schema_refs(self): """Test listing tools with JSON Schema $refs.""" endpoint = "http://example.com/mcp" headers = {} - + # Mock tool with $ref in schema mock_tool = Mock() mock_tool.name = "ref_tool" mock_tool.description = "Tool with refs" mock_tool.inputSchema = { "type": "object", - "properties": { - "user": {"$ref": "#/$defs/User"} - }, + "properties": {"user": {"$ref": "#/$defs/User"}}, "$defs": { - "User": { - "type": "object", - "properties": { - "name": {"type": "string", "description": "User name"} - } - } - } + "User": {"type": "object", "properties": {"name": {"type": "string", "description": "User name"}}} + }, } - + mock_tools_result = Mock() mock_tools_result.tools = [mock_tool] - + mock_session = Mock() mock_session.list_tools = AsyncMock(return_value=mock_tools_result) - - with patch('llama_stack.providers.utils.tools.mcp.client_wrapper') as mock_wrapper: + + with patch("llama_stack.providers.utils.tools.mcp.client_wrapper") as mock_wrapper: mock_wrapper.return_value.__aenter__ = AsyncMock(return_value=mock_session) mock_wrapper.return_value.__aexit__ = AsyncMock() - + result = await list_mcp_tools(endpoint, headers) - + # Should have resolved the $ref tool_def = result.data[0] assert len(tool_def.parameters) == 1 - + # The user parameter should be flattened from the resolved $ref # Note: This depends on how the schema resolution works with nested objects - @pytest.mark.asyncio async def test_list_tools_empty_result(self): """Test listing tools when no tools are available.""" endpoint = "http://example.com/mcp" headers = {} - + mock_tools_result = Mock() mock_tools_result.tools = [] - + mock_session = Mock() mock_session.list_tools = AsyncMock(return_value=mock_tools_result) - - with patch('llama_stack.providers.utils.tools.mcp.client_wrapper') as mock_wrapper: + + with patch("llama_stack.providers.utils.tools.mcp.client_wrapper") as mock_wrapper: mock_wrapper.return_value.__aenter__ = AsyncMock(return_value=mock_session) mock_wrapper.return_value.__aexit__ = AsyncMock() - + result = await list_mcp_tools(endpoint, headers) - + assert isinstance(result, ListToolDefsResponse) assert len(result.data) == 0