precommit fix

This commit is contained in:
Kai Wu 2025-09-30 09:22:58 -07:00
parent 9adabb09cf
commit 774e4311ad
2 changed files with 88 additions and 210 deletions

View file

@ -4,11 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import copy
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from enum import Enum
from typing import Any, cast
import copy
import httpx
from mcp import ClientSession, McpError
@ -32,20 +32,19 @@ logger = get_logger(__name__, category="tools")
protocol_cache = TTLDict(ttl_seconds=3600)
def resolve_json_schema_refs(schema: dict[str, Any]) -> dict[str, Any]:
def resolve_json_schema_refs(schema: Any) -> Any:
"""
Resolve JSON Schema $ref references using $defs.
This function takes a JSON schema that may contain $ref and $defs,
and returns a new schema with all $ref references resolved inline.
"""
if not isinstance(schema, dict):
return schema
# Make a deep copy to avoid modifying the original
resolved_schema = copy.deepcopy(schema)
defs = resolved_schema.get("$defs", {})
def resolve_refs(obj: Any) -> Any:
"""Recursively resolve $ref references in the schema."""
if isinstance(obj, dict):
@ -53,7 +52,7 @@ def resolve_json_schema_refs(schema: dict[str, Any]) -> dict[str, Any]:
# Extract the reference path
ref_path = obj["$ref"]
if ref_path.startswith("#/$defs/"):
def_name = ref_path[len("#/$defs/"):]
def_name = ref_path[len("#/$defs/") :]
if def_name in defs:
# Recursively resolve refs in the definition itself
resolved_def = resolve_refs(defs[def_name])
@ -73,13 +72,13 @@ def resolve_json_schema_refs(schema: dict[str, Any]) -> dict[str, Any]:
else:
# Return primitive values as-is
return obj
# Resolve all refs in the schema
resolved_schema = resolve_refs(resolved_schema)
# Remove the $defs section as it's no longer needed
resolved_schema.pop("$defs", None)
return resolved_schema
@ -166,10 +165,10 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs
tools_result = await session.list_tools()
for tool in tools_result.tools:
parameters = []
# Resolve $ref and $defs in the input schema
resolved_input_schema = resolve_json_schema_refs(tool.inputSchema)
for param_name, param_schema in resolved_input_schema.get("properties", {}).items():
parameters.append(
ToolParameter(