forked from phoenix-oss/llama-stack-mirror
fix(tools): do not index tools, only index toolgroups (#2261)
When registering a MCP endpoint, we cannot list tools (like we used to) since the MCP endpoint may be behind an auth wall. Registration can happen much sooner (via run.yaml). Instead, we do listing only when the _user_ actually calls listing. Furthermore, we cache the list in-memory in the server. Currently, the cache is not invalidated -- we may want to periodically re-list for MCP servers. Note that they must call `list_tools` before calling `invoke_tool` -- we use this critically. This will enable us to list MCP servers in run.yaml ## Test Plan Existing tests, updated tests accordingly.
This commit is contained in:
parent
5a422e236c
commit
ce33d02443
19 changed files with 131 additions and 153 deletions
|
@ -12,19 +12,19 @@ import httpx
|
|||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||
|
||||
from .config import BingSearchToolConfig
|
||||
|
||||
|
||||
class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
def __init__(self, config: BingSearchToolConfig):
|
||||
self.config = config
|
||||
self.url = "https://api.bing.microsoft.com/v7.0/search"
|
||||
|
@ -32,10 +32,10 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
|
|||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool) -> None:
|
||||
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||
return
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
|
|
|
@ -11,30 +11,30 @@ import httpx
|
|||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||
|
||||
from .config import BraveSearchToolConfig
|
||||
|
||||
|
||||
class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
def __init__(self, config: BraveSearchToolConfig):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool) -> None:
|
||||
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||
return
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
|
|
|
@ -11,12 +11,13 @@ from llama_stack.apis.common.content_types import URL
|
|||
from llama_stack.apis.datatypes import Api
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||
from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools
|
||||
|
||||
from .config import MCPProviderConfig
|
||||
|
@ -24,13 +25,19 @@ from .config import MCPProviderConfig
|
|||
logger = get_logger(__name__, category="tools")
|
||||
|
||||
|
||||
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||
return
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
) -> ListToolDefsResponse:
|
||||
|
|
|
@ -12,29 +12,29 @@ import httpx
|
|||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||
|
||||
from .config import TavilySearchToolConfig
|
||||
|
||||
|
||||
class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
def __init__(self, config: TavilySearchToolConfig):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool) -> None:
|
||||
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||
return
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
|
|
|
@ -12,19 +12,19 @@ import httpx
|
|||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||
|
||||
from .config import WolframAlphaToolConfig
|
||||
|
||||
|
||||
class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
def __init__(self, config: WolframAlphaToolConfig):
|
||||
self.config = config
|
||||
self.url = "https://api.wolframalpha.com/v2/query"
|
||||
|
@ -32,10 +32,10 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
|
|||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool) -> None:
|
||||
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||
return
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue