diff --git a/llama_stack/distribution/routing_tables/toolgroups.py b/llama_stack/distribution/routing_tables/toolgroups.py index b86f057bd..6a2e5a860 100644 --- a/llama_stack/distribution/routing_tables/toolgroups.py +++ b/llama_stack/distribution/routing_tables/toolgroups.py @@ -8,7 +8,7 @@ from typing import Any from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups -from llama_stack.distribution.datatypes import ToolGroupWithOwner +from llama_stack.distribution.datatypes import AuthenticationRequiredError, ToolGroupWithOwner from llama_stack.log import get_logger from .common import CommonRoutingTableImpl @@ -53,7 +53,18 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): all_tools = [] for toolgroup in toolgroups: if toolgroup.identifier not in self.toolgroups_to_tools: - await self._index_tools(toolgroup) + try: + await self._index_tools(toolgroup) + except AuthenticationRequiredError: + # Send authentication errors back to the client so it knows + # that it needs to supply credentials for remote MCP servers. + raise + except Exception as e: + # Other errors that the client cannot fix are logged and + # those specific toolgroups are skipped. + logger.warning(f"Error listing tools for toolgroup {toolgroup.identifier}: {e}") + logger.debug(e, exc_info=True) + continue all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier]) return ListToolsResponse(data=all_tools) diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 0eeb68167..ec158de63 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -10,6 +10,7 @@ from unittest.mock import AsyncMock import pytest +from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource from llama_stack.apis.datatypes import Api @@ -296,3 +297,26 @@ async def test_tool_groups_routing_table(cached_disk_dist_registry): await table.unregister_toolgroup(toolgroup_id="test-toolgroup") tool_groups = await table.list_tool_groups() assert len(tool_groups.data) == 0 + + +@pytest.mark.asyncio +async def test_tool_groups_routing_table_exception_handling(cached_disk_dist_registry): + """Test that the tool group routing table handles exceptions when listing tools, like if an MCP server is unreachable.""" + + exception_throwing_tool_groups_impl = ToolGroupsImpl() + exception_throwing_tool_groups_impl.list_runtime_tools = AsyncMock(side_effect=Exception("Test exception")) + + table = ToolGroupsRoutingTable( + {"test_provider": exception_throwing_tool_groups_impl}, cached_disk_dist_registry, {} + ) + await table.initialize() + + await table.register_tool_group( + toolgroup_id="test-toolgroup-exceptions", + provider_id="test_provider", + mcp_endpoint=URL(uri="http://localhost:8479/foo/bar"), + ) + + tools = await table.list_tools(toolgroup_id="test-toolgroup-exceptions") + + assert len(tools.data) == 0