mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
change mcp.py
This commit is contained in:
parent
aac42ddcc2
commit
38f79e6abe
1 changed files with 57 additions and 1 deletions
|
@ -8,6 +8,7 @@ from collections.abc import AsyncGenerator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
import copy
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from mcp import ClientSession, McpError
|
from mcp import ClientSession, McpError
|
||||||
|
@ -31,6 +32,57 @@ logger = get_logger(__name__, category="tools")
|
||||||
protocol_cache = TTLDict(ttl_seconds=3600)
|
protocol_cache = TTLDict(ttl_seconds=3600)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_json_schema_refs(schema: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Resolve JSON Schema $ref references using $defs.
|
||||||
|
|
||||||
|
This function takes a JSON schema that may contain $ref and $defs,
|
||||||
|
and returns a new schema with all $ref references resolved inline.
|
||||||
|
"""
|
||||||
|
if not isinstance(schema, dict):
|
||||||
|
return schema
|
||||||
|
|
||||||
|
# Make a deep copy to avoid modifying the original
|
||||||
|
resolved_schema = copy.deepcopy(schema)
|
||||||
|
defs = resolved_schema.get("$defs", {})
|
||||||
|
|
||||||
|
def resolve_refs(obj: Any) -> Any:
|
||||||
|
"""Recursively resolve $ref references in the schema."""
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
if "$ref" in obj:
|
||||||
|
# Extract the reference path
|
||||||
|
ref_path = obj["$ref"]
|
||||||
|
if ref_path.startswith("#/$defs/"):
|
||||||
|
def_name = ref_path[len("#/$defs/"):]
|
||||||
|
if def_name in defs:
|
||||||
|
# Recursively resolve refs in the definition itself
|
||||||
|
resolved_def = resolve_refs(defs[def_name])
|
||||||
|
return resolved_def
|
||||||
|
else:
|
||||||
|
logger.warning(f"Referenced definition '{def_name}' not found in $defs")
|
||||||
|
return obj
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unsupported $ref format: {ref_path}")
|
||||||
|
return obj
|
||||||
|
else:
|
||||||
|
# Recursively process all values in the dict
|
||||||
|
return {key: resolve_refs(value) for key, value in obj.items()}
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
# Recursively process all items in the list
|
||||||
|
return [resolve_refs(item) for item in obj]
|
||||||
|
else:
|
||||||
|
# Return primitive values as-is
|
||||||
|
return obj
|
||||||
|
|
||||||
|
# Resolve all refs in the schema
|
||||||
|
resolved_schema = resolve_refs(resolved_schema)
|
||||||
|
|
||||||
|
# Remove the $defs section as it's no longer needed
|
||||||
|
resolved_schema.pop("$defs", None)
|
||||||
|
|
||||||
|
return resolved_schema
|
||||||
|
|
||||||
|
|
||||||
class MCPProtol(Enum):
|
class MCPProtol(Enum):
|
||||||
UNKNOWN = 0
|
UNKNOWN = 0
|
||||||
STREAMABLE_HTTP = 1
|
STREAMABLE_HTTP = 1
|
||||||
|
@ -114,7 +166,11 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs
|
||||||
tools_result = await session.list_tools()
|
tools_result = await session.list_tools()
|
||||||
for tool in tools_result.tools:
|
for tool in tools_result.tools:
|
||||||
parameters = []
|
parameters = []
|
||||||
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
|
|
||||||
|
# Resolve $ref and $defs in the input schema
|
||||||
|
resolved_input_schema = resolve_json_schema_refs(tool.inputSchema)
|
||||||
|
|
||||||
|
for param_name, param_schema in resolved_input_schema.get("properties", {}).items():
|
||||||
parameters.append(
|
parameters.append(
|
||||||
ToolParameter(
|
ToolParameter(
|
||||||
name=param_name,
|
name=param_name,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue