simplify toolgroups registration

This commit is contained in:
Dinesh Yeduguru 2025-01-07 15:37:52 -08:00
parent ba242c04cc
commit f9a98c278a
15 changed files with 350 additions and 256 deletions

View file

@ -13,7 +13,7 @@ import secrets
import string
import uuid
from datetime import datetime
from typing import AsyncGenerator, Dict, List, Optional
from typing import Any, AsyncGenerator, Dict, List, Optional
from urllib.parse import urlparse
import httpx
@ -21,8 +21,8 @@ from llama_models.llama3.api.datatypes import BuiltinTool, ToolCall, ToolParamDe
from llama_stack.apis.agents import (
AgentConfig,
AgentTool,
AgentToolWithArgs,
AgentToolGroup,
AgentToolGroupWithArgs,
AgentTurnCreateRequest,
AgentTurnResponseEvent,
AgentTurnResponseEventType,
@ -76,6 +76,10 @@ def make_random_string(length: int = 8):
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
MEMORY_TOOL_GROUP_ID = "builtin::memory"
MEMORY_QUERY_TOOL = "query_memory"
CODE_INTERPRETER_TOOL = "code_interpreter"
WEB_SEARCH_TOOL = "web_search"
class ChatAgent(ShieldRunnerMixin):
@ -192,7 +196,7 @@ class ChatAgent(ShieldRunnerMixin):
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
documents=request.documents,
tools_for_turn=request.tools,
toolgroups_for_turn=request.toolgroups,
):
if isinstance(chunk, CompletionMessage):
log.info(
@ -243,7 +247,7 @@ class ChatAgent(ShieldRunnerMixin):
sampling_params: SamplingParams,
stream: bool = False,
documents: Optional[List[Document]] = None,
tools_for_turn: Optional[List[AgentTool]] = None,
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
) -> AsyncGenerator:
# Doing async generators makes downstream code much simpler and everything amenable to
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
@ -266,7 +270,7 @@ class ChatAgent(ShieldRunnerMixin):
sampling_params,
stream,
documents,
tools_for_turn,
toolgroups_for_turn,
):
if isinstance(res, bool):
return
@ -362,21 +366,24 @@ class ChatAgent(ShieldRunnerMixin):
sampling_params: SamplingParams,
stream: bool = False,
documents: Optional[List[Document]] = None,
tools_for_turn: Optional[List[AgentTool]] = None,
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
) -> AsyncGenerator:
tool_args = {}
if tools_for_turn:
for tool in tools_for_turn:
if isinstance(tool, AgentToolWithArgs):
tool_args[tool.name] = tool.args
toolgroup_args = {}
for toolgroup in self.agent_config.toolgroups:
if isinstance(toolgroup, AgentToolGroupWithArgs):
toolgroup_args[toolgroup.name] = toolgroup.args
if toolgroups_for_turn:
for toolgroup in toolgroups_for_turn:
if isinstance(toolgroup, AgentToolGroupWithArgs):
toolgroup_args[toolgroup.name] = toolgroup.args
tool_defs = await self._get_tool_defs(tools_for_turn)
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
if documents:
await self.handle_documents(
session_id, documents, input_messages, tool_defs
)
if "memory" in tool_defs and len(input_messages) > 0:
with tracing.span("memory_tool") as span:
if MEMORY_QUERY_TOOL in tool_defs and len(input_messages) > 0:
with tracing.span(MEMORY_QUERY_TOOL) as span:
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -386,18 +393,16 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
extra_args = tool_args.get("memory", {})
tool_args = {
# Query memory with the last message's content
"query": input_messages[-1],
**extra_args,
query_args = {
"messages": [msg.content for msg in input_messages],
**toolgroup_args.get(MEMORY_TOOL_GROUP_ID, {}),
}
session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it
if session_info.memory_bank_id:
tool_args["memory_bank_id"] = session_info.memory_bank_id
serialized_args = tracing.serialize_value(tool_args)
query_args["memory_bank_id"] = session_info.memory_bank_id
serialized_args = tracing.serialize_value(query_args)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
@ -415,8 +420,8 @@ class ChatAgent(ShieldRunnerMixin):
)
)
result = await self.tool_runtime_api.invoke_tool(
tool_name="memory",
args=tool_args,
tool_name=MEMORY_QUERY_TOOL,
args=query_args,
)
yield AgentTurnResponseStreamChunk(
@ -485,7 +490,8 @@ class ChatAgent(ShieldRunnerMixin):
tools=[
tool
for tool in tool_defs.values()
if tool.tool_name != "memory"
if tool_to_group.get(tool.tool_name, None)
!= MEMORY_TOOL_GROUP_ID
],
tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True,
@ -632,6 +638,8 @@ class ChatAgent(ShieldRunnerMixin):
self.tool_runtime_api,
session_id,
[message],
toolgroup_args,
tool_to_group,
)
assert (
len(result_messages) == 1
@ -690,26 +698,37 @@ class ChatAgent(ShieldRunnerMixin):
n_iter += 1
async def _get_tool_defs(
self, tools_for_turn: Optional[List[AgentTool]]
self, toolgroups_for_turn: Optional[List[AgentToolGroup]]
) -> Dict[str, ToolDefinition]:
# Determine which tools to include
agent_config_tools = set(
tool.name if isinstance(tool, AgentToolWithArgs) else tool
for tool in self.agent_config.tools
agent_config_toolgroups = set(
(
toolgroup.name
if isinstance(toolgroup, AgentToolGroupWithArgs)
else toolgroup
)
for toolgroup in self.agent_config.toolgroups
)
tools_for_turn_set = (
agent_config_tools
if tools_for_turn is None
toolgroups_for_turn_set = (
agent_config_toolgroups
if toolgroups_for_turn is None
else {
tool.name if isinstance(tool, AgentToolWithArgs) else tool
for tool in tools_for_turn
(
toolgroup.name
if isinstance(toolgroup, AgentToolGroupWithArgs)
else toolgroup
)
for toolgroup in toolgroups_for_turn
}
)
ret = {}
tool_def_map = {}
tool_to_group = {}
for tool_def in self.agent_config.client_tools:
ret[tool_def.name] = ToolDefinition(
if tool_def_map.get(tool_def.name, None):
raise ValueError(f"Tool {tool_def.name} already exists")
tool_def_map[tool_def.name] = ToolDefinition(
tool_name=tool_def.name,
description=tool_def.description,
parameters={
@ -722,41 +741,42 @@ class ChatAgent(ShieldRunnerMixin):
for param in tool_def.parameters
},
)
for tool_name in agent_config_tools:
if tool_name not in tools_for_turn_set:
tool_to_group[tool_def.name] = "__client_tools__"
for toolgroup_name in agent_config_toolgroups:
if toolgroup_name not in toolgroups_for_turn_set:
continue
tools = await self.tool_groups_api.list_tools(tool_group_id=toolgroup_name)
for tool_def in tools:
if tool_def.built_in_type:
if tool_def_map.get(tool_def.built_in_type, None):
raise ValueError(
f"Tool {tool_def.built_in_type} already exists"
)
tool_def = await self.tool_groups_api.get_tool(tool_name)
if tool_def is None:
raise ValueError(f"Tool {tool_name} not found")
if tool_def.identifier.startswith("builtin::"):
built_in_type = tool_def.identifier[len("builtin::") :]
if built_in_type == "web_search":
built_in_type = "brave_search"
if built_in_type not in BuiltinTool.__members__:
raise ValueError(f"Unknown built-in tool: {built_in_type}")
ret[built_in_type] = ToolDefinition(
tool_name=BuiltinTool(built_in_type)
)
continue
ret[tool_def.identifier] = ToolDefinition(
tool_name=tool_def.identifier,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
tool_def_map[tool_def.built_in_type] = ToolDefinition(
tool_name=tool_def.built_in_type
)
for param in tool_def.parameters
},
)
tool_to_group[tool_def.built_in_type] = tool_def.toolgroup_id
continue
return ret
if tool_def_map.get(tool_def.identifier, None):
raise ValueError(f"Tool {tool_def.identifier} already exists")
tool_def_map[tool_def.identifier] = ToolDefinition(
tool_name=tool_def.identifier,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool_def.parameters
},
)
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
return tool_def_map, tool_to_group
async def handle_documents(
self,
@ -765,8 +785,8 @@ class ChatAgent(ShieldRunnerMixin):
input_messages: List[Message],
tool_defs: Dict[str, ToolDefinition],
) -> None:
memory_tool = tool_defs.get("memory", None)
code_interpreter_tool = tool_defs.get("code_interpreter", None)
memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None)
code_interpreter_tool = tool_defs.get(CODE_INTERPRETER_TOOL, None)
content_items = []
url_items = []
pattern = re.compile("^(https?://|file://|data:)")
@ -903,7 +923,11 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
async def execute_tool_call_maybe(
tool_runtime_api: ToolRuntime, session_id: str, messages: List[CompletionMessage]
tool_runtime_api: ToolRuntime,
session_id: str,
messages: List[CompletionMessage],
toolgroup_args: Dict[str, Dict[str, Any]],
tool_to_group: Dict[str, str],
) -> List[ToolResponseMessage]:
# While Tools.run interface takes a list of messages,
# All tools currently only run on a single message
@ -915,18 +939,26 @@ async def execute_tool_call_maybe(
tool_call = message.tool_calls[0]
name = tool_call.tool_name
group_name = tool_to_group.get(name, None)
if group_name is None:
raise ValueError(f"Tool {name} not found in any tool group")
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
tool_call_args = tool_call.arguments
tool_call_args.update(toolgroup_args.get(group_name, {}))
if isinstance(name, BuiltinTool):
if name == BuiltinTool.brave_search:
name = "builtin::web_search"
name = WEB_SEARCH_TOOL
else:
name = "builtin::" + name.value
name = name.value
result = await tool_runtime_api.invoke_tool(
tool_name=name,
args=dict(
session_id=session_id,
**tool_call.arguments,
**tool_call_args,
),
)
return [
ToolResponseMessage(
call_id=tool_call.call_id,

View file

@ -19,7 +19,7 @@ from llama_stack.apis.agents import (
Agents,
AgentSessionCreateResponse,
AgentStepResponse,
AgentTool,
AgentToolGroup,
AgentTurnCreateRequest,
Document,
Session,
@ -147,7 +147,7 @@ class MetaReferenceAgentsImpl(Agents):
ToolResponseMessage,
]
],
tools: Optional[List[AgentTool]] = None,
tools: Optional[List[AgentToolGroup]] = None,
documents: Optional[List[Document]] = None,
stream: Optional[bool] = False,
) -> AsyncGenerator:

View file

@ -7,9 +7,16 @@
import logging
import tempfile
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
Tool,
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .code_execution import CodeExecutionContext, CodeExecutionRequest, CodeExecutor
@ -35,8 +42,22 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async def unregister_tool(self, tool_id: str) -> None:
return
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
raise NotImplementedError("Code interpreter tool group not supported")
async def list_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="code_interpreter",
description="Execute code",
parameters=[
ToolParameter(
name="code",
description="The code to execute",
parameter_type="string",
),
],
)
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]

View file

@ -5,6 +5,8 @@
# the root directory of this source tree.
from typing import List
from jinja2 import Template
from llama_stack.apis.inference import Message, UserMessage
@ -22,7 +24,7 @@ from .config import (
async def generate_rag_query(
config: MemoryQueryGeneratorConfig,
message: Message,
messages: List[Message],
**kwargs,
):
"""
@ -30,9 +32,9 @@ async def generate_rag_query(
retrieving relevant information from the memory bank.
"""
if config.type == MemoryQueryGenerator.default.value:
query = await default_rag_query_generator(config, message, **kwargs)
query = await default_rag_query_generator(config, messages, **kwargs)
elif config.type == MemoryQueryGenerator.llm.value:
query = await llm_rag_query_generator(config, message, **kwargs)
query = await llm_rag_query_generator(config, messages, **kwargs)
else:
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
return query
@ -40,21 +42,21 @@ async def generate_rag_query(
async def default_rag_query_generator(
config: DefaultMemoryQueryGeneratorConfig,
message: Message,
messages: List[Message],
**kwargs,
):
return interleaved_content_as_str(message.content)
return config.sep.join(interleaved_content_as_str(m.content) for m in messages)
async def llm_rag_query_generator(
config: LLMMemoryQueryGeneratorConfig,
message: Message,
messages: List[Message],
**kwargs,
):
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
inference_api = kwargs["inference_api"]
m_dict = {"messages": [message.model_dump()]}
m_dict = {"messages": [message.model_dump() for message in messages]}
template = Template(config.template)
content = template.render(m_dict)

View file

@ -10,13 +10,14 @@ import secrets
import string
from typing import Any, Dict, List, Optional
from llama_stack.apis.inference import Inference, InterleavedContent, Message
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.memory import Memory, QueryDocumentsResponse
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.tools import (
ToolDef,
ToolGroupDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.providers.datatypes import ToolsProtocolPrivate
@ -50,17 +51,31 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async def initialize(self):
pass
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]:
return []
async def list_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="memory",
description="Retrieve context from memory",
parameters=[
ToolParameter(
name="input_messages",
description="The input messages to search for",
parameter_type="array",
),
],
)
]
async def _retrieve_context(
self, message: Message, bank_ids: List[str]
self, input_messages: List[str], bank_ids: List[str]
) -> Optional[List[InterleavedContent]]:
if not bank_ids:
return None
query = await generate_rag_query(
self.config.query_generator_config,
message,
input_messages,
inference_api=self.inference_api,
)
tasks = [
@ -106,17 +121,22 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
self, tool_name: str, args: Dict[str, Any]
) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name)
tool_group = await self.tool_store.get_tool_group(tool.toolgroup_id)
final_args = tool_group.args or {}
final_args.update(args)
config = MemoryToolConfig()
if tool.metadata.get("config") is not None:
if tool.metadata and tool.metadata.get("config") is not None:
config = MemoryToolConfig(**tool.metadata["config"])
if "memory_bank_id" in args:
bank_ids = [args["memory_bank_id"]]
if "memory_bank_ids" in final_args:
bank_ids = final_args["memory_bank_ids"]
else:
bank_ids = [
bank_config.bank_id for bank_config in config.memory_bank_configs
]
if "messages" not in final_args:
raise ValueError("messages are required")
context = await self._retrieve_context(
args["query"],
final_args["messages"],
bank_ids,
)
if context is None:

View file

@ -4,11 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
import requests
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
Tool,
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate
@ -41,8 +48,22 @@ class BraveSearchToolRuntimeImpl(
)
return provider_data.api_key
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
raise NotImplementedError("Brave search tool group not supported")
async def list_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="web_search",
description="Search the web for information",
parameters=[
ToolParameter(
name="query",
description="The query to search for",
parameter_type="string",
)
],
)
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]

View file

@ -4,20 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
from mcp import ClientSession
from mcp.client.sse import sse_client
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
MCPToolGroupDef,
ToolDef,
ToolGroupDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
UserDefinedToolDef,
)
from llama_stack.providers.datatypes import ToolsProtocolPrivate
@ -31,12 +29,14 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async def initialize(self):
pass
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]:
if not isinstance(tool_group, MCPToolGroupDef):
raise ValueError(f"Unsupported tool group type: {type(tool_group)}")
async def list_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required")
tools = []
async with sse_client(tool_group.endpoint.uri) as streams:
async with sse_client(mcp_endpoint.uri) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
tools_result = await session.list_tools()
@ -53,12 +53,12 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
)
)
tools.append(
UserDefinedToolDef(
ToolDef(
name=tool.name,
description=tool.description,
parameters=parameters,
metadata={
"endpoint": tool_group.endpoint.uri,
"endpoint": mcp_endpoint.uri,
},
)
)

View file

@ -5,11 +5,18 @@
# the root directory of this source tree.
import json
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
import requests
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
Tool,
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate
@ -42,8 +49,22 @@ class TavilySearchToolRuntimeImpl(
)
return provider_data.api_key
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
raise NotImplementedError("Tavily search tool group not supported")
async def list_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="web_search",
description="Search the web for information",
parameters=[
ToolParameter(
name="query",
description="The query to search for",
parameter_type="string",
)
],
)
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]

View file

@ -45,8 +45,7 @@ def common_params(inference_model):
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
input_shields=[],
output_shields=[],
available_tools=[],
preprocessing_tools=[],
toolgroups=[],
max_infer_iters=5,
)
@ -83,27 +82,27 @@ def query_attachment_messages():
]
async def create_agent_turn_with_search_tool(
async def create_agent_turn_with_toolgroup(
agents_stack: Dict[str, object],
search_query_messages: List[object],
common_params: Dict[str, str],
tool_name: str,
toolgroup_name: str,
) -> None:
"""
Create an agent turn with a search tool.
Create an agent turn with a toolgroup.
Args:
agents_stack (Dict[str, object]): The agents stack.
search_query_messages (List[object]): The search query messages.
common_params (Dict[str, str]): The common parameters.
search_tool_definition (SearchToolDefinition): The search tool definition.
toolgroup_name (str): The name of the toolgroup.
"""
# Create an agent with the search tool
# Create an agent with the toolgroup
agent_config = AgentConfig(
**{
**common_params,
"tools": [tool_name],
"toolgroups": [toolgroup_name],
}
)
@ -249,7 +248,7 @@ class TestAgents:
agent_config = AgentConfig(
**{
**common_params,
"tools": ["memory"],
"toolgroups": ["builtin::memory"],
"tool_choice": ToolChoice.auto,
}
)
@ -289,13 +288,58 @@ class TestAgents:
if "TAVILY_SEARCH_API_KEY" not in os.environ:
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
await create_agent_turn_with_search_tool(
agents_stack,
search_query_messages,
common_params,
"brave_search",
# Create an agent with the toolgroup
agent_config = AgentConfig(
**{
**common_params,
"toolgroups": ["builtin::web_search"],
}
)
agent_id, session_id = await create_agent_session(
agents_stack.impls[Api.agents], agent_config
)
turn_request = dict(
agent_id=agent_id,
session_id=session_id,
messages=search_query_messages,
stream=True,
)
turn_response = [
chunk
async for chunk in await agents_stack.impls[Api.agents].create_agent_turn(
**turn_request
)
]
assert len(turn_response) > 0
assert all(
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
)
check_event_types(turn_response)
# Check for tool execution events
tool_execution_events = [
chunk
for chunk in turn_response
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
and chunk.event.payload.step_details.step_type
== StepType.tool_execution.value
]
assert len(tool_execution_events) > 0, "No tool execution events found"
# Check the tool execution details
tool_execution = tool_execution_events[0].event.payload.step_details
assert isinstance(tool_execution, ToolExecutionStep)
assert len(tool_execution.tool_calls) > 0
actual_tool_name = tool_execution.tool_calls[0].tool_name
assert actual_tool_name == "web_search"
assert len(tool_execution.tool_responses) > 0
check_turn_complete_event(turn_response, session_id, search_query_messages)
def check_event_types(turn_response):
event_types = [chunk.event.payload.event_type for chunk in turn_response]

View file

@ -8,16 +8,9 @@ import os
import pytest
import pytest_asyncio
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.apis.tools import (
BuiltInToolDef,
ToolGroupInput,
ToolParameter,
UserDefinedToolDef,
UserDefinedToolGroupDef,
)
from llama_stack.apis.tools import ToolGroupInput
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import construct_stack_for_test
@ -47,30 +40,7 @@ def tool_runtime_memory_and_search() -> ProviderFixture:
@pytest.fixture(scope="session")
def tool_group_input_memory() -> ToolGroupInput:
return ToolGroupInput(
tool_group_id="memory_group",
tool_group=UserDefinedToolGroupDef(
tools=[
UserDefinedToolDef(
name="memory",
description="Query the memory bank",
parameters=[
ToolParameter(
name="input_messages",
description="The input messages to search for in memory",
parameter_type="list",
required=True,
),
],
metadata={
"config": {
"memory_bank_configs": [
{"bank_id": "test_bank", "type": "vector"}
]
}
},
)
],
),
toolgroup_id="builtin::memory",
provider_id="memory-runtime",
)
@ -78,10 +48,7 @@ def tool_group_input_memory() -> ToolGroupInput:
@pytest.fixture(scope="session")
def tool_group_input_tavily_search() -> ToolGroupInput:
return ToolGroupInput(
tool_group_id="tavily_search_group",
tool_group=UserDefinedToolGroupDef(
tools=[BuiltInToolDef(built_in_type=BuiltinTool.brave_search, metadata={})],
),
toolgroup_id="builtin::web_search",
provider_id="tavily-search",
)

View file

@ -43,8 +43,8 @@ def sample_documents():
class TestTools:
@pytest.mark.asyncio
async def test_brave_search_tool(self, tools_stack, sample_search_query):
"""Test the Brave search tool functionality."""
async def test_web_search_tool(self, tools_stack, sample_search_query):
"""Test the web search tool functionality."""
if "TAVILY_SEARCH_API_KEY" not in os.environ:
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
@ -52,7 +52,7 @@ class TestTools:
# Execute the tool
response = await tools_impl.invoke_tool(
tool_name="brave_search", args={"query": sample_search_query}
tool_name="web_search", args={"query": sample_search_query}
)
# Verify the response
@ -89,11 +89,12 @@ class TestTools:
response = await tools_impl.invoke_tool(
tool_name="memory",
args={
"input_messages": [
"messages": [
UserMessage(
content="What are the main topics covered in the documentation?",
)
],
"memory_bank_ids": ["test_bank"],
},
)