change mcp.py

This commit is contained in:
Kai Wu 2025-09-29 13:31:23 -07:00
parent aac42ddcc2
commit 38f79e6abe

View file

@ -8,6 +8,7 @@ from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from enum import Enum
from typing import Any, cast
import copy
import httpx
from mcp import ClientSession, McpError
@ -31,6 +32,57 @@ logger = get_logger(__name__, category="tools")
protocol_cache = TTLDict(ttl_seconds=3600)
def resolve_json_schema_refs(schema: dict[str, Any]) -> dict[str, Any]:
"""
Resolve JSON Schema $ref references using $defs.
This function takes a JSON schema that may contain $ref and $defs,
and returns a new schema with all $ref references resolved inline.
"""
if not isinstance(schema, dict):
return schema
# Make a deep copy to avoid modifying the original
resolved_schema = copy.deepcopy(schema)
defs = resolved_schema.get("$defs", {})
def resolve_refs(obj: Any) -> Any:
"""Recursively resolve $ref references in the schema."""
if isinstance(obj, dict):
if "$ref" in obj:
# Extract the reference path
ref_path = obj["$ref"]
if ref_path.startswith("#/$defs/"):
def_name = ref_path[len("#/$defs/"):]
if def_name in defs:
# Recursively resolve refs in the definition itself
resolved_def = resolve_refs(defs[def_name])
return resolved_def
else:
logger.warning(f"Referenced definition '{def_name}' not found in $defs")
return obj
else:
logger.warning(f"Unsupported $ref format: {ref_path}")
return obj
else:
# Recursively process all values in the dict
return {key: resolve_refs(value) for key, value in obj.items()}
elif isinstance(obj, list):
# Recursively process all items in the list
return [resolve_refs(item) for item in obj]
else:
# Return primitive values as-is
return obj
# Resolve all refs in the schema
resolved_schema = resolve_refs(resolved_schema)
# Remove the $defs section as it's no longer needed
resolved_schema.pop("$defs", None)
return resolved_schema
class MCPProtol(Enum):
UNKNOWN = 0
STREAMABLE_HTTP = 1
@ -114,7 +166,11 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs
tools_result = await session.list_tools()
for tool in tools_result.tools:
parameters = []
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
# Resolve $ref and $defs in the input schema
resolved_input_schema = resolve_json_schema_refs(tool.inputSchema)
for param_name, param_schema in resolved_input_schema.get("properties", {}).items():
parameters.append(
ToolParameter(
name=param_name,