precommit fix

This commit is contained in:
Kai Wu 2025-09-30 09:22:58 -07:00
parent 9adabb09cf
commit 774e4311ad
2 changed files with 88 additions and 210 deletions

View file

@ -4,11 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import copy
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from enum import Enum from enum import Enum
from typing import Any, cast from typing import Any, cast
import copy
import httpx import httpx
from mcp import ClientSession, McpError from mcp import ClientSession, McpError
@ -32,20 +32,19 @@ logger = get_logger(__name__, category="tools")
protocol_cache = TTLDict(ttl_seconds=3600) protocol_cache = TTLDict(ttl_seconds=3600)
def resolve_json_schema_refs(schema: dict[str, Any]) -> dict[str, Any]: def resolve_json_schema_refs(schema: Any) -> Any:
""" """
Resolve JSON Schema $ref references using $defs. Resolve JSON Schema $ref references using $defs.
This function takes a JSON schema that may contain $ref and $defs, This function takes a JSON schema that may contain $ref and $defs,
and returns a new schema with all $ref references resolved inline. and returns a new schema with all $ref references resolved inline.
""" """
if not isinstance(schema, dict): if not isinstance(schema, dict):
return schema return schema
# Make a deep copy to avoid modifying the original # Make a deep copy to avoid modifying the original
resolved_schema = copy.deepcopy(schema) resolved_schema = copy.deepcopy(schema)
defs = resolved_schema.get("$defs", {}) defs = resolved_schema.get("$defs", {})
def resolve_refs(obj: Any) -> Any: def resolve_refs(obj: Any) -> Any:
"""Recursively resolve $ref references in the schema.""" """Recursively resolve $ref references in the schema."""
if isinstance(obj, dict): if isinstance(obj, dict):
@ -53,7 +52,7 @@ def resolve_json_schema_refs(schema: dict[str, Any]) -> dict[str, Any]:
# Extract the reference path # Extract the reference path
ref_path = obj["$ref"] ref_path = obj["$ref"]
if ref_path.startswith("#/$defs/"): if ref_path.startswith("#/$defs/"):
def_name = ref_path[len("#/$defs/"):] def_name = ref_path[len("#/$defs/") :]
if def_name in defs: if def_name in defs:
# Recursively resolve refs in the definition itself # Recursively resolve refs in the definition itself
resolved_def = resolve_refs(defs[def_name]) resolved_def = resolve_refs(defs[def_name])
@ -73,13 +72,13 @@ def resolve_json_schema_refs(schema: dict[str, Any]) -> dict[str, Any]:
else: else:
# Return primitive values as-is # Return primitive values as-is
return obj return obj
# Resolve all refs in the schema # Resolve all refs in the schema
resolved_schema = resolve_refs(resolved_schema) resolved_schema = resolve_refs(resolved_schema)
# Remove the $defs section as it's no longer needed # Remove the $defs section as it's no longer needed
resolved_schema.pop("$defs", None) resolved_schema.pop("$defs", None)
return resolved_schema return resolved_schema
@ -166,10 +165,10 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs
tools_result = await session.list_tools() tools_result = await session.list_tools()
for tool in tools_result.tools: for tool in tools_result.tools:
parameters = [] parameters = []
# Resolve $ref and $defs in the input schema # Resolve $ref and $defs in the input schema
resolved_input_schema = resolve_json_schema_refs(tool.inputSchema) resolved_input_schema = resolve_json_schema_refs(tool.inputSchema)
for param_name, param_schema in resolved_input_schema.get("properties", {}).items(): for param_name, param_schema in resolved_input_schema.get("properties", {}).items():
parameters.append( parameters.append(
ToolParameter( ToolParameter(

View file

@ -4,25 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import pytest from unittest.mock import AsyncMock, Mock, patch
from unittest.mock import AsyncMock, Mock, patch, MagicMock
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator
import httpx
from mcp import McpError
from mcp import types as mcp_types
from llama_stack.apis.tools import ListToolDefsResponse
from llama_stack.providers.utils.tools.mcp import ( from llama_stack.providers.utils.tools.mcp import (
resolve_json_schema_refs,
MCPProtol, MCPProtol,
client_wrapper,
list_mcp_tools, list_mcp_tools,
invoke_mcp_tool, resolve_json_schema_refs,
protocol_cache,
) )
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolParameter, ToolInvocationResult
from llama_stack.apis.common.content_types import TextContentItem, ImageContentItem
from llama_stack.core.datatypes import AuthenticationRequiredError
class TestResolveJsonSchemaRefs: class TestResolveJsonSchemaRefs:
@ -32,186 +21,103 @@ class TestResolveJsonSchemaRefs:
"""Test resolving a simple $ref reference.""" """Test resolving a simple $ref reference."""
schema = { schema = {
"type": "object", "type": "object",
"properties": { "properties": {"user": {"$ref": "#/$defs/User"}},
"user": {"$ref": "#/$defs/User"}
},
"$defs": { "$defs": {
"User": { "User": {"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}}
"type": "object", },
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"}
}
}
}
} }
result = resolve_json_schema_refs(schema) result = resolve_json_schema_refs(schema)
expected = { expected = {
"type": "object", "type": "object",
"properties": { "properties": {
"user": { "user": {"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}}
"type": "object", },
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"}
}
}
}
} }
assert result == expected assert result == expected
def test_resolve_nested_refs(self): def test_resolve_nested_refs(self):
"""Test resolving nested $ref references.""" """Test resolving nested $ref references."""
schema = { schema = {
"type": "object", "type": "object",
"properties": { "properties": {"data": {"$ref": "#/$defs/Container"}},
"data": {"$ref": "#/$defs/Container"}
},
"$defs": { "$defs": {
"Container": { "Container": {"type": "object", "properties": {"user": {"$ref": "#/$defs/User"}}},
"type": "object", "User": {"type": "object", "properties": {"name": {"type": "string"}}},
"properties": { },
"user": {"$ref": "#/$defs/User"}
}
},
"User": {
"type": "object",
"properties": {
"name": {"type": "string"}
}
}
}
} }
result = resolve_json_schema_refs(schema) result = resolve_json_schema_refs(schema)
expected = { expected = {
"type": "object", "type": "object",
"properties": { "properties": {
"data": { "data": {
"type": "object", "type": "object",
"properties": { "properties": {"user": {"type": "object", "properties": {"name": {"type": "string"}}}},
"user": {
"type": "object",
"properties": {
"name": {"type": "string"}
}
}
}
} }
} },
} }
assert result == expected assert result == expected
def test_resolve_refs_in_array(self): def test_resolve_refs_in_array(self):
"""Test resolving $ref references within arrays.""" """Test resolving $ref references within arrays."""
schema = { schema = {
"type": "object", "type": "object",
"properties": { "properties": {"items": {"type": "array", "items": {"$ref": "#/$defs/Item"}}},
"items": { "$defs": {"Item": {"type": "object", "properties": {"id": {"type": "string"}}}},
"type": "array",
"items": {"$ref": "#/$defs/Item"}
}
},
"$defs": {
"Item": {
"type": "object",
"properties": {
"id": {"type": "string"}
}
}
}
} }
result = resolve_json_schema_refs(schema) result = resolve_json_schema_refs(schema)
expected = { expected = {
"type": "object", "type": "object",
"properties": { "properties": {
"items": { "items": {"type": "array", "items": {"type": "object", "properties": {"id": {"type": "string"}}}}
"type": "array", },
"items": {
"type": "object",
"properties": {
"id": {"type": "string"}
}
}
}
}
} }
assert result == expected assert result == expected
def test_resolve_missing_ref(self): def test_resolve_missing_ref(self):
"""Test handling of missing $ref definition.""" """Test handling of missing $ref definition."""
schema = { schema = {
"type": "object", "type": "object",
"properties": { "properties": {"user": {"$ref": "#/$defs/MissingUser"}},
"user": {"$ref": "#/$defs/MissingUser"} "$defs": {"User": {"type": "object", "properties": {"name": {"type": "string"}}}},
},
"$defs": {
"User": {
"type": "object",
"properties": {
"name": {"type": "string"}
}
}
}
} }
with patch('llama_stack.providers.utils.tools.mcp.logger') as mock_logger: with patch("llama_stack.providers.utils.tools.mcp.logger") as mock_logger:
result = resolve_json_schema_refs(schema) result = resolve_json_schema_refs(schema)
mock_logger.warning.assert_called_once_with("Referenced definition 'MissingUser' not found in $defs") mock_logger.warning.assert_called_once_with("Referenced definition 'MissingUser' not found in $defs")
# Should return the original $ref unchanged # Should return the original $ref unchanged
expected = { expected = {"type": "object", "properties": {"user": {"$ref": "#/$defs/MissingUser"}}}
"type": "object",
"properties": {
"user": {"$ref": "#/$defs/MissingUser"}
}
}
assert result == expected assert result == expected
def test_resolve_unsupported_ref_format(self): def test_resolve_unsupported_ref_format(self):
"""Test handling of unsupported $ref format.""" """Test handling of unsupported $ref format."""
schema = { schema = {"type": "object", "properties": {"user": {"$ref": "http://example.com/schema"}}, "$defs": {}}
"type": "object",
"properties": { with patch("llama_stack.providers.utils.tools.mcp.logger") as mock_logger:
"user": {"$ref": "http://example.com/schema"}
},
"$defs": {}
}
with patch('llama_stack.providers.utils.tools.mcp.logger') as mock_logger:
result = resolve_json_schema_refs(schema) result = resolve_json_schema_refs(schema)
mock_logger.warning.assert_called_once_with("Unsupported $ref format: http://example.com/schema") mock_logger.warning.assert_called_once_with("Unsupported $ref format: http://example.com/schema")
# Should return the original $ref unchanged # Should return the original $ref unchanged
expected = { expected = {"type": "object", "properties": {"user": {"$ref": "http://example.com/schema"}}}
"type": "object",
"properties": {
"user": {"$ref": "http://example.com/schema"}
}
}
assert result == expected assert result == expected
def test_resolve_no_defs(self): def test_resolve_no_defs(self):
"""Test schema without $defs section.""" """Test schema without $defs section."""
schema = { schema = {"type": "object", "properties": {"name": {"type": "string"}}}
"type": "object",
"properties": {
"name": {"type": "string"}
}
}
result = resolve_json_schema_refs(schema) result = resolve_json_schema_refs(schema)
assert result == schema assert result == schema
def test_resolve_non_dict_input(self): def test_resolve_non_dict_input(self):
@ -225,22 +131,12 @@ class TestResolveJsonSchemaRefs:
"""Test that original schema is not modified.""" """Test that original schema is not modified."""
original_schema = { original_schema = {
"type": "object", "type": "object",
"properties": { "properties": {"user": {"$ref": "#/$defs/User"}},
"user": {"$ref": "#/$defs/User"} "$defs": {"User": {"type": "object", "properties": {"name": {"type": "string"}}}},
},
"$defs": {
"User": {
"type": "object",
"properties": {
"name": {"type": "string"}
}
}
}
} }
original_copy = original_schema.copy()
resolve_json_schema_refs(original_schema) resolve_json_schema_refs(original_schema)
# Original should be unchanged (but this is a shallow comparison) # Original should be unchanged (but this is a shallow comparison)
assert "$ref" in original_schema["properties"]["user"] assert "$ref" in original_schema["properties"]["user"]
assert "$defs" in original_schema assert "$defs" in original_schema
@ -259,12 +155,11 @@ class TestMCPProtocol:
class TestListMcpTools: class TestListMcpTools:
"""Test cases for list_mcp_tools function.""" """Test cases for list_mcp_tools function."""
@pytest.mark.asyncio
async def test_list_tools_success(self): async def test_list_tools_success(self):
"""Test successful listing of MCP tools.""" """Test successful listing of MCP tools."""
endpoint = "http://example.com/mcp" endpoint = "http://example.com/mcp"
headers = {"Authorization": "Bearer token"} headers = {"Authorization": "Bearer token"}
# Mock tool from MCP # Mock tool from MCP
mock_tool = Mock() mock_tool = Mock()
mock_tool.name = "test_tool" mock_tool.name = "test_tool"
@ -272,113 +167,97 @@ class TestListMcpTools:
mock_tool.inputSchema = { mock_tool.inputSchema = {
"type": "object", "type": "object",
"properties": { "properties": {
"param1": { "param1": {"type": "string", "description": "First parameter", "default": "default_value"},
"type": "string", "param2": {"type": "integer", "description": "Second parameter"},
"description": "First parameter", },
"default": "default_value"
},
"param2": {
"type": "integer",
"description": "Second parameter"
}
}
} }
mock_tools_result = Mock() mock_tools_result = Mock()
mock_tools_result.tools = [mock_tool] mock_tools_result.tools = [mock_tool]
mock_session = Mock() mock_session = Mock()
mock_session.list_tools = AsyncMock(return_value=mock_tools_result) mock_session.list_tools = AsyncMock(return_value=mock_tools_result)
with patch('llama_stack.providers.utils.tools.mcp.client_wrapper') as mock_wrapper: with patch("llama_stack.providers.utils.tools.mcp.client_wrapper") as mock_wrapper:
mock_wrapper.return_value.__aenter__ = AsyncMock(return_value=mock_session) mock_wrapper.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_wrapper.return_value.__aexit__ = AsyncMock() mock_wrapper.return_value.__aexit__ = AsyncMock()
result = await list_mcp_tools(endpoint, headers) result = await list_mcp_tools(endpoint, headers)
assert isinstance(result, ListToolDefsResponse) assert isinstance(result, ListToolDefsResponse)
assert len(result.data) == 1 assert len(result.data) == 1
tool_def = result.data[0] tool_def = result.data[0]
assert tool_def.name == "test_tool" assert tool_def.name == "test_tool"
assert tool_def.description == "A test tool" assert tool_def.description == "A test tool"
assert tool_def.metadata["endpoint"] == endpoint assert tool_def.metadata["endpoint"] == endpoint
# Check parameters # Check parameters
assert len(tool_def.parameters) == 2 assert len(tool_def.parameters) == 2
param1 = next(p for p in tool_def.parameters if p.name == "param1") param1 = next(p for p in tool_def.parameters if p.name == "param1")
assert param1.parameter_type == "string" assert param1.parameter_type == "string"
assert param1.description == "First parameter" assert param1.description == "First parameter"
assert param1.required is False # Has default value assert param1.required is False # Has default value
assert param1.default == "default_value" assert param1.default == "default_value"
param2 = next(p for p in tool_def.parameters if p.name == "param2") param2 = next(p for p in tool_def.parameters if p.name == "param2")
assert param2.parameter_type == "integer" assert param2.parameter_type == "integer"
assert param2.description == "Second parameter" assert param2.description == "Second parameter"
assert param2.required is True # No default value assert param2.required is True # No default value
@pytest.mark.asyncio
async def test_list_tools_with_schema_refs(self): async def test_list_tools_with_schema_refs(self):
"""Test listing tools with JSON Schema $refs.""" """Test listing tools with JSON Schema $refs."""
endpoint = "http://example.com/mcp" endpoint = "http://example.com/mcp"
headers = {} headers = {}
# Mock tool with $ref in schema # Mock tool with $ref in schema
mock_tool = Mock() mock_tool = Mock()
mock_tool.name = "ref_tool" mock_tool.name = "ref_tool"
mock_tool.description = "Tool with refs" mock_tool.description = "Tool with refs"
mock_tool.inputSchema = { mock_tool.inputSchema = {
"type": "object", "type": "object",
"properties": { "properties": {"user": {"$ref": "#/$defs/User"}},
"user": {"$ref": "#/$defs/User"}
},
"$defs": { "$defs": {
"User": { "User": {"type": "object", "properties": {"name": {"type": "string", "description": "User name"}}}
"type": "object", },
"properties": {
"name": {"type": "string", "description": "User name"}
}
}
}
} }
mock_tools_result = Mock() mock_tools_result = Mock()
mock_tools_result.tools = [mock_tool] mock_tools_result.tools = [mock_tool]
mock_session = Mock() mock_session = Mock()
mock_session.list_tools = AsyncMock(return_value=mock_tools_result) mock_session.list_tools = AsyncMock(return_value=mock_tools_result)
with patch('llama_stack.providers.utils.tools.mcp.client_wrapper') as mock_wrapper: with patch("llama_stack.providers.utils.tools.mcp.client_wrapper") as mock_wrapper:
mock_wrapper.return_value.__aenter__ = AsyncMock(return_value=mock_session) mock_wrapper.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_wrapper.return_value.__aexit__ = AsyncMock() mock_wrapper.return_value.__aexit__ = AsyncMock()
result = await list_mcp_tools(endpoint, headers) result = await list_mcp_tools(endpoint, headers)
# Should have resolved the $ref # Should have resolved the $ref
tool_def = result.data[0] tool_def = result.data[0]
assert len(tool_def.parameters) == 1 assert len(tool_def.parameters) == 1
# The user parameter should be flattened from the resolved $ref # The user parameter should be flattened from the resolved $ref
# Note: This depends on how the schema resolution works with nested objects # Note: This depends on how the schema resolution works with nested objects
@pytest.mark.asyncio
async def test_list_tools_empty_result(self): async def test_list_tools_empty_result(self):
"""Test listing tools when no tools are available.""" """Test listing tools when no tools are available."""
endpoint = "http://example.com/mcp" endpoint = "http://example.com/mcp"
headers = {} headers = {}
mock_tools_result = Mock() mock_tools_result = Mock()
mock_tools_result.tools = [] mock_tools_result.tools = []
mock_session = Mock() mock_session = Mock()
mock_session.list_tools = AsyncMock(return_value=mock_tools_result) mock_session.list_tools = AsyncMock(return_value=mock_tools_result)
with patch('llama_stack.providers.utils.tools.mcp.client_wrapper') as mock_wrapper: with patch("llama_stack.providers.utils.tools.mcp.client_wrapper") as mock_wrapper:
mock_wrapper.return_value.__aenter__ = AsyncMock(return_value=mock_session) mock_wrapper.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_wrapper.return_value.__aexit__ = AsyncMock() mock_wrapper.return_value.__aexit__ = AsyncMock()
result = await list_mcp_tools(endpoint, headers) result = await list_mcp_tools(endpoint, headers)
assert isinstance(result, ListToolDefsResponse) assert isinstance(result, ListToolDefsResponse)
assert len(result.data) == 0 assert len(result.data) == 0