change mcp.py

This commit is contained in:
Kai Wu 2025-09-29 13:31:23 -07:00
parent aac42ddcc2
commit 38f79e6abe

View file

@ -8,6 +8,7 @@ from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from enum import Enum from enum import Enum
from typing import Any, cast from typing import Any, cast
import copy
import httpx import httpx
from mcp import ClientSession, McpError from mcp import ClientSession, McpError
@ -31,6 +32,57 @@ logger = get_logger(__name__, category="tools")
protocol_cache = TTLDict(ttl_seconds=3600) protocol_cache = TTLDict(ttl_seconds=3600)
def resolve_json_schema_refs(schema: dict[str, Any]) -> dict[str, Any]:
"""
Resolve JSON Schema $ref references using $defs.
This function takes a JSON schema that may contain $ref and $defs,
and returns a new schema with all $ref references resolved inline.
"""
if not isinstance(schema, dict):
return schema
# Make a deep copy to avoid modifying the original
resolved_schema = copy.deepcopy(schema)
defs = resolved_schema.get("$defs", {})
def resolve_refs(obj: Any) -> Any:
"""Recursively resolve $ref references in the schema."""
if isinstance(obj, dict):
if "$ref" in obj:
# Extract the reference path
ref_path = obj["$ref"]
if ref_path.startswith("#/$defs/"):
def_name = ref_path[len("#/$defs/"):]
if def_name in defs:
# Recursively resolve refs in the definition itself
resolved_def = resolve_refs(defs[def_name])
return resolved_def
else:
logger.warning(f"Referenced definition '{def_name}' not found in $defs")
return obj
else:
logger.warning(f"Unsupported $ref format: {ref_path}")
return obj
else:
# Recursively process all values in the dict
return {key: resolve_refs(value) for key, value in obj.items()}
elif isinstance(obj, list):
# Recursively process all items in the list
return [resolve_refs(item) for item in obj]
else:
# Return primitive values as-is
return obj
# Resolve all refs in the schema
resolved_schema = resolve_refs(resolved_schema)
# Remove the $defs section as it's no longer needed
resolved_schema.pop("$defs", None)
return resolved_schema
class MCPProtol(Enum): class MCPProtol(Enum):
UNKNOWN = 0 UNKNOWN = 0
STREAMABLE_HTTP = 1 STREAMABLE_HTTP = 1
@ -114,7 +166,11 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs
tools_result = await session.list_tools() tools_result = await session.list_tools()
for tool in tools_result.tools: for tool in tools_result.tools:
parameters = [] parameters = []
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
# Resolve $ref and $defs in the input schema
resolved_input_schema = resolve_json_schema_refs(tool.inputSchema)
for param_name, param_schema in resolved_input_schema.get("properties", {}).items():
parameters.append( parameters.append(
ToolParameter( ToolParameter(
name=param_name, name=param_name,