diff --git a/llama_stack/distribution/routing_tables/toolgroups.py b/llama_stack/distribution/routing_tables/toolgroups.py index b86f057bd..ac5e34783 100644 --- a/llama_stack/distribution/routing_tables/toolgroups.py +++ b/llama_stack/distribution/routing_tables/toolgroups.py @@ -52,9 +52,14 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): all_tools = [] for toolgroup in toolgroups: - if toolgroup.identifier not in self.toolgroups_to_tools: - await self._index_tools(toolgroup) - all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier]) + try: + if toolgroup.identifier not in self.toolgroups_to_tools: + await self._index_tools(toolgroup) + all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier]) + except Exception as e: + logger.warning(f"Error listing tools for toolgroup {toolgroup.identifier}: {e}") + logger.debug(e, exc_info=True) + continue return ListToolsResponse(data=all_tools) diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 0eeb68167..ec158de63 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -10,6 +10,7 @@ from unittest.mock import AsyncMock import pytest +from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource from llama_stack.apis.datatypes import Api @@ -296,3 +297,26 @@ async def test_tool_groups_routing_table(cached_disk_dist_registry): await table.unregister_toolgroup(toolgroup_id="test-toolgroup") tool_groups = await table.list_tool_groups() assert len(tool_groups.data) == 0 + + +@pytest.mark.asyncio +async def test_tool_groups_routing_table_exception_handling(cached_disk_dist_registry): + """Test that the tool group routing table handles exceptions when listing tools, like if an MCP server is unreachable.""" + + exception_throwing_tool_groups_impl = ToolGroupsImpl() + exception_throwing_tool_groups_impl.list_runtime_tools = AsyncMock(side_effect=Exception("Test exception")) + + table = ToolGroupsRoutingTable( + {"test_provider": exception_throwing_tool_groups_impl}, cached_disk_dist_registry, {} + ) + await table.initialize() + + await table.register_tool_group( + toolgroup_id="test-toolgroup-exceptions", + provider_id="test_provider", + mcp_endpoint=URL(uri="http://localhost:8479/foo/bar"), + ) + + tools = await table.list_tools(toolgroup_id="test-toolgroup-exceptions") + + assert len(tools.data) == 0