mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
several fixes
This commit is contained in:
parent
bf8a73e09a
commit
cddc1f3524
15 changed files with 95 additions and 83 deletions
|
@ -47,7 +47,7 @@ from llama_stack.providers.datatypes import (
|
||||||
RemoteProviderSpec,
|
RemoteProviderSpec,
|
||||||
ScoringFunctionsProtocolPrivate,
|
ScoringFunctionsProtocolPrivate,
|
||||||
ShieldsProtocolPrivate,
|
ShieldsProtocolPrivate,
|
||||||
ToolsProtocolPrivate,
|
ToolGroupsProtocolPrivate,
|
||||||
VectorDBsProtocolPrivate,
|
VectorDBsProtocolPrivate,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -93,7 +93,7 @@ def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
|
||||||
def additional_protocols_map() -> dict[Api, Any]:
|
def additional_protocols_map() -> dict[Api, Any]:
|
||||||
return {
|
return {
|
||||||
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
||||||
Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups),
|
Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups),
|
||||||
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
|
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
|
||||||
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
||||||
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
|
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
|
||||||
|
|
|
@ -11,7 +11,7 @@ from llama_stack.apis.common.content_types import (
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ListToolDefsResponse,
|
ListToolsResponse,
|
||||||
RAGDocument,
|
RAGDocument,
|
||||||
RAGQueryConfig,
|
RAGQueryConfig,
|
||||||
RAGQueryResult,
|
RAGQueryResult,
|
||||||
|
@ -87,6 +87,6 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
|
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||||
) -> ListToolDefsResponse:
|
) -> ListToolsResponse:
|
||||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
return await self.routing_table.list_tools(tool_group_id)
|
||||||
|
|
|
@ -46,7 +46,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
||||||
elif api == Api.eval:
|
elif api == Api.eval:
|
||||||
return await p.register_benchmark(obj)
|
return await p.register_benchmark(obj)
|
||||||
elif api == Api.tool_runtime:
|
elif api == Api.tool_runtime:
|
||||||
return await p.register_tool(obj)
|
return await p.register_toolgroup(obj)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown API {api} for registering object with provider")
|
raise ValueError(f"Unknown API {api} for registering object with provider")
|
||||||
|
|
||||||
|
@ -60,7 +60,7 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
||||||
elif api == Api.datasetio:
|
elif api == Api.datasetio:
|
||||||
return await p.unregister_dataset(obj.identifier)
|
return await p.unregister_dataset(obj.identifier)
|
||||||
elif api == Api.tool_runtime:
|
elif api == Api.tool_runtime:
|
||||||
return await p.unregister_tool(obj.identifier)
|
return await p.unregister_toolgroup(obj.identifier)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unregister not supported for {api}")
|
raise ValueError(f"Unregister not supported for {api}")
|
||||||
|
|
||||||
|
@ -136,7 +136,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
elif isinstance(self, BenchmarksRoutingTable):
|
elif isinstance(self, BenchmarksRoutingTable):
|
||||||
return ("Eval", "benchmark")
|
return ("Eval", "benchmark")
|
||||||
elif isinstance(self, ToolGroupsRoutingTable):
|
elif isinstance(self, ToolGroupsRoutingTable):
|
||||||
return ("Tools", "tool")
|
return ("ToolGroups", "tool_group")
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown routing table type")
|
raise ValueError("Unknown routing table type")
|
||||||
|
|
||||||
|
|
|
@ -24,9 +24,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||||
# we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id
|
# we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id
|
||||||
# TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while?
|
# TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while?
|
||||||
tool_name = routing_key
|
if routing_key in self.tool_to_toolgroup:
|
||||||
if tool_name in self.tool_to_toolgroup:
|
routing_key = self.tool_to_toolgroup[routing_key]
|
||||||
routing_key = self.tool_to_toolgroup[tool_name]
|
|
||||||
return super().get_provider_impl(routing_key, provider_id)
|
return super().get_provider_impl(routing_key, provider_id)
|
||||||
|
|
||||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||||
|
@ -39,11 +38,26 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
for toolgroup in toolgroups:
|
for toolgroup in toolgroups:
|
||||||
group_id = toolgroup.identifier
|
group_id = toolgroup.identifier
|
||||||
if group_id not in self.toolgroups_to_tools:
|
if group_id not in self.toolgroups_to_tools:
|
||||||
provider_impl = self.get_provider_impl(toolgroup.provider_id)
|
provider_impl = super().get_provider_impl(group_id, toolgroup.provider_id)
|
||||||
tools = await provider_impl.list_runtime_tools(group_id, toolgroup.mcp_endpoint)
|
tooldefs_response = await provider_impl.list_runtime_tools(group_id, toolgroup.mcp_endpoint)
|
||||||
|
|
||||||
self.toolgroups_to_tools[group_id] = tools.data
|
# TODO: kill this Tool vs ToolDef distinction
|
||||||
for tool in tools.data:
|
tooldefs = tooldefs_response.data
|
||||||
|
tools = []
|
||||||
|
for t in tooldefs:
|
||||||
|
tools.append(
|
||||||
|
Tool(
|
||||||
|
identifier=t.name,
|
||||||
|
toolgroup_id=group_id,
|
||||||
|
description=t.description or "",
|
||||||
|
parameters=t.parameters or [],
|
||||||
|
metadata=t.metadata,
|
||||||
|
provider_id=toolgroup.provider_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.toolgroups_to_tools[group_id] = tools
|
||||||
|
for tool in tools:
|
||||||
self.tool_to_toolgroup[tool.identifier] = group_id
|
self.tool_to_toolgroup[tool.identifier] = group_id
|
||||||
all_tools.extend(self.toolgroups_to_tools[group_id])
|
all_tools.extend(self.toolgroups_to_tools[group_id])
|
||||||
|
|
||||||
|
@ -74,15 +88,15 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
mcp_endpoint: URL | None = None,
|
mcp_endpoint: URL | None = None,
|
||||||
args: dict[str, Any] | None = None,
|
args: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
await self.dist_registry.register(
|
toolgroup = ToolGroupWithACL(
|
||||||
ToolGroupWithACL(
|
identifier=toolgroup_id,
|
||||||
identifier=toolgroup_id,
|
provider_id=provider_id,
|
||||||
provider_id=provider_id,
|
provider_resource_id=toolgroup_id,
|
||||||
provider_resource_id=toolgroup_id,
|
mcp_endpoint=mcp_endpoint,
|
||||||
mcp_endpoint=mcp_endpoint,
|
args=args,
|
||||||
args=args,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
await self.register_object(toolgroup)
|
||||||
|
return toolgroup
|
||||||
|
|
||||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
tool_group = await self.get_tool_group(toolgroup_id)
|
tool_group = await self.get_tool_group(toolgroup_id)
|
||||||
|
|
|
@ -16,7 +16,7 @@ from llama_stack.apis.datatypes import Api
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn
|
from llama_stack.apis.scoring_functions import ScoringFn
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.apis.tools import Tool
|
from llama_stack.apis.tools import ToolGroup
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
@ -74,10 +74,10 @@ class BenchmarksProtocolPrivate(Protocol):
|
||||||
async def register_benchmark(self, benchmark: Benchmark) -> None: ...
|
async def register_benchmark(self, benchmark: Benchmark) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
class ToolsProtocolPrivate(Protocol):
|
class ToolGroupsProtocolPrivate(Protocol):
|
||||||
async def register_tool(self, tool: Tool) -> None: ...
|
async def register_toolgroup(self, toolgroup: ToolGroup) -> None: ...
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None: ...
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -25,14 +25,14 @@ from llama_stack.apis.tools import (
|
||||||
RAGQueryConfig,
|
RAGQueryConfig,
|
||||||
RAGQueryResult,
|
RAGQueryResult,
|
||||||
RAGToolRuntime,
|
RAGToolRuntime,
|
||||||
Tool,
|
|
||||||
ToolDef,
|
ToolDef,
|
||||||
|
ToolGroup,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
content_from_doc,
|
content_from_doc,
|
||||||
|
@ -49,7 +49,7 @@ def make_random_string(length: int = 8):
|
||||||
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
||||||
|
|
||||||
|
|
||||||
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: RagToolRuntimeConfig,
|
config: RagToolRuntimeConfig,
|
||||||
|
@ -66,10 +66,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
async def shutdown(self):
|
async def shutdown(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_tool(self, tool: Tool) -> None:
|
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None:
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
async def insert(
|
async def insert(
|
||||||
|
|
|
@ -12,19 +12,19 @@ import httpx
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ListToolDefsResponse,
|
ListToolDefsResponse,
|
||||||
Tool,
|
|
||||||
ToolDef,
|
ToolDef,
|
||||||
|
ToolGroup,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||||
|
|
||||||
from .config import BingSearchToolConfig
|
from .config import BingSearchToolConfig
|
||||||
|
|
||||||
|
|
||||||
class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||||
def __init__(self, config: BingSearchToolConfig):
|
def __init__(self, config: BingSearchToolConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.url = "https://api.bing.microsoft.com/v7.0/search"
|
self.url = "https://api.bing.microsoft.com/v7.0/search"
|
||||||
|
@ -32,10 +32,10 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_tool(self, tool: Tool) -> None:
|
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None:
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
def _get_api_key(self) -> str:
|
def _get_api_key(self) -> str:
|
||||||
|
|
|
@ -11,30 +11,30 @@ import httpx
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ListToolDefsResponse,
|
ListToolDefsResponse,
|
||||||
Tool,
|
|
||||||
ToolDef,
|
ToolDef,
|
||||||
|
ToolGroup,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.models.llama.datatypes import BuiltinTool
|
from llama_stack.models.llama.datatypes import BuiltinTool
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||||
|
|
||||||
from .config import BraveSearchToolConfig
|
from .config import BraveSearchToolConfig
|
||||||
|
|
||||||
|
|
||||||
class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||||
def __init__(self, config: BraveSearchToolConfig):
|
def __init__(self, config: BraveSearchToolConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_tool(self, tool: Tool) -> None:
|
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None:
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
def _get_api_key(self) -> str:
|
def _get_api_key(self) -> str:
|
||||||
|
|
|
@ -11,12 +11,13 @@ from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.datatypes import Api
|
from llama_stack.apis.datatypes import Api
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ListToolDefsResponse,
|
ListToolDefsResponse,
|
||||||
|
ToolGroup,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||||
from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools
|
from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools
|
||||||
|
|
||||||
from .config import MCPProviderConfig
|
from .config import MCPProviderConfig
|
||||||
|
@ -24,13 +25,19 @@ from .config import MCPProviderConfig
|
||||||
logger = get_logger(__name__, category="tools")
|
logger = get_logger(__name__, category="tools")
|
||||||
|
|
||||||
|
|
||||||
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||||
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
|
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||||
) -> ListToolDefsResponse:
|
) -> ListToolDefsResponse:
|
||||||
|
|
|
@ -12,29 +12,29 @@ import httpx
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ListToolDefsResponse,
|
ListToolDefsResponse,
|
||||||
Tool,
|
|
||||||
ToolDef,
|
ToolDef,
|
||||||
|
ToolGroup,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||||
|
|
||||||
from .config import TavilySearchToolConfig
|
from .config import TavilySearchToolConfig
|
||||||
|
|
||||||
|
|
||||||
class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||||
def __init__(self, config: TavilySearchToolConfig):
|
def __init__(self, config: TavilySearchToolConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_tool(self, tool: Tool) -> None:
|
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None:
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
def _get_api_key(self) -> str:
|
def _get_api_key(self) -> str:
|
||||||
|
|
|
@ -12,19 +12,19 @@ import httpx
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ListToolDefsResponse,
|
ListToolDefsResponse,
|
||||||
Tool,
|
|
||||||
ToolDef,
|
ToolDef,
|
||||||
|
ToolGroup,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||||
|
|
||||||
from .config import WolframAlphaToolConfig
|
from .config import WolframAlphaToolConfig
|
||||||
|
|
||||||
|
|
||||||
class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||||
def __init__(self, config: WolframAlphaToolConfig):
|
def __init__(self, config: WolframAlphaToolConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.url = "https://api.wolframalpha.com/v2/query"
|
self.url = "https://api.wolframalpha.com/v2/query"
|
||||||
|
@ -32,10 +32,10 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_tool(self, tool: Tool) -> None:
|
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None:
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
def _get_api_key(self) -> str:
|
def _get_api_key(self) -> str:
|
||||||
|
|
|
@ -25,10 +25,12 @@ def test_web_search_tool(llama_stack_client, sample_search_query):
|
||||||
if "TAVILY_SEARCH_API_KEY" not in os.environ:
|
if "TAVILY_SEARCH_API_KEY" not in os.environ:
|
||||||
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
|
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
|
||||||
|
|
||||||
|
tools = llama_stack_client.tool_runtime.list_tools()
|
||||||
|
assert any(tool.identifier == "web_search" for tool in tools)
|
||||||
|
|
||||||
response = llama_stack_client.tool_runtime.invoke_tool(
|
response = llama_stack_client.tool_runtime.invoke_tool(
|
||||||
tool_name="web_search", kwargs={"query": sample_search_query}
|
tool_name="web_search", kwargs={"query": sample_search_query}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify the response
|
# Verify the response
|
||||||
assert response.content is not None
|
assert response.content is not None
|
||||||
assert len(response.content) > 0
|
assert len(response.content) > 0
|
||||||
|
@ -49,11 +51,12 @@ def test_wolfram_alpha_tool(llama_stack_client, sample_wolfram_alpha_query):
|
||||||
if "WOLFRAM_ALPHA_API_KEY" not in os.environ:
|
if "WOLFRAM_ALPHA_API_KEY" not in os.environ:
|
||||||
pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test")
|
pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test")
|
||||||
|
|
||||||
|
tools = llama_stack_client.tool_runtime.list_tools()
|
||||||
|
assert any(tool.identifier == "wolfram_alpha" for tool in tools)
|
||||||
response = llama_stack_client.tool_runtime.invoke_tool(
|
response = llama_stack_client.tool_runtime.invoke_tool(
|
||||||
tool_name="wolfram_alpha", kwargs={"query": sample_wolfram_alpha_query}
|
tool_name="wolfram_alpha", kwargs={"query": sample_wolfram_alpha_query}
|
||||||
)
|
)
|
||||||
|
|
||||||
print(response.content)
|
|
||||||
assert response.content is not None
|
assert response.content is not None
|
||||||
assert len(response.content) > 0
|
assert len(response.content) > 0
|
||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
|
|
|
@ -31,13 +31,12 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
|
||||||
test_toolgroup_id = MCP_TOOLGROUP_ID
|
test_toolgroup_id = MCP_TOOLGROUP_ID
|
||||||
uri = mcp_server["server_url"]
|
uri = mcp_server["server_url"]
|
||||||
|
|
||||||
# registering itself should fail since it requires listing tools
|
# registering should not raise an error anymore even if you don't specify the auth token
|
||||||
with pytest.raises(Exception, match="Unauthorized"):
|
llama_stack_client.toolgroups.register(
|
||||||
llama_stack_client.toolgroups.register(
|
toolgroup_id=test_toolgroup_id,
|
||||||
toolgroup_id=test_toolgroup_id,
|
provider_id="model-context-protocol",
|
||||||
provider_id="model-context-protocol",
|
mcp_endpoint=dict(uri=uri),
|
||||||
mcp_endpoint=dict(uri=uri),
|
)
|
||||||
)
|
|
||||||
|
|
||||||
provider_data = {
|
provider_data = {
|
||||||
"mcp_headers": {
|
"mcp_headers": {
|
||||||
|
@ -50,18 +49,9 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
|
||||||
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
|
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
with pytest.raises(Exception, match="Unauthorized"):
|
||||||
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id, extra_headers=auth_headers)
|
llama_stack_client.tools.list()
|
||||||
except Exception as e:
|
|
||||||
# An error is OK since the toolgroup may not exist
|
|
||||||
print(f"Error unregistering toolgroup: {e}")
|
|
||||||
|
|
||||||
llama_stack_client.toolgroups.register(
|
|
||||||
toolgroup_id=test_toolgroup_id,
|
|
||||||
provider_id="model-context-protocol",
|
|
||||||
mcp_endpoint=dict(uri=uri),
|
|
||||||
extra_headers=auth_headers,
|
|
||||||
)
|
|
||||||
response = llama_stack_client.tools.list(
|
response = llama_stack_client.tools.list(
|
||||||
toolgroup_id=test_toolgroup_id,
|
toolgroup_id=test_toolgroup_id,
|
||||||
extra_headers=auth_headers,
|
extra_headers=auth_headers,
|
||||||
|
|
|
@ -51,7 +51,5 @@ def test_register_and_unregister_toolgroup(llama_stack_client):
|
||||||
with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"):
|
with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"):
|
||||||
llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id)
|
llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id)
|
||||||
|
|
||||||
# Verify tools are also unregistered
|
with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"):
|
||||||
unregister_tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
|
llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
|
||||||
assert isinstance(unregister_tools_list_response, list)
|
|
||||||
assert not unregister_tools_list_response
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataS
|
||||||
from llama_stack.apis.datatypes import Api
|
from llama_stack.apis.datatypes import Api
|
||||||
from llama_stack.apis.models.models import Model, ModelType
|
from llama_stack.apis.models.models import Model, ModelType
|
||||||
from llama_stack.apis.shields.shields import Shield
|
from llama_stack.apis.shields.shields import Shield
|
||||||
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolParameter
|
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
|
||||||
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
|
||||||
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
|
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
|
||||||
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
|
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
|
||||||
|
@ -101,11 +101,11 @@ class ToolGroupsImpl(Impl):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(Api.tool_runtime)
|
super().__init__(Api.tool_runtime)
|
||||||
|
|
||||||
async def register_tool(self, tool):
|
async def register_toolgroup(self, toolgroup: ToolGroup):
|
||||||
return tool
|
return toolgroup
|
||||||
|
|
||||||
async def unregister_tool(self, tool_name: str):
|
async def unregister_toolgroup(self, toolgroup_id: str):
|
||||||
return tool_name
|
return toolgroup_id
|
||||||
|
|
||||||
async def list_runtime_tools(self, toolgroup_id, mcp_endpoint):
|
async def list_runtime_tools(self, toolgroup_id, mcp_endpoint):
|
||||||
return ListToolDefsResponse(
|
return ListToolDefsResponse(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue