diff --git a/llama_stack/providers/utils/tools/mcp.py b/llama_stack/providers/utils/tools/mcp.py index 155f7eff8..d182f25ae 100644 --- a/llama_stack/providers/utils/tools/mcp.py +++ b/llama_stack/providers/utils/tools/mcp.py @@ -8,6 +8,7 @@ from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from enum import Enum from typing import Any, cast +import copy import httpx from mcp import ClientSession, McpError @@ -31,6 +32,57 @@ logger = get_logger(__name__, category="tools") protocol_cache = TTLDict(ttl_seconds=3600) +def resolve_json_schema_refs(schema: dict[str, Any]) -> dict[str, Any]: + """ + Resolve JSON Schema $ref references using $defs. + + This function takes a JSON schema that may contain $ref and $defs, + and returns a new schema with all $ref references resolved inline. + """ + if not isinstance(schema, dict): + return schema + + # Make a deep copy to avoid modifying the original + resolved_schema = copy.deepcopy(schema) + defs = resolved_schema.get("$defs", {}) + + def resolve_refs(obj: Any) -> Any: + """Recursively resolve $ref references in the schema.""" + if isinstance(obj, dict): + if "$ref" in obj: + # Extract the reference path + ref_path = obj["$ref"] + if ref_path.startswith("#/$defs/"): + def_name = ref_path[len("#/$defs/"):] + if def_name in defs: + # Recursively resolve refs in the definition itself + resolved_def = resolve_refs(defs[def_name]) + return resolved_def + else: + logger.warning(f"Referenced definition '{def_name}' not found in $defs") + return obj + else: + logger.warning(f"Unsupported $ref format: {ref_path}") + return obj + else: + # Recursively process all values in the dict + return {key: resolve_refs(value) for key, value in obj.items()} + elif isinstance(obj, list): + # Recursively process all items in the list + return [resolve_refs(item) for item in obj] + else: + # Return primitive values as-is + return obj + + # Resolve all refs in the schema + resolved_schema = resolve_refs(resolved_schema) + + # Remove the $defs section as it's no longer needed + resolved_schema.pop("$defs", None) + + return resolved_schema + + class MCPProtol(Enum): UNKNOWN = 0 STREAMABLE_HTTP = 1 @@ -114,7 +166,11 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs tools_result = await session.list_tools() for tool in tools_result.tools: parameters = [] - for param_name, param_schema in tool.inputSchema.get("properties", {}).items(): + + # Resolve $ref and $defs in the input schema + resolved_input_schema = resolve_json_schema_refs(tool.inputSchema) + + for param_name, param_schema in resolved_input_schema.get("properties", {}).items(): parameters.append( ToolParameter( name=param_name,