mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
precommit fix
This commit is contained in:
parent
9adabb09cf
commit
774e4311ad
2 changed files with 88 additions and 210 deletions
|
@ -4,11 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from enum import Enum
|
||||
from typing import Any, cast
|
||||
import copy
|
||||
|
||||
import httpx
|
||||
from mcp import ClientSession, McpError
|
||||
|
@ -32,20 +32,19 @@ logger = get_logger(__name__, category="tools")
|
|||
protocol_cache = TTLDict(ttl_seconds=3600)
|
||||
|
||||
|
||||
def resolve_json_schema_refs(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
def resolve_json_schema_refs(schema: Any) -> Any:
|
||||
"""
|
||||
Resolve JSON Schema $ref references using $defs.
|
||||
|
||||
This function takes a JSON schema that may contain $ref and $defs,
|
||||
and returns a new schema with all $ref references resolved inline.
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
return schema
|
||||
|
||||
|
||||
# Make a deep copy to avoid modifying the original
|
||||
resolved_schema = copy.deepcopy(schema)
|
||||
defs = resolved_schema.get("$defs", {})
|
||||
|
||||
|
||||
def resolve_refs(obj: Any) -> Any:
|
||||
"""Recursively resolve $ref references in the schema."""
|
||||
if isinstance(obj, dict):
|
||||
|
@ -53,7 +52,7 @@ def resolve_json_schema_refs(schema: dict[str, Any]) -> dict[str, Any]:
|
|||
# Extract the reference path
|
||||
ref_path = obj["$ref"]
|
||||
if ref_path.startswith("#/$defs/"):
|
||||
def_name = ref_path[len("#/$defs/"):]
|
||||
def_name = ref_path[len("#/$defs/") :]
|
||||
if def_name in defs:
|
||||
# Recursively resolve refs in the definition itself
|
||||
resolved_def = resolve_refs(defs[def_name])
|
||||
|
@ -73,13 +72,13 @@ def resolve_json_schema_refs(schema: dict[str, Any]) -> dict[str, Any]:
|
|||
else:
|
||||
# Return primitive values as-is
|
||||
return obj
|
||||
|
||||
|
||||
# Resolve all refs in the schema
|
||||
resolved_schema = resolve_refs(resolved_schema)
|
||||
|
||||
|
||||
# Remove the $defs section as it's no longer needed
|
||||
resolved_schema.pop("$defs", None)
|
||||
|
||||
|
||||
return resolved_schema
|
||||
|
||||
|
||||
|
@ -166,10 +165,10 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs
|
|||
tools_result = await session.list_tools()
|
||||
for tool in tools_result.tools:
|
||||
parameters = []
|
||||
|
||||
|
||||
# Resolve $ref and $defs in the input schema
|
||||
resolved_input_schema = resolve_json_schema_refs(tool.inputSchema)
|
||||
|
||||
|
||||
for param_name, param_schema in resolved_input_schema.get("properties", {}).items():
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue