mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
keep related tests only
This commit is contained in:
parent
2e70c54f42
commit
9adabb09cf
1 changed files with 0 additions and 253 deletions
|
@ -256,145 +256,6 @@ class TestMCPProtocol:
|
|||
assert MCPProtol.SSE.value == 2
|
||||
|
||||
|
||||
class TestClientWrapper:
|
||||
"""Test cases for client_wrapper function."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client_session(self):
|
||||
"""Mock ClientSession for testing."""
|
||||
session = Mock()
|
||||
session.initialize = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client_streams(self):
|
||||
"""Mock client streams."""
|
||||
return (Mock(), Mock())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_streamable_http_connection(self, mock_client_session, mock_client_streams):
|
||||
"""Test successful connection with STREAMABLE_HTTP protocol."""
|
||||
endpoint = "http://example.com/mcp"
|
||||
headers = {"Authorization": "Bearer token"}
|
||||
|
||||
# Create a proper context manager mock
|
||||
mock_http_context = AsyncMock()
|
||||
mock_http_context.__aenter__ = AsyncMock(return_value=mock_client_streams)
|
||||
mock_http_context.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_session_context = AsyncMock()
|
||||
mock_session_context.__aenter__ = AsyncMock(return_value=mock_client_session)
|
||||
mock_session_context.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch('llama_stack.providers.utils.tools.mcp.streamablehttp_client') as mock_http_client, \
|
||||
patch('llama_stack.providers.utils.tools.mcp.ClientSession') as mock_session_class:
|
||||
|
||||
mock_http_client.return_value = mock_http_context
|
||||
mock_session_class.return_value = mock_session_context
|
||||
|
||||
async with client_wrapper(endpoint, headers) as session:
|
||||
assert session == mock_client_session
|
||||
mock_client_session.initialize.assert_called_once()
|
||||
assert protocol_cache.get(endpoint) == MCPProtol.STREAMABLE_HTTP
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cached_protocol_preference(self, mock_client_session, mock_client_streams):
|
||||
"""Test that cached protocol is tried first."""
|
||||
endpoint = "http://example.com/mcp"
|
||||
headers = {"Authorization": "Bearer token"}
|
||||
|
||||
# Set SSE as cached protocol
|
||||
protocol_cache[endpoint] = MCPProtol.SSE
|
||||
|
||||
# Create proper context manager mocks
|
||||
mock_sse_context = AsyncMock()
|
||||
mock_sse_context.__aenter__ = AsyncMock(return_value=mock_client_streams)
|
||||
mock_sse_context.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_session_context = AsyncMock()
|
||||
mock_session_context.__aenter__ = AsyncMock(return_value=mock_client_session)
|
||||
mock_session_context.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch('llama_stack.providers.utils.tools.mcp.sse_client') as mock_sse_client, \
|
||||
patch('llama_stack.providers.utils.tools.mcp.streamablehttp_client') as mock_http_client, \
|
||||
patch('llama_stack.providers.utils.tools.mcp.ClientSession') as mock_session_class:
|
||||
|
||||
mock_sse_client.return_value = mock_sse_context
|
||||
mock_session_class.return_value = mock_session_context
|
||||
|
||||
async with client_wrapper(endpoint, headers) as session:
|
||||
assert session == mock_client_session
|
||||
# SSE should be tried first due to cache
|
||||
mock_sse_client.assert_called_once()
|
||||
mock_http_client.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authentication_error_raises_exception(self):
|
||||
"""Test that 401 errors raise AuthenticationRequiredError."""
|
||||
endpoint = "http://example.com/mcp"
|
||||
headers = {"Authorization": "Bearer invalid"}
|
||||
|
||||
protocol_cache.clear()
|
||||
|
||||
# Create a proper HTTP 401 error
|
||||
response = Mock()
|
||||
response.status_code = 401
|
||||
http_error = httpx.HTTPStatusError("Unauthorized", request=Mock(), response=response)
|
||||
|
||||
with patch('llama_stack.providers.utils.tools.mcp.streamablehttp_client') as mock_http_client:
|
||||
mock_http_client.return_value.__aenter__ = AsyncMock(side_effect=http_error)
|
||||
mock_http_client.return_value.__aexit__ = AsyncMock()
|
||||
|
||||
with pytest.raises(AuthenticationRequiredError):
|
||||
async with client_wrapper(endpoint, headers):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_error_handling(self):
|
||||
"""Test handling of connection errors."""
|
||||
endpoint = "http://example.com/mcp"
|
||||
headers = {}
|
||||
|
||||
protocol_cache.clear()
|
||||
|
||||
connect_error = httpx.ConnectError("Connection refused")
|
||||
|
||||
with patch('llama_stack.providers.utils.tools.mcp.streamablehttp_client') as mock_http_client, \
|
||||
patch('llama_stack.providers.utils.tools.mcp.sse_client') as mock_sse_client:
|
||||
|
||||
mock_http_client.return_value.__aenter__ = AsyncMock(side_effect=connect_error)
|
||||
mock_http_client.return_value.__aexit__ = AsyncMock()
|
||||
mock_sse_client.return_value.__aenter__ = AsyncMock(side_effect=connect_error)
|
||||
mock_sse_client.return_value.__aexit__ = AsyncMock()
|
||||
|
||||
with pytest.raises(ConnectionError, match="Failed to connect to MCP server"):
|
||||
async with client_wrapper(endpoint, headers):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_error_handling(self):
|
||||
"""Test handling of timeout errors."""
|
||||
endpoint = "http://example.com/mcp"
|
||||
headers = {}
|
||||
|
||||
protocol_cache.clear()
|
||||
|
||||
timeout_error = httpx.TimeoutException("Request timeout")
|
||||
|
||||
with patch('llama_stack.providers.utils.tools.mcp.streamablehttp_client') as mock_http_client, \
|
||||
patch('llama_stack.providers.utils.tools.mcp.sse_client') as mock_sse_client:
|
||||
|
||||
mock_http_client.return_value.__aenter__ = AsyncMock(side_effect=timeout_error)
|
||||
mock_http_client.return_value.__aexit__ = AsyncMock()
|
||||
mock_sse_client.return_value.__aenter__ = AsyncMock(side_effect=timeout_error)
|
||||
mock_sse_client.return_value.__aexit__ = AsyncMock()
|
||||
|
||||
with pytest.raises(TimeoutError, match="MCP server.*timed out"):
|
||||
async with client_wrapper(endpoint, headers):
|
||||
pass
|
||||
|
||||
|
||||
class TestListMcpTools:
|
||||
"""Test cases for list_mcp_tools function."""
|
||||
|
||||
|
@ -521,117 +382,3 @@ class TestListMcpTools:
|
|||
|
||||
assert isinstance(result, ListToolDefsResponse)
|
||||
assert len(result.data) == 0
|
||||
|
||||
|
||||
class TestInvokeMcpTool:
|
||||
"""Test cases for invoke_mcp_tool function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_tool_success_with_text_content(self):
|
||||
"""Test successful tool invocation with text content."""
|
||||
endpoint = "http://example.com/mcp"
|
||||
headers = {}
|
||||
tool_name = "test_tool"
|
||||
kwargs = {"param1": "value1", "param2": 42}
|
||||
|
||||
# Mock MCP text content
|
||||
mock_text_content = mcp_types.TextContent(type="text", text="Tool output text")
|
||||
|
||||
mock_result = Mock()
|
||||
mock_result.content = [mock_text_content]
|
||||
mock_result.isError = False
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.call_tool = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch('llama_stack.providers.utils.tools.mcp.client_wrapper') as mock_wrapper:
|
||||
mock_wrapper.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_wrapper.return_value.__aexit__ = AsyncMock()
|
||||
|
||||
result = await invoke_mcp_tool(endpoint, headers, tool_name, kwargs)
|
||||
|
||||
assert isinstance(result, ToolInvocationResult)
|
||||
assert result.error_code == 0
|
||||
assert len(result.content) == 1
|
||||
|
||||
content_item = result.content[0]
|
||||
assert isinstance(content_item, TextContentItem)
|
||||
assert content_item.text == "Tool output text"
|
||||
|
||||
mock_session.call_tool.assert_called_once_with(tool_name, kwargs)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_tool_with_error(self):
|
||||
"""Test tool invocation when tool returns an error."""
|
||||
endpoint = "http://example.com/mcp"
|
||||
headers = {}
|
||||
tool_name = "error_tool"
|
||||
kwargs = {}
|
||||
|
||||
mock_text_content = mcp_types.TextContent(type="text", text="Error occurred")
|
||||
|
||||
mock_result = Mock()
|
||||
mock_result.content = [mock_text_content]
|
||||
mock_result.isError = True
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.call_tool = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch('llama_stack.providers.utils.tools.mcp.client_wrapper') as mock_wrapper:
|
||||
mock_wrapper.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_wrapper.return_value.__aexit__ = AsyncMock()
|
||||
|
||||
result = await invoke_mcp_tool(endpoint, headers, tool_name, kwargs)
|
||||
|
||||
assert isinstance(result, ToolInvocationResult)
|
||||
assert result.error_code == 1
|
||||
assert len(result.content) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_tool_with_embedded_resource_warning(self):
|
||||
"""Test tool invocation with unsupported EmbeddedResource content."""
|
||||
endpoint = "http://example.com/mcp"
|
||||
headers = {}
|
||||
tool_name = "resource_tool"
|
||||
kwargs = {}
|
||||
|
||||
# Mock MCP embedded resource content
|
||||
mock_embedded_resource = mcp_types.EmbeddedResource(
|
||||
type="resource",
|
||||
resource={
|
||||
"uri": "file:///example.txt",
|
||||
"text": "Resource content"
|
||||
}
|
||||
)
|
||||
|
||||
mock_result = Mock()
|
||||
mock_result.content = [mock_embedded_resource]
|
||||
mock_result.isError = False
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.call_tool = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch('llama_stack.providers.utils.tools.mcp.client_wrapper') as mock_wrapper, \
|
||||
patch('llama_stack.providers.utils.tools.mcp.logger') as mock_logger:
|
||||
|
||||
mock_wrapper.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_wrapper.return_value.__aexit__ = AsyncMock()
|
||||
|
||||
result = await invoke_mcp_tool(endpoint, headers, tool_name, kwargs)
|
||||
|
||||
assert isinstance(result, ToolInvocationResult)
|
||||
assert result.error_code == 0
|
||||
assert len(result.content) == 0 # EmbeddedResource is skipped
|
||||
|
||||
# Should log a warning
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert "EmbeddedResource is not supported" in str(mock_logger.warning.call_args)
|
||||
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_protocol_cache():
|
||||
"""Clear protocol cache before each test."""
|
||||
protocol_cache.clear()
|
||||
yield
|
||||
protocol_cache.clear()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue