mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
precommit fix
This commit is contained in:
parent
9adabb09cf
commit
774e4311ad
2 changed files with 88 additions and 210 deletions
|
@ -4,11 +4,11 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import copy
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
import copy
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from mcp import ClientSession, McpError
|
from mcp import ClientSession, McpError
|
||||||
|
@ -32,20 +32,19 @@ logger = get_logger(__name__, category="tools")
|
||||||
protocol_cache = TTLDict(ttl_seconds=3600)
|
protocol_cache = TTLDict(ttl_seconds=3600)
|
||||||
|
|
||||||
|
|
||||||
def resolve_json_schema_refs(schema: dict[str, Any]) -> dict[str, Any]:
|
def resolve_json_schema_refs(schema: Any) -> Any:
|
||||||
"""
|
"""
|
||||||
Resolve JSON Schema $ref references using $defs.
|
Resolve JSON Schema $ref references using $defs.
|
||||||
|
|
||||||
This function takes a JSON schema that may contain $ref and $defs,
|
This function takes a JSON schema that may contain $ref and $defs,
|
||||||
and returns a new schema with all $ref references resolved inline.
|
and returns a new schema with all $ref references resolved inline.
|
||||||
"""
|
"""
|
||||||
if not isinstance(schema, dict):
|
if not isinstance(schema, dict):
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
# Make a deep copy to avoid modifying the original
|
# Make a deep copy to avoid modifying the original
|
||||||
resolved_schema = copy.deepcopy(schema)
|
resolved_schema = copy.deepcopy(schema)
|
||||||
defs = resolved_schema.get("$defs", {})
|
defs = resolved_schema.get("$defs", {})
|
||||||
|
|
||||||
def resolve_refs(obj: Any) -> Any:
|
def resolve_refs(obj: Any) -> Any:
|
||||||
"""Recursively resolve $ref references in the schema."""
|
"""Recursively resolve $ref references in the schema."""
|
||||||
if isinstance(obj, dict):
|
if isinstance(obj, dict):
|
||||||
|
@ -53,7 +52,7 @@ def resolve_json_schema_refs(schema: dict[str, Any]) -> dict[str, Any]:
|
||||||
# Extract the reference path
|
# Extract the reference path
|
||||||
ref_path = obj["$ref"]
|
ref_path = obj["$ref"]
|
||||||
if ref_path.startswith("#/$defs/"):
|
if ref_path.startswith("#/$defs/"):
|
||||||
def_name = ref_path[len("#/$defs/"):]
|
def_name = ref_path[len("#/$defs/") :]
|
||||||
if def_name in defs:
|
if def_name in defs:
|
||||||
# Recursively resolve refs in the definition itself
|
# Recursively resolve refs in the definition itself
|
||||||
resolved_def = resolve_refs(defs[def_name])
|
resolved_def = resolve_refs(defs[def_name])
|
||||||
|
@ -73,13 +72,13 @@ def resolve_json_schema_refs(schema: dict[str, Any]) -> dict[str, Any]:
|
||||||
else:
|
else:
|
||||||
# Return primitive values as-is
|
# Return primitive values as-is
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
# Resolve all refs in the schema
|
# Resolve all refs in the schema
|
||||||
resolved_schema = resolve_refs(resolved_schema)
|
resolved_schema = resolve_refs(resolved_schema)
|
||||||
|
|
||||||
# Remove the $defs section as it's no longer needed
|
# Remove the $defs section as it's no longer needed
|
||||||
resolved_schema.pop("$defs", None)
|
resolved_schema.pop("$defs", None)
|
||||||
|
|
||||||
return resolved_schema
|
return resolved_schema
|
||||||
|
|
||||||
|
|
||||||
|
@ -166,10 +165,10 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs
|
||||||
tools_result = await session.list_tools()
|
tools_result = await session.list_tools()
|
||||||
for tool in tools_result.tools:
|
for tool in tools_result.tools:
|
||||||
parameters = []
|
parameters = []
|
||||||
|
|
||||||
# Resolve $ref and $defs in the input schema
|
# Resolve $ref and $defs in the input schema
|
||||||
resolved_input_schema = resolve_json_schema_refs(tool.inputSchema)
|
resolved_input_schema = resolve_json_schema_refs(tool.inputSchema)
|
||||||
|
|
||||||
for param_name, param_schema in resolved_input_schema.get("properties", {}).items():
|
for param_name, param_schema in resolved_input_schema.get("properties", {}).items():
|
||||||
parameters.append(
|
parameters.append(
|
||||||
ToolParameter(
|
ToolParameter(
|
||||||
|
|
|
@ -4,25 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import pytest
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
from unittest.mock import AsyncMock, Mock, patch, MagicMock
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from typing import Any, AsyncGenerator
|
|
||||||
import httpx
|
|
||||||
from mcp import McpError
|
|
||||||
from mcp import types as mcp_types
|
|
||||||
|
|
||||||
|
from llama_stack.apis.tools import ListToolDefsResponse
|
||||||
from llama_stack.providers.utils.tools.mcp import (
|
from llama_stack.providers.utils.tools.mcp import (
|
||||||
resolve_json_schema_refs,
|
|
||||||
MCPProtol,
|
MCPProtol,
|
||||||
client_wrapper,
|
|
||||||
list_mcp_tools,
|
list_mcp_tools,
|
||||||
invoke_mcp_tool,
|
resolve_json_schema_refs,
|
||||||
protocol_cache,
|
|
||||||
)
|
)
|
||||||
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolParameter, ToolInvocationResult
|
|
||||||
from llama_stack.apis.common.content_types import TextContentItem, ImageContentItem
|
|
||||||
from llama_stack.core.datatypes import AuthenticationRequiredError
|
|
||||||
|
|
||||||
|
|
||||||
class TestResolveJsonSchemaRefs:
|
class TestResolveJsonSchemaRefs:
|
||||||
|
@ -32,186 +21,103 @@ class TestResolveJsonSchemaRefs:
|
||||||
"""Test resolving a simple $ref reference."""
|
"""Test resolving a simple $ref reference."""
|
||||||
schema = {
|
schema = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {"user": {"$ref": "#/$defs/User"}},
|
||||||
"user": {"$ref": "#/$defs/User"}
|
|
||||||
},
|
|
||||||
"$defs": {
|
"$defs": {
|
||||||
"User": {
|
"User": {"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}}
|
||||||
"type": "object",
|
},
|
||||||
"properties": {
|
|
||||||
"name": {"type": "string"},
|
|
||||||
"age": {"type": "integer"}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
result = resolve_json_schema_refs(schema)
|
result = resolve_json_schema_refs(schema)
|
||||||
|
|
||||||
expected = {
|
expected = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"user": {
|
"user": {"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}}
|
||||||
"type": "object",
|
},
|
||||||
"properties": {
|
|
||||||
"name": {"type": "string"},
|
|
||||||
"age": {"type": "integer"}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
assert result == expected
|
assert result == expected
|
||||||
|
|
||||||
def test_resolve_nested_refs(self):
|
def test_resolve_nested_refs(self):
|
||||||
"""Test resolving nested $ref references."""
|
"""Test resolving nested $ref references."""
|
||||||
schema = {
|
schema = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {"data": {"$ref": "#/$defs/Container"}},
|
||||||
"data": {"$ref": "#/$defs/Container"}
|
|
||||||
},
|
|
||||||
"$defs": {
|
"$defs": {
|
||||||
"Container": {
|
"Container": {"type": "object", "properties": {"user": {"$ref": "#/$defs/User"}}},
|
||||||
"type": "object",
|
"User": {"type": "object", "properties": {"name": {"type": "string"}}},
|
||||||
"properties": {
|
},
|
||||||
"user": {"$ref": "#/$defs/User"}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"User": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"name": {"type": "string"}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
result = resolve_json_schema_refs(schema)
|
result = resolve_json_schema_refs(schema)
|
||||||
|
|
||||||
expected = {
|
expected = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"data": {
|
"data": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {"user": {"type": "object", "properties": {"name": {"type": "string"}}}},
|
||||||
"user": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"name": {"type": "string"}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
assert result == expected
|
assert result == expected
|
||||||
|
|
||||||
def test_resolve_refs_in_array(self):
|
def test_resolve_refs_in_array(self):
|
||||||
"""Test resolving $ref references within arrays."""
|
"""Test resolving $ref references within arrays."""
|
||||||
schema = {
|
schema = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {"items": {"type": "array", "items": {"$ref": "#/$defs/Item"}}},
|
||||||
"items": {
|
"$defs": {"Item": {"type": "object", "properties": {"id": {"type": "string"}}}},
|
||||||
"type": "array",
|
|
||||||
"items": {"$ref": "#/$defs/Item"}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"$defs": {
|
|
||||||
"Item": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"id": {"type": "string"}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
result = resolve_json_schema_refs(schema)
|
result = resolve_json_schema_refs(schema)
|
||||||
|
|
||||||
expected = {
|
expected = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"items": {
|
"items": {"type": "array", "items": {"type": "object", "properties": {"id": {"type": "string"}}}}
|
||||||
"type": "array",
|
},
|
||||||
"items": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"id": {"type": "string"}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
assert result == expected
|
assert result == expected
|
||||||
|
|
||||||
def test_resolve_missing_ref(self):
|
def test_resolve_missing_ref(self):
|
||||||
"""Test handling of missing $ref definition."""
|
"""Test handling of missing $ref definition."""
|
||||||
schema = {
|
schema = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {"user": {"$ref": "#/$defs/MissingUser"}},
|
||||||
"user": {"$ref": "#/$defs/MissingUser"}
|
"$defs": {"User": {"type": "object", "properties": {"name": {"type": "string"}}}},
|
||||||
},
|
|
||||||
"$defs": {
|
|
||||||
"User": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"name": {"type": "string"}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
with patch('llama_stack.providers.utils.tools.mcp.logger') as mock_logger:
|
with patch("llama_stack.providers.utils.tools.mcp.logger") as mock_logger:
|
||||||
result = resolve_json_schema_refs(schema)
|
result = resolve_json_schema_refs(schema)
|
||||||
mock_logger.warning.assert_called_once_with("Referenced definition 'MissingUser' not found in $defs")
|
mock_logger.warning.assert_called_once_with("Referenced definition 'MissingUser' not found in $defs")
|
||||||
|
|
||||||
# Should return the original $ref unchanged
|
# Should return the original $ref unchanged
|
||||||
expected = {
|
expected = {"type": "object", "properties": {"user": {"$ref": "#/$defs/MissingUser"}}}
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"user": {"$ref": "#/$defs/MissingUser"}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
assert result == expected
|
assert result == expected
|
||||||
|
|
||||||
def test_resolve_unsupported_ref_format(self):
|
def test_resolve_unsupported_ref_format(self):
|
||||||
"""Test handling of unsupported $ref format."""
|
"""Test handling of unsupported $ref format."""
|
||||||
schema = {
|
schema = {"type": "object", "properties": {"user": {"$ref": "http://example.com/schema"}}, "$defs": {}}
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
with patch("llama_stack.providers.utils.tools.mcp.logger") as mock_logger:
|
||||||
"user": {"$ref": "http://example.com/schema"}
|
|
||||||
},
|
|
||||||
"$defs": {}
|
|
||||||
}
|
|
||||||
|
|
||||||
with patch('llama_stack.providers.utils.tools.mcp.logger') as mock_logger:
|
|
||||||
result = resolve_json_schema_refs(schema)
|
result = resolve_json_schema_refs(schema)
|
||||||
mock_logger.warning.assert_called_once_with("Unsupported $ref format: http://example.com/schema")
|
mock_logger.warning.assert_called_once_with("Unsupported $ref format: http://example.com/schema")
|
||||||
|
|
||||||
# Should return the original $ref unchanged
|
# Should return the original $ref unchanged
|
||||||
expected = {
|
expected = {"type": "object", "properties": {"user": {"$ref": "http://example.com/schema"}}}
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"user": {"$ref": "http://example.com/schema"}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
assert result == expected
|
assert result == expected
|
||||||
|
|
||||||
def test_resolve_no_defs(self):
|
def test_resolve_no_defs(self):
|
||||||
"""Test schema without $defs section."""
|
"""Test schema without $defs section."""
|
||||||
schema = {
|
schema = {"type": "object", "properties": {"name": {"type": "string"}}}
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"name": {"type": "string"}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
result = resolve_json_schema_refs(schema)
|
result = resolve_json_schema_refs(schema)
|
||||||
|
|
||||||
assert result == schema
|
assert result == schema
|
||||||
|
|
||||||
def test_resolve_non_dict_input(self):
|
def test_resolve_non_dict_input(self):
|
||||||
|
@ -225,22 +131,12 @@ class TestResolveJsonSchemaRefs:
|
||||||
"""Test that original schema is not modified."""
|
"""Test that original schema is not modified."""
|
||||||
original_schema = {
|
original_schema = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {"user": {"$ref": "#/$defs/User"}},
|
||||||
"user": {"$ref": "#/$defs/User"}
|
"$defs": {"User": {"type": "object", "properties": {"name": {"type": "string"}}}},
|
||||||
},
|
|
||||||
"$defs": {
|
|
||||||
"User": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"name": {"type": "string"}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
original_copy = original_schema.copy()
|
|
||||||
resolve_json_schema_refs(original_schema)
|
resolve_json_schema_refs(original_schema)
|
||||||
|
|
||||||
# Original should be unchanged (but this is a shallow comparison)
|
# Original should be unchanged (but this is a shallow comparison)
|
||||||
assert "$ref" in original_schema["properties"]["user"]
|
assert "$ref" in original_schema["properties"]["user"]
|
||||||
assert "$defs" in original_schema
|
assert "$defs" in original_schema
|
||||||
|
@ -259,12 +155,11 @@ class TestMCPProtocol:
|
||||||
class TestListMcpTools:
|
class TestListMcpTools:
|
||||||
"""Test cases for list_mcp_tools function."""
|
"""Test cases for list_mcp_tools function."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_tools_success(self):
|
async def test_list_tools_success(self):
|
||||||
"""Test successful listing of MCP tools."""
|
"""Test successful listing of MCP tools."""
|
||||||
endpoint = "http://example.com/mcp"
|
endpoint = "http://example.com/mcp"
|
||||||
headers = {"Authorization": "Bearer token"}
|
headers = {"Authorization": "Bearer token"}
|
||||||
|
|
||||||
# Mock tool from MCP
|
# Mock tool from MCP
|
||||||
mock_tool = Mock()
|
mock_tool = Mock()
|
||||||
mock_tool.name = "test_tool"
|
mock_tool.name = "test_tool"
|
||||||
|
@ -272,113 +167,97 @@ class TestListMcpTools:
|
||||||
mock_tool.inputSchema = {
|
mock_tool.inputSchema = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"param1": {
|
"param1": {"type": "string", "description": "First parameter", "default": "default_value"},
|
||||||
"type": "string",
|
"param2": {"type": "integer", "description": "Second parameter"},
|
||||||
"description": "First parameter",
|
},
|
||||||
"default": "default_value"
|
|
||||||
},
|
|
||||||
"param2": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "Second parameter"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mock_tools_result = Mock()
|
mock_tools_result = Mock()
|
||||||
mock_tools_result.tools = [mock_tool]
|
mock_tools_result.tools = [mock_tool]
|
||||||
|
|
||||||
mock_session = Mock()
|
mock_session = Mock()
|
||||||
mock_session.list_tools = AsyncMock(return_value=mock_tools_result)
|
mock_session.list_tools = AsyncMock(return_value=mock_tools_result)
|
||||||
|
|
||||||
with patch('llama_stack.providers.utils.tools.mcp.client_wrapper') as mock_wrapper:
|
with patch("llama_stack.providers.utils.tools.mcp.client_wrapper") as mock_wrapper:
|
||||||
mock_wrapper.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
mock_wrapper.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
||||||
mock_wrapper.return_value.__aexit__ = AsyncMock()
|
mock_wrapper.return_value.__aexit__ = AsyncMock()
|
||||||
|
|
||||||
result = await list_mcp_tools(endpoint, headers)
|
result = await list_mcp_tools(endpoint, headers)
|
||||||
|
|
||||||
assert isinstance(result, ListToolDefsResponse)
|
assert isinstance(result, ListToolDefsResponse)
|
||||||
assert len(result.data) == 1
|
assert len(result.data) == 1
|
||||||
|
|
||||||
tool_def = result.data[0]
|
tool_def = result.data[0]
|
||||||
assert tool_def.name == "test_tool"
|
assert tool_def.name == "test_tool"
|
||||||
assert tool_def.description == "A test tool"
|
assert tool_def.description == "A test tool"
|
||||||
assert tool_def.metadata["endpoint"] == endpoint
|
assert tool_def.metadata["endpoint"] == endpoint
|
||||||
|
|
||||||
# Check parameters
|
# Check parameters
|
||||||
assert len(tool_def.parameters) == 2
|
assert len(tool_def.parameters) == 2
|
||||||
|
|
||||||
param1 = next(p for p in tool_def.parameters if p.name == "param1")
|
param1 = next(p for p in tool_def.parameters if p.name == "param1")
|
||||||
assert param1.parameter_type == "string"
|
assert param1.parameter_type == "string"
|
||||||
assert param1.description == "First parameter"
|
assert param1.description == "First parameter"
|
||||||
assert param1.required is False # Has default value
|
assert param1.required is False # Has default value
|
||||||
assert param1.default == "default_value"
|
assert param1.default == "default_value"
|
||||||
|
|
||||||
param2 = next(p for p in tool_def.parameters if p.name == "param2")
|
param2 = next(p for p in tool_def.parameters if p.name == "param2")
|
||||||
assert param2.parameter_type == "integer"
|
assert param2.parameter_type == "integer"
|
||||||
assert param2.description == "Second parameter"
|
assert param2.description == "Second parameter"
|
||||||
assert param2.required is True # No default value
|
assert param2.required is True # No default value
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_tools_with_schema_refs(self):
|
async def test_list_tools_with_schema_refs(self):
|
||||||
"""Test listing tools with JSON Schema $refs."""
|
"""Test listing tools with JSON Schema $refs."""
|
||||||
endpoint = "http://example.com/mcp"
|
endpoint = "http://example.com/mcp"
|
||||||
headers = {}
|
headers = {}
|
||||||
|
|
||||||
# Mock tool with $ref in schema
|
# Mock tool with $ref in schema
|
||||||
mock_tool = Mock()
|
mock_tool = Mock()
|
||||||
mock_tool.name = "ref_tool"
|
mock_tool.name = "ref_tool"
|
||||||
mock_tool.description = "Tool with refs"
|
mock_tool.description = "Tool with refs"
|
||||||
mock_tool.inputSchema = {
|
mock_tool.inputSchema = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {"user": {"$ref": "#/$defs/User"}},
|
||||||
"user": {"$ref": "#/$defs/User"}
|
|
||||||
},
|
|
||||||
"$defs": {
|
"$defs": {
|
||||||
"User": {
|
"User": {"type": "object", "properties": {"name": {"type": "string", "description": "User name"}}}
|
||||||
"type": "object",
|
},
|
||||||
"properties": {
|
|
||||||
"name": {"type": "string", "description": "User name"}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mock_tools_result = Mock()
|
mock_tools_result = Mock()
|
||||||
mock_tools_result.tools = [mock_tool]
|
mock_tools_result.tools = [mock_tool]
|
||||||
|
|
||||||
mock_session = Mock()
|
mock_session = Mock()
|
||||||
mock_session.list_tools = AsyncMock(return_value=mock_tools_result)
|
mock_session.list_tools = AsyncMock(return_value=mock_tools_result)
|
||||||
|
|
||||||
with patch('llama_stack.providers.utils.tools.mcp.client_wrapper') as mock_wrapper:
|
with patch("llama_stack.providers.utils.tools.mcp.client_wrapper") as mock_wrapper:
|
||||||
mock_wrapper.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
mock_wrapper.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
||||||
mock_wrapper.return_value.__aexit__ = AsyncMock()
|
mock_wrapper.return_value.__aexit__ = AsyncMock()
|
||||||
|
|
||||||
result = await list_mcp_tools(endpoint, headers)
|
result = await list_mcp_tools(endpoint, headers)
|
||||||
|
|
||||||
# Should have resolved the $ref
|
# Should have resolved the $ref
|
||||||
tool_def = result.data[0]
|
tool_def = result.data[0]
|
||||||
assert len(tool_def.parameters) == 1
|
assert len(tool_def.parameters) == 1
|
||||||
|
|
||||||
# The user parameter should be flattened from the resolved $ref
|
# The user parameter should be flattened from the resolved $ref
|
||||||
# Note: This depends on how the schema resolution works with nested objects
|
# Note: This depends on how the schema resolution works with nested objects
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_tools_empty_result(self):
|
async def test_list_tools_empty_result(self):
|
||||||
"""Test listing tools when no tools are available."""
|
"""Test listing tools when no tools are available."""
|
||||||
endpoint = "http://example.com/mcp"
|
endpoint = "http://example.com/mcp"
|
||||||
headers = {}
|
headers = {}
|
||||||
|
|
||||||
mock_tools_result = Mock()
|
mock_tools_result = Mock()
|
||||||
mock_tools_result.tools = []
|
mock_tools_result.tools = []
|
||||||
|
|
||||||
mock_session = Mock()
|
mock_session = Mock()
|
||||||
mock_session.list_tools = AsyncMock(return_value=mock_tools_result)
|
mock_session.list_tools = AsyncMock(return_value=mock_tools_result)
|
||||||
|
|
||||||
with patch('llama_stack.providers.utils.tools.mcp.client_wrapper') as mock_wrapper:
|
with patch("llama_stack.providers.utils.tools.mcp.client_wrapper") as mock_wrapper:
|
||||||
mock_wrapper.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
mock_wrapper.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
||||||
mock_wrapper.return_value.__aexit__ = AsyncMock()
|
mock_wrapper.return_value.__aexit__ = AsyncMock()
|
||||||
|
|
||||||
result = await list_mcp_tools(endpoint, headers)
|
result = await list_mcp_tools(endpoint, headers)
|
||||||
|
|
||||||
assert isinstance(result, ListToolDefsResponse)
|
assert isinstance(result, ListToolDefsResponse)
|
||||||
assert len(result.data) == 0
|
assert len(result.data) == 0
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue