mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Merge c5f65b08bb
into 2ebc172f33
This commit is contained in:
commit
9fb3dd1b1b
2 changed files with 37 additions and 2 deletions
|
@ -8,7 +8,7 @@ from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
|
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
|
||||||
from llama_stack.distribution.datatypes import ToolGroupWithOwner
|
from llama_stack.distribution.datatypes import AuthenticationRequiredError, ToolGroupWithOwner
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
@ -53,7 +53,18 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
all_tools = []
|
all_tools = []
|
||||||
for toolgroup in toolgroups:
|
for toolgroup in toolgroups:
|
||||||
if toolgroup.identifier not in self.toolgroups_to_tools:
|
if toolgroup.identifier not in self.toolgroups_to_tools:
|
||||||
await self._index_tools(toolgroup)
|
try:
|
||||||
|
await self._index_tools(toolgroup)
|
||||||
|
except AuthenticationRequiredError:
|
||||||
|
# Send authentication errors back to the client so it knows
|
||||||
|
# that it needs to supply credentials for remote MCP servers.
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
# Other errors that the client cannot fix are logged and
|
||||||
|
# those specific toolgroups are skipped.
|
||||||
|
logger.warning(f"Error listing tools for toolgroup {toolgroup.identifier}: {e}")
|
||||||
|
logger.debug(e, exc_info=True)
|
||||||
|
continue
|
||||||
all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier])
|
all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier])
|
||||||
|
|
||||||
return ListToolsResponse(data=all_tools)
|
return ListToolsResponse(data=all_tools)
|
||||||
|
|
|
@ -10,6 +10,7 @@ from unittest.mock import AsyncMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.common.type_system import NumberType
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
|
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
|
||||||
from llama_stack.apis.datatypes import Api
|
from llama_stack.apis.datatypes import Api
|
||||||
|
@ -296,3 +297,26 @@ async def test_tool_groups_routing_table(cached_disk_dist_registry):
|
||||||
await table.unregister_toolgroup(toolgroup_id="test-toolgroup")
|
await table.unregister_toolgroup(toolgroup_id="test-toolgroup")
|
||||||
tool_groups = await table.list_tool_groups()
|
tool_groups = await table.list_tool_groups()
|
||||||
assert len(tool_groups.data) == 0
|
assert len(tool_groups.data) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_groups_routing_table_exception_handling(cached_disk_dist_registry):
|
||||||
|
"""Test that the tool group routing table handles exceptions when listing tools, like if an MCP server is unreachable."""
|
||||||
|
|
||||||
|
exception_throwing_tool_groups_impl = ToolGroupsImpl()
|
||||||
|
exception_throwing_tool_groups_impl.list_runtime_tools = AsyncMock(side_effect=Exception("Test exception"))
|
||||||
|
|
||||||
|
table = ToolGroupsRoutingTable(
|
||||||
|
{"test_provider": exception_throwing_tool_groups_impl}, cached_disk_dist_registry, {}
|
||||||
|
)
|
||||||
|
await table.initialize()
|
||||||
|
|
||||||
|
await table.register_tool_group(
|
||||||
|
toolgroup_id="test-toolgroup-exceptions",
|
||||||
|
provider_id="test_provider",
|
||||||
|
mcp_endpoint=URL(uri="http://localhost:8479/foo/bar"),
|
||||||
|
)
|
||||||
|
|
||||||
|
tools = await table.list_tools(toolgroup_id="test-toolgroup-exceptions")
|
||||||
|
|
||||||
|
assert len(tools.data) == 0
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue