mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
simplify toolgroups registration
This commit is contained in:
parent
ba242c04cc
commit
f9a98c278a
15 changed files with 350 additions and 256 deletions
|
@ -137,15 +137,15 @@ class Session(BaseModel):
|
|||
memory_bank: Optional[MemoryBank] = None
|
||||
|
||||
|
||||
class AgentToolWithArgs(BaseModel):
|
||||
class AgentToolGroupWithArgs(BaseModel):
|
||||
name: str
|
||||
args: Dict[str, Any]
|
||||
|
||||
|
||||
AgentTool = register_schema(
|
||||
AgentToolGroup = register_schema(
|
||||
Union[
|
||||
str,
|
||||
AgentToolWithArgs,
|
||||
AgentToolGroupWithArgs,
|
||||
],
|
||||
name="AgentTool",
|
||||
)
|
||||
|
@ -156,7 +156,7 @@ class AgentConfigCommon(BaseModel):
|
|||
|
||||
input_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
tools: Optional[List[AgentTool]] = Field(default_factory=list)
|
||||
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
|
||||
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
|
||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||
|
@ -278,7 +278,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
|||
]
|
||||
|
||||
documents: Optional[List[Document]] = None
|
||||
tools: Optional[List[AgentTool]] = None
|
||||
toolgroups: Optional[List[AgentToolGroup]] = None
|
||||
|
||||
stream: Optional[bool] = False
|
||||
|
||||
|
@ -317,7 +317,7 @@ class Agents(Protocol):
|
|||
],
|
||||
stream: Optional[bool] = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
tools: Optional[List[AgentTool]] = None,
|
||||
tools: Optional[List[AgentToolGroup]] = None,
|
||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||
|
||||
@webmethod(route="/agents/turn/get")
|
||||
|
|
|
@ -5,10 +5,10 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from llama_models.llama3.api.datatypes import BuiltinTool, ToolPromptFormat
|
||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Protocol, runtime_checkable
|
||||
|
||||
|
@ -22,7 +22,7 @@ class ToolParameter(BaseModel):
|
|||
name: str
|
||||
parameter_type: str
|
||||
description: str
|
||||
required: bool
|
||||
required: bool = Field(default=True)
|
||||
default: Optional[Any] = None
|
||||
|
||||
|
||||
|
@ -36,7 +36,7 @@ class ToolHost(Enum):
|
|||
@json_schema_type
|
||||
class Tool(Resource):
|
||||
type: Literal[ResourceType.tool.value] = ResourceType.tool.value
|
||||
tool_group: str
|
||||
toolgroup_id: str
|
||||
tool_host: ToolHost
|
||||
description: str
|
||||
parameters: List[ToolParameter]
|
||||
|
@ -58,41 +58,19 @@ class ToolDef(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MCPToolGroupDef(BaseModel):
|
||||
"""
|
||||
A tool group that is defined by in a model context protocol server.
|
||||
Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information.
|
||||
"""
|
||||
|
||||
type: Literal["model_context_protocol"] = "model_context_protocol"
|
||||
endpoint: URL
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class UserDefinedToolGroupDef(BaseModel):
|
||||
type: Literal["user_defined"] = "user_defined"
|
||||
tools: List[ToolDef]
|
||||
|
||||
|
||||
ToolGroupDef = register_schema(
|
||||
Annotated[
|
||||
Union[MCPToolGroupDef, UserDefinedToolGroupDef], Field(discriminator="type")
|
||||
],
|
||||
name="ToolGroupDef",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolGroupInput(BaseModel):
|
||||
tool_group_id: str
|
||||
tool_group_def: ToolGroupDef
|
||||
provider_id: Optional[str] = None
|
||||
toolgroup_id: str
|
||||
provider_id: str
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
mcp_endpoint: Optional[URL] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolGroup(Resource):
|
||||
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
|
||||
mcp_endpoint: Optional[URL] = None
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -104,6 +82,7 @@ class ToolInvocationResult(BaseModel):
|
|||
|
||||
class ToolStore(Protocol):
|
||||
def get_tool(self, tool_name: str) -> Tool: ...
|
||||
def get_tool_group(self, tool_group_id: str) -> ToolGroup: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
@ -112,9 +91,10 @@ class ToolGroups(Protocol):
|
|||
@webmethod(route="/toolgroups/register", method="POST")
|
||||
async def register_tool_group(
|
||||
self,
|
||||
tool_group_id: str,
|
||||
tool_group_def: ToolGroupDef,
|
||||
provider_id: Optional[str] = None,
|
||||
toolgroup_id: str,
|
||||
provider_id: str,
|
||||
mcp_endpoint: Optional[URL] = None,
|
||||
args: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Register a tool group"""
|
||||
...
|
||||
|
@ -122,7 +102,7 @@ class ToolGroups(Protocol):
|
|||
@webmethod(route="/toolgroups/get", method="GET")
|
||||
async def get_tool_group(
|
||||
self,
|
||||
tool_group_id: str,
|
||||
toolgroup_id: str,
|
||||
) -> ToolGroup: ...
|
||||
|
||||
@webmethod(route="/toolgroups/list", method="GET")
|
||||
|
@ -149,8 +129,10 @@ class ToolGroups(Protocol):
|
|||
class ToolRuntime(Protocol):
|
||||
tool_store: ToolStore
|
||||
|
||||
@webmethod(route="/tool-runtime/discover", method="POST")
|
||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: ...
|
||||
@webmethod(route="/tool-runtime/list-tools", method="POST")
|
||||
async def list_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]: ...
|
||||
|
||||
@webmethod(route="/tool-runtime/invoke", method="POST")
|
||||
async def invoke_tool(
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||
from llama_stack.apis.eval import (
|
||||
AppEvalTaskConfig,
|
||||
|
@ -38,7 +38,7 @@ from llama_stack.apis.scoring import (
|
|||
ScoringFnParams,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolRuntime
|
||||
from llama_stack.apis.tools import ToolDef, ToolRuntime
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
||||
|
||||
|
@ -417,7 +417,9 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
args=args,
|
||||
)
|
||||
|
||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
|
||||
return await self.routing_table.get_provider_impl(
|
||||
tool_group.name
|
||||
).discover_tools(tool_group)
|
||||
async def list_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(
|
||||
tool_group_id, mcp_endpoint
|
||||
)
|
||||
|
|
|
@ -26,15 +26,7 @@ from llama_stack.apis.scoring_functions import (
|
|||
ScoringFunctions,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield, Shields
|
||||
from llama_stack.apis.tools import (
|
||||
MCPToolGroupDef,
|
||||
Tool,
|
||||
ToolGroup,
|
||||
ToolGroupDef,
|
||||
ToolGroups,
|
||||
ToolHost,
|
||||
UserDefinedToolGroupDef,
|
||||
)
|
||||
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroups, ToolHost
|
||||
from llama_stack.distribution.datatypes import (
|
||||
RoutableObject,
|
||||
RoutableObjectWithProvider,
|
||||
|
@ -496,51 +488,38 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
|
||||
tools = await self.get_all_with_type("tool")
|
||||
if tool_group_id:
|
||||
tools = [tool for tool in tools if tool.tool_group == tool_group_id]
|
||||
tools = [tool for tool in tools if tool.toolgroup_id == tool_group_id]
|
||||
return tools
|
||||
|
||||
async def list_tool_groups(self) -> List[ToolGroup]:
|
||||
return await self.get_all_with_type("tool_group")
|
||||
|
||||
async def get_tool_group(self, tool_group_id: str) -> ToolGroup:
|
||||
return await self.get_object_by_identifier("tool_group", tool_group_id)
|
||||
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
||||
return await self.get_object_by_identifier("tool_group", toolgroup_id)
|
||||
|
||||
async def get_tool(self, tool_name: str) -> Tool:
|
||||
return await self.get_object_by_identifier("tool", tool_name)
|
||||
|
||||
async def register_tool_group(
|
||||
self,
|
||||
tool_group_id: str,
|
||||
tool_group_def: ToolGroupDef,
|
||||
provider_id: Optional[str] = None,
|
||||
toolgroup_id: str,
|
||||
provider_id: str,
|
||||
mcp_endpoint: Optional[URL] = None,
|
||||
args: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
tools = []
|
||||
tool_defs = []
|
||||
tool_host = ToolHost.distribution
|
||||
if provider_id is None:
|
||||
if len(self.impls_by_provider_id.keys()) > 1:
|
||||
raise ValueError(
|
||||
f"No provider_id specified and multiple providers available. Please specify a provider_id. Available providers: {', '.join(self.impls_by_provider_id.keys())}"
|
||||
)
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
|
||||
# parse tool group to the type if dict
|
||||
tool_group_def = TypeAdapter(ToolGroupDef).validate_python(tool_group_def)
|
||||
if isinstance(tool_group_def, MCPToolGroupDef):
|
||||
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
|
||||
tool_group_def
|
||||
)
|
||||
tool_host = ToolHost.model_context_protocol
|
||||
elif isinstance(tool_group_def, UserDefinedToolGroupDef):
|
||||
tool_defs = tool_group_def.tools
|
||||
else:
|
||||
raise ValueError(f"Unknown tool group: {tool_group_def}")
|
||||
tool_defs = await self.impls_by_provider_id[provider_id].list_tools(
|
||||
toolgroup_id, mcp_endpoint
|
||||
)
|
||||
tool_host = (
|
||||
ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
||||
)
|
||||
|
||||
for tool_def in tool_defs:
|
||||
tools.append(
|
||||
Tool(
|
||||
identifier=tool_def.name,
|
||||
tool_group=tool_group_id,
|
||||
toolgroup_id=toolgroup_id,
|
||||
description=tool_def.description or "",
|
||||
parameters=tool_def.parameters or [],
|
||||
provider_id=provider_id,
|
||||
|
@ -565,9 +544,11 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
|
||||
await self.dist_registry.register(
|
||||
ToolGroup(
|
||||
identifier=tool_group_id,
|
||||
identifier=toolgroup_id,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=tool_group_id,
|
||||
provider_resource_id=toolgroup_id,
|
||||
mcp_endpoint=mcp_endpoint,
|
||||
args=args,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ import secrets
|
|||
import string
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, Dict, List, Optional
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
@ -21,8 +21,8 @@ from llama_models.llama3.api.datatypes import BuiltinTool, ToolCall, ToolParamDe
|
|||
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
AgentTool,
|
||||
AgentToolWithArgs,
|
||||
AgentToolGroup,
|
||||
AgentToolGroupWithArgs,
|
||||
AgentTurnCreateRequest,
|
||||
AgentTurnResponseEvent,
|
||||
AgentTurnResponseEventType,
|
||||
|
@ -76,6 +76,10 @@ def make_random_string(length: int = 8):
|
|||
|
||||
|
||||
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||
MEMORY_TOOL_GROUP_ID = "builtin::memory"
|
||||
MEMORY_QUERY_TOOL = "query_memory"
|
||||
CODE_INTERPRETER_TOOL = "code_interpreter"
|
||||
WEB_SEARCH_TOOL = "web_search"
|
||||
|
||||
|
||||
class ChatAgent(ShieldRunnerMixin):
|
||||
|
@ -192,7 +196,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
documents=request.documents,
|
||||
tools_for_turn=request.tools,
|
||||
toolgroups_for_turn=request.toolgroups,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
log.info(
|
||||
|
@ -243,7 +247,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
tools_for_turn: Optional[List[AgentTool]] = None,
|
||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||
) -> AsyncGenerator:
|
||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
||||
|
@ -266,7 +270,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
sampling_params,
|
||||
stream,
|
||||
documents,
|
||||
tools_for_turn,
|
||||
toolgroups_for_turn,
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
|
@ -362,21 +366,24 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
tools_for_turn: Optional[List[AgentTool]] = None,
|
||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||
) -> AsyncGenerator:
|
||||
tool_args = {}
|
||||
if tools_for_turn:
|
||||
for tool in tools_for_turn:
|
||||
if isinstance(tool, AgentToolWithArgs):
|
||||
tool_args[tool.name] = tool.args
|
||||
toolgroup_args = {}
|
||||
for toolgroup in self.agent_config.toolgroups:
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||
toolgroup_args[toolgroup.name] = toolgroup.args
|
||||
if toolgroups_for_turn:
|
||||
for toolgroup in toolgroups_for_turn:
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||
toolgroup_args[toolgroup.name] = toolgroup.args
|
||||
|
||||
tool_defs = await self._get_tool_defs(tools_for_turn)
|
||||
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
|
||||
if documents:
|
||||
await self.handle_documents(
|
||||
session_id, documents, input_messages, tool_defs
|
||||
)
|
||||
if "memory" in tool_defs and len(input_messages) > 0:
|
||||
with tracing.span("memory_tool") as span:
|
||||
if MEMORY_QUERY_TOOL in tool_defs and len(input_messages) > 0:
|
||||
with tracing.span(MEMORY_QUERY_TOOL) as span:
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
|
@ -386,18 +393,16 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
extra_args = tool_args.get("memory", {})
|
||||
tool_args = {
|
||||
# Query memory with the last message's content
|
||||
"query": input_messages[-1],
|
||||
**extra_args,
|
||||
query_args = {
|
||||
"messages": [msg.content for msg in input_messages],
|
||||
**toolgroup_args.get(MEMORY_TOOL_GROUP_ID, {}),
|
||||
}
|
||||
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
# if the session has a memory bank id, let the memory tool use it
|
||||
if session_info.memory_bank_id:
|
||||
tool_args["memory_bank_id"] = session_info.memory_bank_id
|
||||
serialized_args = tracing.serialize_value(tool_args)
|
||||
query_args["memory_bank_id"] = session_info.memory_bank_id
|
||||
serialized_args = tracing.serialize_value(query_args)
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
|
@ -415,8 +420,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name="memory",
|
||||
args=tool_args,
|
||||
tool_name=MEMORY_QUERY_TOOL,
|
||||
args=query_args,
|
||||
)
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
|
@ -485,7 +490,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tools=[
|
||||
tool
|
||||
for tool in tool_defs.values()
|
||||
if tool.tool_name != "memory"
|
||||
if tool_to_group.get(tool.tool_name, None)
|
||||
!= MEMORY_TOOL_GROUP_ID
|
||||
],
|
||||
tool_prompt_format=self.agent_config.tool_prompt_format,
|
||||
stream=True,
|
||||
|
@ -632,6 +638,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self.tool_runtime_api,
|
||||
session_id,
|
||||
[message],
|
||||
toolgroup_args,
|
||||
tool_to_group,
|
||||
)
|
||||
assert (
|
||||
len(result_messages) == 1
|
||||
|
@ -690,26 +698,37 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
n_iter += 1
|
||||
|
||||
async def _get_tool_defs(
|
||||
self, tools_for_turn: Optional[List[AgentTool]]
|
||||
self, toolgroups_for_turn: Optional[List[AgentToolGroup]]
|
||||
) -> Dict[str, ToolDefinition]:
|
||||
# Determine which tools to include
|
||||
agent_config_tools = set(
|
||||
tool.name if isinstance(tool, AgentToolWithArgs) else tool
|
||||
for tool in self.agent_config.tools
|
||||
agent_config_toolgroups = set(
|
||||
(
|
||||
toolgroup.name
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs)
|
||||
else toolgroup
|
||||
)
|
||||
for toolgroup in self.agent_config.toolgroups
|
||||
)
|
||||
tools_for_turn_set = (
|
||||
agent_config_tools
|
||||
if tools_for_turn is None
|
||||
toolgroups_for_turn_set = (
|
||||
agent_config_toolgroups
|
||||
if toolgroups_for_turn is None
|
||||
else {
|
||||
tool.name if isinstance(tool, AgentToolWithArgs) else tool
|
||||
for tool in tools_for_turn
|
||||
(
|
||||
toolgroup.name
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs)
|
||||
else toolgroup
|
||||
)
|
||||
for toolgroup in toolgroups_for_turn
|
||||
}
|
||||
)
|
||||
|
||||
ret = {}
|
||||
tool_def_map = {}
|
||||
tool_to_group = {}
|
||||
|
||||
for tool_def in self.agent_config.client_tools:
|
||||
ret[tool_def.name] = ToolDefinition(
|
||||
if tool_def_map.get(tool_def.name, None):
|
||||
raise ValueError(f"Tool {tool_def.name} already exists")
|
||||
tool_def_map[tool_def.name] = ToolDefinition(
|
||||
tool_name=tool_def.name,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
|
@ -722,41 +741,42 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
for param in tool_def.parameters
|
||||
},
|
||||
)
|
||||
|
||||
for tool_name in agent_config_tools:
|
||||
if tool_name not in tools_for_turn_set:
|
||||
tool_to_group[tool_def.name] = "__client_tools__"
|
||||
for toolgroup_name in agent_config_toolgroups:
|
||||
if toolgroup_name not in toolgroups_for_turn_set:
|
||||
continue
|
||||
tools = await self.tool_groups_api.list_tools(tool_group_id=toolgroup_name)
|
||||
for tool_def in tools:
|
||||
if tool_def.built_in_type:
|
||||
if tool_def_map.get(tool_def.built_in_type, None):
|
||||
raise ValueError(
|
||||
f"Tool {tool_def.built_in_type} already exists"
|
||||
)
|
||||
|
||||
tool_def = await self.tool_groups_api.get_tool(tool_name)
|
||||
if tool_def is None:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
|
||||
if tool_def.identifier.startswith("builtin::"):
|
||||
built_in_type = tool_def.identifier[len("builtin::") :]
|
||||
if built_in_type == "web_search":
|
||||
built_in_type = "brave_search"
|
||||
if built_in_type not in BuiltinTool.__members__:
|
||||
raise ValueError(f"Unknown built-in tool: {built_in_type}")
|
||||
ret[built_in_type] = ToolDefinition(
|
||||
tool_name=BuiltinTool(built_in_type)
|
||||
)
|
||||
continue
|
||||
|
||||
ret[tool_def.identifier] = ToolDefinition(
|
||||
tool_name=tool_def.identifier,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
tool_def_map[tool_def.built_in_type] = ToolDefinition(
|
||||
tool_name=tool_def.built_in_type
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
},
|
||||
)
|
||||
tool_to_group[tool_def.built_in_type] = tool_def.toolgroup_id
|
||||
continue
|
||||
|
||||
return ret
|
||||
if tool_def_map.get(tool_def.identifier, None):
|
||||
raise ValueError(f"Tool {tool_def.identifier} already exists")
|
||||
tool_def_map[tool_def.identifier] = ToolDefinition(
|
||||
tool_name=tool_def.identifier,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
},
|
||||
)
|
||||
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
|
||||
|
||||
return tool_def_map, tool_to_group
|
||||
|
||||
async def handle_documents(
|
||||
self,
|
||||
|
@ -765,8 +785,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
input_messages: List[Message],
|
||||
tool_defs: Dict[str, ToolDefinition],
|
||||
) -> None:
|
||||
memory_tool = tool_defs.get("memory", None)
|
||||
code_interpreter_tool = tool_defs.get("code_interpreter", None)
|
||||
memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None)
|
||||
code_interpreter_tool = tool_defs.get(CODE_INTERPRETER_TOOL, None)
|
||||
content_items = []
|
||||
url_items = []
|
||||
pattern = re.compile("^(https?://|file://|data:)")
|
||||
|
@ -903,7 +923,11 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
|||
|
||||
|
||||
async def execute_tool_call_maybe(
|
||||
tool_runtime_api: ToolRuntime, session_id: str, messages: List[CompletionMessage]
|
||||
tool_runtime_api: ToolRuntime,
|
||||
session_id: str,
|
||||
messages: List[CompletionMessage],
|
||||
toolgroup_args: Dict[str, Dict[str, Any]],
|
||||
tool_to_group: Dict[str, str],
|
||||
) -> List[ToolResponseMessage]:
|
||||
# While Tools.run interface takes a list of messages,
|
||||
# All tools currently only run on a single message
|
||||
|
@ -915,18 +939,26 @@ async def execute_tool_call_maybe(
|
|||
|
||||
tool_call = message.tool_calls[0]
|
||||
name = tool_call.tool_name
|
||||
group_name = tool_to_group.get(name, None)
|
||||
if group_name is None:
|
||||
raise ValueError(f"Tool {name} not found in any tool group")
|
||||
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
||||
tool_call_args = tool_call.arguments
|
||||
tool_call_args.update(toolgroup_args.get(group_name, {}))
|
||||
if isinstance(name, BuiltinTool):
|
||||
if name == BuiltinTool.brave_search:
|
||||
name = "builtin::web_search"
|
||||
name = WEB_SEARCH_TOOL
|
||||
else:
|
||||
name = "builtin::" + name.value
|
||||
name = name.value
|
||||
|
||||
result = await tool_runtime_api.invoke_tool(
|
||||
tool_name=name,
|
||||
args=dict(
|
||||
session_id=session_id,
|
||||
**tool_call.arguments,
|
||||
**tool_call_args,
|
||||
),
|
||||
)
|
||||
|
||||
return [
|
||||
ToolResponseMessage(
|
||||
call_id=tool_call.call_id,
|
||||
|
|
|
@ -19,7 +19,7 @@ from llama_stack.apis.agents import (
|
|||
Agents,
|
||||
AgentSessionCreateResponse,
|
||||
AgentStepResponse,
|
||||
AgentTool,
|
||||
AgentToolGroup,
|
||||
AgentTurnCreateRequest,
|
||||
Document,
|
||||
Session,
|
||||
|
@ -147,7 +147,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
ToolResponseMessage,
|
||||
]
|
||||
],
|
||||
tools: Optional[List[AgentTool]] = None,
|
||||
tools: Optional[List[AgentToolGroup]] = None,
|
||||
documents: Optional[List[Document]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
) -> AsyncGenerator:
|
||||
|
|
|
@ -7,9 +7,16 @@
|
|||
|
||||
import logging
|
||||
import tempfile
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
|
||||
from .code_execution import CodeExecutionContext, CodeExecutionRequest, CodeExecutor
|
||||
|
@ -35,8 +42,22 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
return
|
||||
|
||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
|
||||
raise NotImplementedError("Code interpreter tool group not supported")
|
||||
async def list_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return [
|
||||
ToolDef(
|
||||
name="code_interpreter",
|
||||
description="Execute code",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="code",
|
||||
description="The code to execute",
|
||||
parameter_type="string",
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
|
|
|
@ -5,6 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import List
|
||||
|
||||
from jinja2 import Template
|
||||
|
||||
from llama_stack.apis.inference import Message, UserMessage
|
||||
|
@ -22,7 +24,7 @@ from .config import (
|
|||
|
||||
async def generate_rag_query(
|
||||
config: MemoryQueryGeneratorConfig,
|
||||
message: Message,
|
||||
messages: List[Message],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
@ -30,9 +32,9 @@ async def generate_rag_query(
|
|||
retrieving relevant information from the memory bank.
|
||||
"""
|
||||
if config.type == MemoryQueryGenerator.default.value:
|
||||
query = await default_rag_query_generator(config, message, **kwargs)
|
||||
query = await default_rag_query_generator(config, messages, **kwargs)
|
||||
elif config.type == MemoryQueryGenerator.llm.value:
|
||||
query = await llm_rag_query_generator(config, message, **kwargs)
|
||||
query = await llm_rag_query_generator(config, messages, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
|
||||
return query
|
||||
|
@ -40,21 +42,21 @@ async def generate_rag_query(
|
|||
|
||||
async def default_rag_query_generator(
|
||||
config: DefaultMemoryQueryGeneratorConfig,
|
||||
message: Message,
|
||||
messages: List[Message],
|
||||
**kwargs,
|
||||
):
|
||||
return interleaved_content_as_str(message.content)
|
||||
return config.sep.join(interleaved_content_as_str(m.content) for m in messages)
|
||||
|
||||
|
||||
async def llm_rag_query_generator(
|
||||
config: LLMMemoryQueryGeneratorConfig,
|
||||
message: Message,
|
||||
messages: List[Message],
|
||||
**kwargs,
|
||||
):
|
||||
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
|
||||
inference_api = kwargs["inference_api"]
|
||||
|
||||
m_dict = {"messages": [message.model_dump()]}
|
||||
m_dict = {"messages": [message.model_dump() for message in messages]}
|
||||
|
||||
template = Template(config.template)
|
||||
content = template.render(m_dict)
|
||||
|
|
|
@ -10,13 +10,14 @@ import secrets
|
|||
import string
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent, Message
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.memory import Memory, QueryDocumentsResponse
|
||||
from llama_stack.apis.memory_banks import MemoryBanks
|
||||
from llama_stack.apis.tools import (
|
||||
ToolDef,
|
||||
ToolGroupDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
|
@ -50,17 +51,31 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]:
|
||||
return []
|
||||
async def list_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return [
|
||||
ToolDef(
|
||||
name="memory",
|
||||
description="Retrieve context from memory",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="input_messages",
|
||||
description="The input messages to search for",
|
||||
parameter_type="array",
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
async def _retrieve_context(
|
||||
self, message: Message, bank_ids: List[str]
|
||||
self, input_messages: List[str], bank_ids: List[str]
|
||||
) -> Optional[List[InterleavedContent]]:
|
||||
if not bank_ids:
|
||||
return None
|
||||
query = await generate_rag_query(
|
||||
self.config.query_generator_config,
|
||||
message,
|
||||
input_messages,
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
tasks = [
|
||||
|
@ -106,17 +121,22 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
self, tool_name: str, args: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
tool = await self.tool_store.get_tool(tool_name)
|
||||
tool_group = await self.tool_store.get_tool_group(tool.toolgroup_id)
|
||||
final_args = tool_group.args or {}
|
||||
final_args.update(args)
|
||||
config = MemoryToolConfig()
|
||||
if tool.metadata.get("config") is not None:
|
||||
if tool.metadata and tool.metadata.get("config") is not None:
|
||||
config = MemoryToolConfig(**tool.metadata["config"])
|
||||
if "memory_bank_id" in args:
|
||||
bank_ids = [args["memory_bank_id"]]
|
||||
if "memory_bank_ids" in final_args:
|
||||
bank_ids = final_args["memory_bank_ids"]
|
||||
else:
|
||||
bank_ids = [
|
||||
bank_config.bank_id for bank_config in config.memory_bank_configs
|
||||
]
|
||||
if "messages" not in final_args:
|
||||
raise ValueError("messages are required")
|
||||
context = await self._retrieve_context(
|
||||
args["query"],
|
||||
final_args["messages"],
|
||||
bank_ids,
|
||||
)
|
||||
if context is None:
|
||||
|
|
|
@ -4,11 +4,18 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
|
||||
|
@ -41,8 +48,22 @@ class BraveSearchToolRuntimeImpl(
|
|||
)
|
||||
return provider_data.api_key
|
||||
|
||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
|
||||
raise NotImplementedError("Brave search tool group not supported")
|
||||
async def list_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return [
|
||||
ToolDef(
|
||||
name="web_search",
|
||||
description="Search the web for information",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="The query to search for",
|
||||
parameter_type="string",
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
|
|
|
@ -4,20 +4,18 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
MCPToolGroupDef,
|
||||
ToolDef,
|
||||
ToolGroupDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
UserDefinedToolDef,
|
||||
)
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
|
||||
|
@ -31,12 +29,14 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]:
|
||||
if not isinstance(tool_group, MCPToolGroupDef):
|
||||
raise ValueError(f"Unsupported tool group type: {type(tool_group)}")
|
||||
async def list_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
if mcp_endpoint is None:
|
||||
raise ValueError("mcp_endpoint is required")
|
||||
|
||||
tools = []
|
||||
async with sse_client(tool_group.endpoint.uri) as streams:
|
||||
async with sse_client(mcp_endpoint.uri) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
tools_result = await session.list_tools()
|
||||
|
@ -53,12 +53,12 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
)
|
||||
)
|
||||
tools.append(
|
||||
UserDefinedToolDef(
|
||||
ToolDef(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=parameters,
|
||||
metadata={
|
||||
"endpoint": tool_group.endpoint.uri,
|
||||
"endpoint": mcp_endpoint.uri,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
|
|
@ -5,11 +5,18 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
|
||||
|
@ -42,8 +49,22 @@ class TavilySearchToolRuntimeImpl(
|
|||
)
|
||||
return provider_data.api_key
|
||||
|
||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
|
||||
raise NotImplementedError("Tavily search tool group not supported")
|
||||
async def list_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return [
|
||||
ToolDef(
|
||||
name="web_search",
|
||||
description="Search the web for information",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="The query to search for",
|
||||
parameter_type="string",
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
|
|
|
@ -45,8 +45,7 @@ def common_params(inference_model):
|
|||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
available_tools=[],
|
||||
preprocessing_tools=[],
|
||||
toolgroups=[],
|
||||
max_infer_iters=5,
|
||||
)
|
||||
|
||||
|
@ -83,27 +82,27 @@ def query_attachment_messages():
|
|||
]
|
||||
|
||||
|
||||
async def create_agent_turn_with_search_tool(
|
||||
async def create_agent_turn_with_toolgroup(
|
||||
agents_stack: Dict[str, object],
|
||||
search_query_messages: List[object],
|
||||
common_params: Dict[str, str],
|
||||
tool_name: str,
|
||||
toolgroup_name: str,
|
||||
) -> None:
|
||||
"""
|
||||
Create an agent turn with a search tool.
|
||||
Create an agent turn with a toolgroup.
|
||||
|
||||
Args:
|
||||
agents_stack (Dict[str, object]): The agents stack.
|
||||
search_query_messages (List[object]): The search query messages.
|
||||
common_params (Dict[str, str]): The common parameters.
|
||||
search_tool_definition (SearchToolDefinition): The search tool definition.
|
||||
toolgroup_name (str): The name of the toolgroup.
|
||||
"""
|
||||
|
||||
# Create an agent with the search tool
|
||||
# Create an agent with the toolgroup
|
||||
agent_config = AgentConfig(
|
||||
**{
|
||||
**common_params,
|
||||
"tools": [tool_name],
|
||||
"toolgroups": [toolgroup_name],
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -249,7 +248,7 @@ class TestAgents:
|
|||
agent_config = AgentConfig(
|
||||
**{
|
||||
**common_params,
|
||||
"tools": ["memory"],
|
||||
"toolgroups": ["builtin::memory"],
|
||||
"tool_choice": ToolChoice.auto,
|
||||
}
|
||||
)
|
||||
|
@ -289,13 +288,58 @@ class TestAgents:
|
|||
if "TAVILY_SEARCH_API_KEY" not in os.environ:
|
||||
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
|
||||
|
||||
await create_agent_turn_with_search_tool(
|
||||
agents_stack,
|
||||
search_query_messages,
|
||||
common_params,
|
||||
"brave_search",
|
||||
# Create an agent with the toolgroup
|
||||
agent_config = AgentConfig(
|
||||
**{
|
||||
**common_params,
|
||||
"toolgroups": ["builtin::web_search"],
|
||||
}
|
||||
)
|
||||
|
||||
agent_id, session_id = await create_agent_session(
|
||||
agents_stack.impls[Api.agents], agent_config
|
||||
)
|
||||
turn_request = dict(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=search_query_messages,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
turn_response = [
|
||||
chunk
|
||||
async for chunk in await agents_stack.impls[Api.agents].create_agent_turn(
|
||||
**turn_request
|
||||
)
|
||||
]
|
||||
|
||||
assert len(turn_response) > 0
|
||||
assert all(
|
||||
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
|
||||
)
|
||||
|
||||
check_event_types(turn_response)
|
||||
|
||||
# Check for tool execution events
|
||||
tool_execution_events = [
|
||||
chunk
|
||||
for chunk in turn_response
|
||||
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
|
||||
and chunk.event.payload.step_details.step_type
|
||||
== StepType.tool_execution.value
|
||||
]
|
||||
assert len(tool_execution_events) > 0, "No tool execution events found"
|
||||
|
||||
# Check the tool execution details
|
||||
tool_execution = tool_execution_events[0].event.payload.step_details
|
||||
assert isinstance(tool_execution, ToolExecutionStep)
|
||||
assert len(tool_execution.tool_calls) > 0
|
||||
actual_tool_name = tool_execution.tool_calls[0].tool_name
|
||||
assert actual_tool_name == "web_search"
|
||||
assert len(tool_execution.tool_responses) > 0
|
||||
|
||||
check_turn_complete_event(turn_response, session_id, search_query_messages)
|
||||
|
||||
|
||||
def check_event_types(turn_response):
|
||||
event_types = [chunk.event.payload.event_type for chunk in turn_response]
|
||||
|
|
|
@ -8,16 +8,9 @@ import os
|
|||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||
|
||||
from llama_stack.apis.models import ModelInput, ModelType
|
||||
from llama_stack.apis.tools import (
|
||||
BuiltInToolDef,
|
||||
ToolGroupInput,
|
||||
ToolParameter,
|
||||
UserDefinedToolDef,
|
||||
UserDefinedToolGroupDef,
|
||||
)
|
||||
from llama_stack.apis.tools import ToolGroupInput
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
|
||||
|
@ -47,30 +40,7 @@ def tool_runtime_memory_and_search() -> ProviderFixture:
|
|||
@pytest.fixture(scope="session")
|
||||
def tool_group_input_memory() -> ToolGroupInput:
|
||||
return ToolGroupInput(
|
||||
tool_group_id="memory_group",
|
||||
tool_group=UserDefinedToolGroupDef(
|
||||
tools=[
|
||||
UserDefinedToolDef(
|
||||
name="memory",
|
||||
description="Query the memory bank",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="input_messages",
|
||||
description="The input messages to search for in memory",
|
||||
parameter_type="list",
|
||||
required=True,
|
||||
),
|
||||
],
|
||||
metadata={
|
||||
"config": {
|
||||
"memory_bank_configs": [
|
||||
{"bank_id": "test_bank", "type": "vector"}
|
||||
]
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
),
|
||||
toolgroup_id="builtin::memory",
|
||||
provider_id="memory-runtime",
|
||||
)
|
||||
|
||||
|
@ -78,10 +48,7 @@ def tool_group_input_memory() -> ToolGroupInput:
|
|||
@pytest.fixture(scope="session")
|
||||
def tool_group_input_tavily_search() -> ToolGroupInput:
|
||||
return ToolGroupInput(
|
||||
tool_group_id="tavily_search_group",
|
||||
tool_group=UserDefinedToolGroupDef(
|
||||
tools=[BuiltInToolDef(built_in_type=BuiltinTool.brave_search, metadata={})],
|
||||
),
|
||||
toolgroup_id="builtin::web_search",
|
||||
provider_id="tavily-search",
|
||||
)
|
||||
|
||||
|
|
|
@ -43,8 +43,8 @@ def sample_documents():
|
|||
|
||||
class TestTools:
|
||||
@pytest.mark.asyncio
|
||||
async def test_brave_search_tool(self, tools_stack, sample_search_query):
|
||||
"""Test the Brave search tool functionality."""
|
||||
async def test_web_search_tool(self, tools_stack, sample_search_query):
|
||||
"""Test the web search tool functionality."""
|
||||
if "TAVILY_SEARCH_API_KEY" not in os.environ:
|
||||
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
|
||||
|
||||
|
@ -52,7 +52,7 @@ class TestTools:
|
|||
|
||||
# Execute the tool
|
||||
response = await tools_impl.invoke_tool(
|
||||
tool_name="brave_search", args={"query": sample_search_query}
|
||||
tool_name="web_search", args={"query": sample_search_query}
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
|
@ -89,11 +89,12 @@ class TestTools:
|
|||
response = await tools_impl.invoke_tool(
|
||||
tool_name="memory",
|
||||
args={
|
||||
"input_messages": [
|
||||
"messages": [
|
||||
UserMessage(
|
||||
content="What are the main topics covered in the documentation?",
|
||||
)
|
||||
],
|
||||
"memory_bank_ids": ["test_bank"],
|
||||
},
|
||||
)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue