This commit is contained in:
Ben Browning 2025-07-11 16:26:08 +02:00 committed by GitHub
commit 9fb3dd1b1b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 37 additions and 2 deletions

View file

@ -8,7 +8,7 @@ from typing import Any
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
from llama_stack.distribution.datatypes import ToolGroupWithOwner
from llama_stack.distribution.datatypes import AuthenticationRequiredError, ToolGroupWithOwner
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
@ -53,7 +53,18 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
all_tools = []
for toolgroup in toolgroups:
if toolgroup.identifier not in self.toolgroups_to_tools:
await self._index_tools(toolgroup)
try:
await self._index_tools(toolgroup)
except AuthenticationRequiredError:
# Send authentication errors back to the client so it knows
# that it needs to supply credentials for remote MCP servers.
raise
except Exception as e:
# Other errors that the client cannot fix are logged and
# those specific toolgroups are skipped.
logger.warning(f"Error listing tools for toolgroup {toolgroup.identifier}: {e}")
logger.debug(e, exc_info=True)
continue
all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier])
return ListToolsResponse(data=all_tools)

View file

@ -10,6 +10,7 @@ from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
from llama_stack.apis.datatypes import Api
@ -296,3 +297,26 @@ async def test_tool_groups_routing_table(cached_disk_dist_registry):
await table.unregister_toolgroup(toolgroup_id="test-toolgroup")
tool_groups = await table.list_tool_groups()
assert len(tool_groups.data) == 0
@pytest.mark.asyncio
async def test_tool_groups_routing_table_exception_handling(cached_disk_dist_registry):
"""Test that the tool group routing table handles exceptions when listing tools, like if an MCP server is unreachable."""
exception_throwing_tool_groups_impl = ToolGroupsImpl()
exception_throwing_tool_groups_impl.list_runtime_tools = AsyncMock(side_effect=Exception("Test exception"))
table = ToolGroupsRoutingTable(
{"test_provider": exception_throwing_tool_groups_impl}, cached_disk_dist_registry, {}
)
await table.initialize()
await table.register_tool_group(
toolgroup_id="test-toolgroup-exceptions",
provider_id="test_provider",
mcp_endpoint=URL(uri="http://localhost:8479/foo/bar"),
)
tools = await table.list_tools(toolgroup_id="test-toolgroup-exceptions")
assert len(tools.data) == 0