precommit fix

This commit is contained in:
Kai Wu 2025-09-30 09:22:58 -07:00
parent 9adabb09cf
commit 774e4311ad
2 changed files with 88 additions and 210 deletions

View file

@ -4,11 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import copy
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from enum import Enum
from typing import Any, cast
import copy
import httpx
from mcp import ClientSession, McpError
@ -32,10 +32,9 @@ logger = get_logger(__name__, category="tools")
protocol_cache = TTLDict(ttl_seconds=3600)
def resolve_json_schema_refs(schema: dict[str, Any]) -> dict[str, Any]:
def resolve_json_schema_refs(schema: Any) -> Any:
"""
Resolve JSON Schema $ref references using $defs.
This function takes a JSON schema that may contain $ref and $defs,
and returns a new schema with all $ref references resolved inline.
"""

View file

@ -4,25 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from unittest.mock import AsyncMock, Mock, patch, MagicMock
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator
import httpx
from mcp import McpError
from mcp import types as mcp_types
from unittest.mock import AsyncMock, Mock, patch
from llama_stack.apis.tools import ListToolDefsResponse
from llama_stack.providers.utils.tools.mcp import (
resolve_json_schema_refs,
MCPProtol,
client_wrapper,
list_mcp_tools,
invoke_mcp_tool,
protocol_cache,
resolve_json_schema_refs,
)
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolParameter, ToolInvocationResult
from llama_stack.apis.common.content_types import TextContentItem, ImageContentItem
from llama_stack.core.datatypes import AuthenticationRequiredError
class TestResolveJsonSchemaRefs:
@ -32,18 +21,10 @@ class TestResolveJsonSchemaRefs:
"""Test resolving a simple $ref reference."""
schema = {
"type": "object",
"properties": {
"user": {"$ref": "#/$defs/User"}
},
"properties": {"user": {"$ref": "#/$defs/User"}},
"$defs": {
"User": {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"}
}
}
}
"User": {"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}}
},
}
result = resolve_json_schema_refs(schema)
@ -51,14 +32,8 @@ class TestResolveJsonSchemaRefs:
expected = {
"type": "object",
"properties": {
"user": {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"}
}
}
}
"user": {"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}}
},
}
assert result == expected
@ -67,23 +42,11 @@ class TestResolveJsonSchemaRefs:
"""Test resolving nested $ref references."""
schema = {
"type": "object",
"properties": {
"data": {"$ref": "#/$defs/Container"}
},
"properties": {"data": {"$ref": "#/$defs/Container"}},
"$defs": {
"Container": {
"type": "object",
"properties": {
"user": {"$ref": "#/$defs/User"}
}
"Container": {"type": "object", "properties": {"user": {"$ref": "#/$defs/User"}}},
"User": {"type": "object", "properties": {"name": {"type": "string"}}},
},
"User": {
"type": "object",
"properties": {
"name": {"type": "string"}
}
}
}
}
result = resolve_json_schema_refs(schema)
@ -93,16 +56,9 @@ class TestResolveJsonSchemaRefs:
"properties": {
"data": {
"type": "object",
"properties": {
"user": {
"type": "object",
"properties": {
"name": {"type": "string"}
}
}
}
}
"properties": {"user": {"type": "object", "properties": {"name": {"type": "string"}}}},
}
},
}
assert result == expected
@ -111,20 +67,8 @@ class TestResolveJsonSchemaRefs:
"""Test resolving $ref references within arrays."""
schema = {
"type": "object",
"properties": {
"items": {
"type": "array",
"items": {"$ref": "#/$defs/Item"}
}
},
"$defs": {
"Item": {
"type": "object",
"properties": {
"id": {"type": "string"}
}
}
}
"properties": {"items": {"type": "array", "items": {"$ref": "#/$defs/Item"}}},
"$defs": {"Item": {"type": "object", "properties": {"id": {"type": "string"}}}},
}
result = resolve_json_schema_refs(schema)
@ -132,16 +76,8 @@ class TestResolveJsonSchemaRefs:
expected = {
"type": "object",
"properties": {
"items": {
"type": "array",
"items": {
"type": "object",
"properties": {
"id": {"type": "string"}
}
}
}
}
"items": {"type": "array", "items": {"type": "object", "properties": {"id": {"type": "string"}}}}
},
}
assert result == expected
@ -150,65 +86,35 @@ class TestResolveJsonSchemaRefs:
"""Test handling of missing $ref definition."""
schema = {
"type": "object",
"properties": {
"user": {"$ref": "#/$defs/MissingUser"}
},
"$defs": {
"User": {
"type": "object",
"properties": {
"name": {"type": "string"}
}
}
}
"properties": {"user": {"$ref": "#/$defs/MissingUser"}},
"$defs": {"User": {"type": "object", "properties": {"name": {"type": "string"}}}},
}
with patch('llama_stack.providers.utils.tools.mcp.logger') as mock_logger:
with patch("llama_stack.providers.utils.tools.mcp.logger") as mock_logger:
result = resolve_json_schema_refs(schema)
mock_logger.warning.assert_called_once_with("Referenced definition 'MissingUser' not found in $defs")
# Should return the original $ref unchanged
expected = {
"type": "object",
"properties": {
"user": {"$ref": "#/$defs/MissingUser"}
}
}
expected = {"type": "object", "properties": {"user": {"$ref": "#/$defs/MissingUser"}}}
assert result == expected
def test_resolve_unsupported_ref_format(self):
"""Test handling of unsupported $ref format."""
schema = {
"type": "object",
"properties": {
"user": {"$ref": "http://example.com/schema"}
},
"$defs": {}
}
schema = {"type": "object", "properties": {"user": {"$ref": "http://example.com/schema"}}, "$defs": {}}
with patch('llama_stack.providers.utils.tools.mcp.logger') as mock_logger:
with patch("llama_stack.providers.utils.tools.mcp.logger") as mock_logger:
result = resolve_json_schema_refs(schema)
mock_logger.warning.assert_called_once_with("Unsupported $ref format: http://example.com/schema")
# Should return the original $ref unchanged
expected = {
"type": "object",
"properties": {
"user": {"$ref": "http://example.com/schema"}
}
}
expected = {"type": "object", "properties": {"user": {"$ref": "http://example.com/schema"}}}
assert result == expected
def test_resolve_no_defs(self):
"""Test schema without $defs section."""
schema = {
"type": "object",
"properties": {
"name": {"type": "string"}
}
}
schema = {"type": "object", "properties": {"name": {"type": "string"}}}
result = resolve_json_schema_refs(schema)
@ -225,20 +131,10 @@ class TestResolveJsonSchemaRefs:
"""Test that original schema is not modified."""
original_schema = {
"type": "object",
"properties": {
"user": {"$ref": "#/$defs/User"}
},
"$defs": {
"User": {
"type": "object",
"properties": {
"name": {"type": "string"}
}
}
}
"properties": {"user": {"$ref": "#/$defs/User"}},
"$defs": {"User": {"type": "object", "properties": {"name": {"type": "string"}}}},
}
original_copy = original_schema.copy()
resolve_json_schema_refs(original_schema)
# Original should be unchanged (but this is a shallow comparison)
@ -259,7 +155,6 @@ class TestMCPProtocol:
class TestListMcpTools:
"""Test cases for list_mcp_tools function."""
@pytest.mark.asyncio
async def test_list_tools_success(self):
"""Test successful listing of MCP tools."""
endpoint = "http://example.com/mcp"
@ -272,16 +167,9 @@ class TestListMcpTools:
mock_tool.inputSchema = {
"type": "object",
"properties": {
"param1": {
"type": "string",
"description": "First parameter",
"default": "default_value"
"param1": {"type": "string", "description": "First parameter", "default": "default_value"},
"param2": {"type": "integer", "description": "Second parameter"},
},
"param2": {
"type": "integer",
"description": "Second parameter"
}
}
}
mock_tools_result = Mock()
@ -290,7 +178,7 @@ class TestListMcpTools:
mock_session = Mock()
mock_session.list_tools = AsyncMock(return_value=mock_tools_result)
with patch('llama_stack.providers.utils.tools.mcp.client_wrapper') as mock_wrapper:
with patch("llama_stack.providers.utils.tools.mcp.client_wrapper") as mock_wrapper:
mock_wrapper.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_wrapper.return_value.__aexit__ = AsyncMock()
@ -318,7 +206,6 @@ class TestListMcpTools:
assert param2.description == "Second parameter"
assert param2.required is True # No default value
@pytest.mark.asyncio
async def test_list_tools_with_schema_refs(self):
"""Test listing tools with JSON Schema $refs."""
endpoint = "http://example.com/mcp"
@ -330,17 +217,10 @@ class TestListMcpTools:
mock_tool.description = "Tool with refs"
mock_tool.inputSchema = {
"type": "object",
"properties": {
"user": {"$ref": "#/$defs/User"}
},
"properties": {"user": {"$ref": "#/$defs/User"}},
"$defs": {
"User": {
"type": "object",
"properties": {
"name": {"type": "string", "description": "User name"}
}
}
}
"User": {"type": "object", "properties": {"name": {"type": "string", "description": "User name"}}}
},
}
mock_tools_result = Mock()
@ -349,7 +229,7 @@ class TestListMcpTools:
mock_session = Mock()
mock_session.list_tools = AsyncMock(return_value=mock_tools_result)
with patch('llama_stack.providers.utils.tools.mcp.client_wrapper') as mock_wrapper:
with patch("llama_stack.providers.utils.tools.mcp.client_wrapper") as mock_wrapper:
mock_wrapper.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_wrapper.return_value.__aexit__ = AsyncMock()
@ -362,7 +242,6 @@ class TestListMcpTools:
# The user parameter should be flattened from the resolved $ref
# Note: This depends on how the schema resolution works with nested objects
@pytest.mark.asyncio
async def test_list_tools_empty_result(self):
"""Test listing tools when no tools are available."""
endpoint = "http://example.com/mcp"
@ -374,7 +253,7 @@ class TestListMcpTools:
mock_session = Mock()
mock_session.list_tools = AsyncMock(return_value=mock_tools_result)
with patch('llama_stack.providers.utils.tools.mcp.client_wrapper') as mock_wrapper:
with patch("llama_stack.providers.utils.tools.mcp.client_wrapper") as mock_wrapper:
mock_wrapper.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_wrapper.return_value.__aexit__ = AsyncMock()