mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
change mcp.py
This commit is contained in:
parent
aac42ddcc2
commit
38f79e6abe
1 changed files with 57 additions and 1 deletions
|
@ -8,6 +8,7 @@ from collections.abc import AsyncGenerator
|
|||
from contextlib import asynccontextmanager
|
||||
from enum import Enum
|
||||
from typing import Any, cast
|
||||
import copy
|
||||
|
||||
import httpx
|
||||
from mcp import ClientSession, McpError
|
||||
|
@ -31,6 +32,57 @@ logger = get_logger(__name__, category="tools")
|
|||
protocol_cache = TTLDict(ttl_seconds=3600)
|
||||
|
||||
|
||||
def resolve_json_schema_refs(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Resolve JSON Schema $ref references using $defs.
|
||||
|
||||
This function takes a JSON schema that may contain $ref and $defs,
|
||||
and returns a new schema with all $ref references resolved inline.
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
return schema
|
||||
|
||||
# Make a deep copy to avoid modifying the original
|
||||
resolved_schema = copy.deepcopy(schema)
|
||||
defs = resolved_schema.get("$defs", {})
|
||||
|
||||
def resolve_refs(obj: Any) -> Any:
|
||||
"""Recursively resolve $ref references in the schema."""
|
||||
if isinstance(obj, dict):
|
||||
if "$ref" in obj:
|
||||
# Extract the reference path
|
||||
ref_path = obj["$ref"]
|
||||
if ref_path.startswith("#/$defs/"):
|
||||
def_name = ref_path[len("#/$defs/"):]
|
||||
if def_name in defs:
|
||||
# Recursively resolve refs in the definition itself
|
||||
resolved_def = resolve_refs(defs[def_name])
|
||||
return resolved_def
|
||||
else:
|
||||
logger.warning(f"Referenced definition '{def_name}' not found in $defs")
|
||||
return obj
|
||||
else:
|
||||
logger.warning(f"Unsupported $ref format: {ref_path}")
|
||||
return obj
|
||||
else:
|
||||
# Recursively process all values in the dict
|
||||
return {key: resolve_refs(value) for key, value in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
# Recursively process all items in the list
|
||||
return [resolve_refs(item) for item in obj]
|
||||
else:
|
||||
# Return primitive values as-is
|
||||
return obj
|
||||
|
||||
# Resolve all refs in the schema
|
||||
resolved_schema = resolve_refs(resolved_schema)
|
||||
|
||||
# Remove the $defs section as it's no longer needed
|
||||
resolved_schema.pop("$defs", None)
|
||||
|
||||
return resolved_schema
|
||||
|
||||
|
||||
class MCPProtol(Enum):
|
||||
UNKNOWN = 0
|
||||
STREAMABLE_HTTP = 1
|
||||
|
@ -114,7 +166,11 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs
|
|||
tools_result = await session.list_tools()
|
||||
for tool in tools_result.tools:
|
||||
parameters = []
|
||||
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
|
||||
|
||||
# Resolve $ref and $defs in the input schema
|
||||
resolved_input_schema = resolve_json_schema_refs(tool.inputSchema)
|
||||
|
||||
for param_name, param_schema in resolved_input_schema.get("properties", {}).items():
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
name=param_name,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue