simplify toolgroups registration

This commit is contained in:
Dinesh Yeduguru 2025-01-07 15:37:52 -08:00
parent ba242c04cc
commit f9a98c278a
15 changed files with 350 additions and 256 deletions

View file

@ -137,15 +137,15 @@ class Session(BaseModel):
memory_bank: Optional[MemoryBank] = None
class AgentToolWithArgs(BaseModel):
class AgentToolGroupWithArgs(BaseModel):
name: str
args: Dict[str, Any]
AgentTool = register_schema(
AgentToolGroup = register_schema(
Union[
str,
AgentToolWithArgs,
AgentToolGroupWithArgs,
],
name="AgentTool",
)
@ -156,7 +156,7 @@ class AgentConfigCommon(BaseModel):
input_shields: Optional[List[str]] = Field(default_factory=list)
output_shields: Optional[List[str]] = Field(default_factory=list)
tools: Optional[List[AgentTool]] = Field(default_factory=list)
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
@ -278,7 +278,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
]
documents: Optional[List[Document]] = None
tools: Optional[List[AgentTool]] = None
toolgroups: Optional[List[AgentToolGroup]] = None
stream: Optional[bool] = False
@ -317,7 +317,7 @@ class Agents(Protocol):
],
stream: Optional[bool] = False,
documents: Optional[List[Document]] = None,
tools: Optional[List[AgentTool]] = None,
tools: Optional[List[AgentToolGroup]] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod(route="/agents/turn/get")

View file

@ -5,10 +5,10 @@
# the root directory of this source tree.
from enum import Enum
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
from typing import Any, Dict, List, Literal, Optional
from llama_models.llama3.api.datatypes import BuiltinTool, ToolPromptFormat
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Protocol, runtime_checkable
@ -22,7 +22,7 @@ class ToolParameter(BaseModel):
name: str
parameter_type: str
description: str
required: bool
required: bool = Field(default=True)
default: Optional[Any] = None
@ -36,7 +36,7 @@ class ToolHost(Enum):
@json_schema_type
class Tool(Resource):
type: Literal[ResourceType.tool.value] = ResourceType.tool.value
tool_group: str
toolgroup_id: str
tool_host: ToolHost
description: str
parameters: List[ToolParameter]
@ -58,41 +58,19 @@ class ToolDef(BaseModel):
)
@json_schema_type
class MCPToolGroupDef(BaseModel):
"""
A tool group that is defined by in a model context protocol server.
Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information.
"""
type: Literal["model_context_protocol"] = "model_context_protocol"
endpoint: URL
@json_schema_type
class UserDefinedToolGroupDef(BaseModel):
type: Literal["user_defined"] = "user_defined"
tools: List[ToolDef]
ToolGroupDef = register_schema(
Annotated[
Union[MCPToolGroupDef, UserDefinedToolGroupDef], Field(discriminator="type")
],
name="ToolGroupDef",
)
@json_schema_type
class ToolGroupInput(BaseModel):
tool_group_id: str
tool_group_def: ToolGroupDef
provider_id: Optional[str] = None
toolgroup_id: str
provider_id: str
args: Optional[Dict[str, Any]] = None
mcp_endpoint: Optional[URL] = None
@json_schema_type
class ToolGroup(Resource):
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
mcp_endpoint: Optional[URL] = None
args: Optional[Dict[str, Any]] = None
@json_schema_type
@ -104,6 +82,7 @@ class ToolInvocationResult(BaseModel):
class ToolStore(Protocol):
def get_tool(self, tool_name: str) -> Tool: ...
def get_tool_group(self, tool_group_id: str) -> ToolGroup: ...
@runtime_checkable
@ -112,9 +91,10 @@ class ToolGroups(Protocol):
@webmethod(route="/toolgroups/register", method="POST")
async def register_tool_group(
self,
tool_group_id: str,
tool_group_def: ToolGroupDef,
provider_id: Optional[str] = None,
toolgroup_id: str,
provider_id: str,
mcp_endpoint: Optional[URL] = None,
args: Optional[Dict[str, Any]] = None,
) -> None:
"""Register a tool group"""
...
@ -122,7 +102,7 @@ class ToolGroups(Protocol):
@webmethod(route="/toolgroups/get", method="GET")
async def get_tool_group(
self,
tool_group_id: str,
toolgroup_id: str,
) -> ToolGroup: ...
@webmethod(route="/toolgroups/list", method="GET")
@ -149,8 +129,10 @@ class ToolGroups(Protocol):
class ToolRuntime(Protocol):
tool_store: ToolStore
@webmethod(route="/tool-runtime/discover", method="POST")
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: ...
@webmethod(route="/tool-runtime/list-tools", method="POST")
async def list_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]: ...
@webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool(

View file

@ -6,7 +6,7 @@
from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.eval import (
AppEvalTaskConfig,
@ -38,7 +38,7 @@ from llama_stack.apis.scoring import (
ScoringFnParams,
)
from llama_stack.apis.shields import Shield
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolRuntime
from llama_stack.apis.tools import ToolDef, ToolRuntime
from llama_stack.providers.datatypes import RoutingTable
@ -417,7 +417,9 @@ class ToolRuntimeRouter(ToolRuntime):
args=args,
)
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
return await self.routing_table.get_provider_impl(
tool_group.name
).discover_tools(tool_group)
async def list_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(
tool_group_id, mcp_endpoint
)

View file

@ -26,15 +26,7 @@ from llama_stack.apis.scoring_functions import (
ScoringFunctions,
)
from llama_stack.apis.shields import Shield, Shields
from llama_stack.apis.tools import (
MCPToolGroupDef,
Tool,
ToolGroup,
ToolGroupDef,
ToolGroups,
ToolHost,
UserDefinedToolGroupDef,
)
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroups, ToolHost
from llama_stack.distribution.datatypes import (
RoutableObject,
RoutableObjectWithProvider,
@ -496,51 +488,38 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
tools = await self.get_all_with_type("tool")
if tool_group_id:
tools = [tool for tool in tools if tool.tool_group == tool_group_id]
tools = [tool for tool in tools if tool.toolgroup_id == tool_group_id]
return tools
async def list_tool_groups(self) -> List[ToolGroup]:
return await self.get_all_with_type("tool_group")
async def get_tool_group(self, tool_group_id: str) -> ToolGroup:
return await self.get_object_by_identifier("tool_group", tool_group_id)
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
return await self.get_object_by_identifier("tool_group", toolgroup_id)
async def get_tool(self, tool_name: str) -> Tool:
return await self.get_object_by_identifier("tool", tool_name)
async def register_tool_group(
self,
tool_group_id: str,
tool_group_def: ToolGroupDef,
provider_id: Optional[str] = None,
toolgroup_id: str,
provider_id: str,
mcp_endpoint: Optional[URL] = None,
args: Optional[Dict[str, Any]] = None,
) -> None:
tools = []
tool_defs = []
tool_host = ToolHost.distribution
if provider_id is None:
if len(self.impls_by_provider_id.keys()) > 1:
raise ValueError(
f"No provider_id specified and multiple providers available. Please specify a provider_id. Available providers: {', '.join(self.impls_by_provider_id.keys())}"
)
provider_id = list(self.impls_by_provider_id.keys())[0]
# parse tool group to the type if dict
tool_group_def = TypeAdapter(ToolGroupDef).validate_python(tool_group_def)
if isinstance(tool_group_def, MCPToolGroupDef):
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
tool_group_def
)
tool_host = ToolHost.model_context_protocol
elif isinstance(tool_group_def, UserDefinedToolGroupDef):
tool_defs = tool_group_def.tools
else:
raise ValueError(f"Unknown tool group: {tool_group_def}")
tool_defs = await self.impls_by_provider_id[provider_id].list_tools(
toolgroup_id, mcp_endpoint
)
tool_host = (
ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
)
for tool_def in tool_defs:
tools.append(
Tool(
identifier=tool_def.name,
tool_group=tool_group_id,
toolgroup_id=toolgroup_id,
description=tool_def.description or "",
parameters=tool_def.parameters or [],
provider_id=provider_id,
@ -565,9 +544,11 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
await self.dist_registry.register(
ToolGroup(
identifier=tool_group_id,
identifier=toolgroup_id,
provider_id=provider_id,
provider_resource_id=tool_group_id,
provider_resource_id=toolgroup_id,
mcp_endpoint=mcp_endpoint,
args=args,
)
)

View file

@ -13,7 +13,7 @@ import secrets
import string
import uuid
from datetime import datetime
from typing import AsyncGenerator, Dict, List, Optional
from typing import Any, AsyncGenerator, Dict, List, Optional
from urllib.parse import urlparse
import httpx
@ -21,8 +21,8 @@ from llama_models.llama3.api.datatypes import BuiltinTool, ToolCall, ToolParamDe
from llama_stack.apis.agents import (
AgentConfig,
AgentTool,
AgentToolWithArgs,
AgentToolGroup,
AgentToolGroupWithArgs,
AgentTurnCreateRequest,
AgentTurnResponseEvent,
AgentTurnResponseEventType,
@ -76,6 +76,10 @@ def make_random_string(length: int = 8):
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
MEMORY_TOOL_GROUP_ID = "builtin::memory"
MEMORY_QUERY_TOOL = "query_memory"
CODE_INTERPRETER_TOOL = "code_interpreter"
WEB_SEARCH_TOOL = "web_search"
class ChatAgent(ShieldRunnerMixin):
@ -192,7 +196,7 @@ class ChatAgent(ShieldRunnerMixin):
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
documents=request.documents,
tools_for_turn=request.tools,
toolgroups_for_turn=request.toolgroups,
):
if isinstance(chunk, CompletionMessage):
log.info(
@ -243,7 +247,7 @@ class ChatAgent(ShieldRunnerMixin):
sampling_params: SamplingParams,
stream: bool = False,
documents: Optional[List[Document]] = None,
tools_for_turn: Optional[List[AgentTool]] = None,
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
) -> AsyncGenerator:
# Doing async generators makes downstream code much simpler and everything amenable to
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
@ -266,7 +270,7 @@ class ChatAgent(ShieldRunnerMixin):
sampling_params,
stream,
documents,
tools_for_turn,
toolgroups_for_turn,
):
if isinstance(res, bool):
return
@ -362,21 +366,24 @@ class ChatAgent(ShieldRunnerMixin):
sampling_params: SamplingParams,
stream: bool = False,
documents: Optional[List[Document]] = None,
tools_for_turn: Optional[List[AgentTool]] = None,
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
) -> AsyncGenerator:
tool_args = {}
if tools_for_turn:
for tool in tools_for_turn:
if isinstance(tool, AgentToolWithArgs):
tool_args[tool.name] = tool.args
toolgroup_args = {}
for toolgroup in self.agent_config.toolgroups:
if isinstance(toolgroup, AgentToolGroupWithArgs):
toolgroup_args[toolgroup.name] = toolgroup.args
if toolgroups_for_turn:
for toolgroup in toolgroups_for_turn:
if isinstance(toolgroup, AgentToolGroupWithArgs):
toolgroup_args[toolgroup.name] = toolgroup.args
tool_defs = await self._get_tool_defs(tools_for_turn)
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
if documents:
await self.handle_documents(
session_id, documents, input_messages, tool_defs
)
if "memory" in tool_defs and len(input_messages) > 0:
with tracing.span("memory_tool") as span:
if MEMORY_QUERY_TOOL in tool_defs and len(input_messages) > 0:
with tracing.span(MEMORY_QUERY_TOOL) as span:
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -386,18 +393,16 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
extra_args = tool_args.get("memory", {})
tool_args = {
# Query memory with the last message's content
"query": input_messages[-1],
**extra_args,
query_args = {
"messages": [msg.content for msg in input_messages],
**toolgroup_args.get(MEMORY_TOOL_GROUP_ID, {}),
}
session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it
if session_info.memory_bank_id:
tool_args["memory_bank_id"] = session_info.memory_bank_id
serialized_args = tracing.serialize_value(tool_args)
query_args["memory_bank_id"] = session_info.memory_bank_id
serialized_args = tracing.serialize_value(query_args)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
@ -415,8 +420,8 @@ class ChatAgent(ShieldRunnerMixin):
)
)
result = await self.tool_runtime_api.invoke_tool(
tool_name="memory",
args=tool_args,
tool_name=MEMORY_QUERY_TOOL,
args=query_args,
)
yield AgentTurnResponseStreamChunk(
@ -485,7 +490,8 @@ class ChatAgent(ShieldRunnerMixin):
tools=[
tool
for tool in tool_defs.values()
if tool.tool_name != "memory"
if tool_to_group.get(tool.tool_name, None)
!= MEMORY_TOOL_GROUP_ID
],
tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True,
@ -632,6 +638,8 @@ class ChatAgent(ShieldRunnerMixin):
self.tool_runtime_api,
session_id,
[message],
toolgroup_args,
tool_to_group,
)
assert (
len(result_messages) == 1
@ -690,26 +698,37 @@ class ChatAgent(ShieldRunnerMixin):
n_iter += 1
async def _get_tool_defs(
self, tools_for_turn: Optional[List[AgentTool]]
self, toolgroups_for_turn: Optional[List[AgentToolGroup]]
) -> Dict[str, ToolDefinition]:
# Determine which tools to include
agent_config_tools = set(
tool.name if isinstance(tool, AgentToolWithArgs) else tool
for tool in self.agent_config.tools
agent_config_toolgroups = set(
(
toolgroup.name
if isinstance(toolgroup, AgentToolGroupWithArgs)
else toolgroup
)
for toolgroup in self.agent_config.toolgroups
)
tools_for_turn_set = (
agent_config_tools
if tools_for_turn is None
toolgroups_for_turn_set = (
agent_config_toolgroups
if toolgroups_for_turn is None
else {
tool.name if isinstance(tool, AgentToolWithArgs) else tool
for tool in tools_for_turn
(
toolgroup.name
if isinstance(toolgroup, AgentToolGroupWithArgs)
else toolgroup
)
for toolgroup in toolgroups_for_turn
}
)
ret = {}
tool_def_map = {}
tool_to_group = {}
for tool_def in self.agent_config.client_tools:
ret[tool_def.name] = ToolDefinition(
if tool_def_map.get(tool_def.name, None):
raise ValueError(f"Tool {tool_def.name} already exists")
tool_def_map[tool_def.name] = ToolDefinition(
tool_name=tool_def.name,
description=tool_def.description,
parameters={
@ -722,41 +741,42 @@ class ChatAgent(ShieldRunnerMixin):
for param in tool_def.parameters
},
)
for tool_name in agent_config_tools:
if tool_name not in tools_for_turn_set:
tool_to_group[tool_def.name] = "__client_tools__"
for toolgroup_name in agent_config_toolgroups:
if toolgroup_name not in toolgroups_for_turn_set:
continue
tools = await self.tool_groups_api.list_tools(tool_group_id=toolgroup_name)
for tool_def in tools:
if tool_def.built_in_type:
if tool_def_map.get(tool_def.built_in_type, None):
raise ValueError(
f"Tool {tool_def.built_in_type} already exists"
)
tool_def = await self.tool_groups_api.get_tool(tool_name)
if tool_def is None:
raise ValueError(f"Tool {tool_name} not found")
if tool_def.identifier.startswith("builtin::"):
built_in_type = tool_def.identifier[len("builtin::") :]
if built_in_type == "web_search":
built_in_type = "brave_search"
if built_in_type not in BuiltinTool.__members__:
raise ValueError(f"Unknown built-in tool: {built_in_type}")
ret[built_in_type] = ToolDefinition(
tool_name=BuiltinTool(built_in_type)
)
continue
ret[tool_def.identifier] = ToolDefinition(
tool_name=tool_def.identifier,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
tool_def_map[tool_def.built_in_type] = ToolDefinition(
tool_name=tool_def.built_in_type
)
for param in tool_def.parameters
},
)
tool_to_group[tool_def.built_in_type] = tool_def.toolgroup_id
continue
return ret
if tool_def_map.get(tool_def.identifier, None):
raise ValueError(f"Tool {tool_def.identifier} already exists")
tool_def_map[tool_def.identifier] = ToolDefinition(
tool_name=tool_def.identifier,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool_def.parameters
},
)
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
return tool_def_map, tool_to_group
async def handle_documents(
self,
@ -765,8 +785,8 @@ class ChatAgent(ShieldRunnerMixin):
input_messages: List[Message],
tool_defs: Dict[str, ToolDefinition],
) -> None:
memory_tool = tool_defs.get("memory", None)
code_interpreter_tool = tool_defs.get("code_interpreter", None)
memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None)
code_interpreter_tool = tool_defs.get(CODE_INTERPRETER_TOOL, None)
content_items = []
url_items = []
pattern = re.compile("^(https?://|file://|data:)")
@ -903,7 +923,11 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
async def execute_tool_call_maybe(
tool_runtime_api: ToolRuntime, session_id: str, messages: List[CompletionMessage]
tool_runtime_api: ToolRuntime,
session_id: str,
messages: List[CompletionMessage],
toolgroup_args: Dict[str, Dict[str, Any]],
tool_to_group: Dict[str, str],
) -> List[ToolResponseMessage]:
# While Tools.run interface takes a list of messages,
# All tools currently only run on a single message
@ -915,18 +939,26 @@ async def execute_tool_call_maybe(
tool_call = message.tool_calls[0]
name = tool_call.tool_name
group_name = tool_to_group.get(name, None)
if group_name is None:
raise ValueError(f"Tool {name} not found in any tool group")
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
tool_call_args = tool_call.arguments
tool_call_args.update(toolgroup_args.get(group_name, {}))
if isinstance(name, BuiltinTool):
if name == BuiltinTool.brave_search:
name = "builtin::web_search"
name = WEB_SEARCH_TOOL
else:
name = "builtin::" + name.value
name = name.value
result = await tool_runtime_api.invoke_tool(
tool_name=name,
args=dict(
session_id=session_id,
**tool_call.arguments,
**tool_call_args,
),
)
return [
ToolResponseMessage(
call_id=tool_call.call_id,

View file

@ -19,7 +19,7 @@ from llama_stack.apis.agents import (
Agents,
AgentSessionCreateResponse,
AgentStepResponse,
AgentTool,
AgentToolGroup,
AgentTurnCreateRequest,
Document,
Session,
@ -147,7 +147,7 @@ class MetaReferenceAgentsImpl(Agents):
ToolResponseMessage,
]
],
tools: Optional[List[AgentTool]] = None,
tools: Optional[List[AgentToolGroup]] = None,
documents: Optional[List[Document]] = None,
stream: Optional[bool] = False,
) -> AsyncGenerator:

View file

@ -7,9 +7,16 @@
import logging
import tempfile
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
Tool,
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .code_execution import CodeExecutionContext, CodeExecutionRequest, CodeExecutor
@ -35,8 +42,22 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async def unregister_tool(self, tool_id: str) -> None:
return
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
raise NotImplementedError("Code interpreter tool group not supported")
async def list_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="code_interpreter",
description="Execute code",
parameters=[
ToolParameter(
name="code",
description="The code to execute",
parameter_type="string",
),
],
)
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]

View file

@ -5,6 +5,8 @@
# the root directory of this source tree.
from typing import List
from jinja2 import Template
from llama_stack.apis.inference import Message, UserMessage
@ -22,7 +24,7 @@ from .config import (
async def generate_rag_query(
config: MemoryQueryGeneratorConfig,
message: Message,
messages: List[Message],
**kwargs,
):
"""
@ -30,9 +32,9 @@ async def generate_rag_query(
retrieving relevant information from the memory bank.
"""
if config.type == MemoryQueryGenerator.default.value:
query = await default_rag_query_generator(config, message, **kwargs)
query = await default_rag_query_generator(config, messages, **kwargs)
elif config.type == MemoryQueryGenerator.llm.value:
query = await llm_rag_query_generator(config, message, **kwargs)
query = await llm_rag_query_generator(config, messages, **kwargs)
else:
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
return query
@ -40,21 +42,21 @@ async def generate_rag_query(
async def default_rag_query_generator(
config: DefaultMemoryQueryGeneratorConfig,
message: Message,
messages: List[Message],
**kwargs,
):
return interleaved_content_as_str(message.content)
return config.sep.join(interleaved_content_as_str(m.content) for m in messages)
async def llm_rag_query_generator(
config: LLMMemoryQueryGeneratorConfig,
message: Message,
messages: List[Message],
**kwargs,
):
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
inference_api = kwargs["inference_api"]
m_dict = {"messages": [message.model_dump()]}
m_dict = {"messages": [message.model_dump() for message in messages]}
template = Template(config.template)
content = template.render(m_dict)

View file

@ -10,13 +10,14 @@ import secrets
import string
from typing import Any, Dict, List, Optional
from llama_stack.apis.inference import Inference, InterleavedContent, Message
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.memory import Memory, QueryDocumentsResponse
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.tools import (
ToolDef,
ToolGroupDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.providers.datatypes import ToolsProtocolPrivate
@ -50,17 +51,31 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async def initialize(self):
pass
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]:
return []
async def list_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="memory",
description="Retrieve context from memory",
parameters=[
ToolParameter(
name="input_messages",
description="The input messages to search for",
parameter_type="array",
),
],
)
]
async def _retrieve_context(
self, message: Message, bank_ids: List[str]
self, input_messages: List[str], bank_ids: List[str]
) -> Optional[List[InterleavedContent]]:
if not bank_ids:
return None
query = await generate_rag_query(
self.config.query_generator_config,
message,
input_messages,
inference_api=self.inference_api,
)
tasks = [
@ -106,17 +121,22 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
self, tool_name: str, args: Dict[str, Any]
) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name)
tool_group = await self.tool_store.get_tool_group(tool.toolgroup_id)
final_args = tool_group.args or {}
final_args.update(args)
config = MemoryToolConfig()
if tool.metadata.get("config") is not None:
if tool.metadata and tool.metadata.get("config") is not None:
config = MemoryToolConfig(**tool.metadata["config"])
if "memory_bank_id" in args:
bank_ids = [args["memory_bank_id"]]
if "memory_bank_ids" in final_args:
bank_ids = final_args["memory_bank_ids"]
else:
bank_ids = [
bank_config.bank_id for bank_config in config.memory_bank_configs
]
if "messages" not in final_args:
raise ValueError("messages are required")
context = await self._retrieve_context(
args["query"],
final_args["messages"],
bank_ids,
)
if context is None:

View file

@ -4,11 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
import requests
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
Tool,
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate
@ -41,8 +48,22 @@ class BraveSearchToolRuntimeImpl(
)
return provider_data.api_key
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
raise NotImplementedError("Brave search tool group not supported")
async def list_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="web_search",
description="Search the web for information",
parameters=[
ToolParameter(
name="query",
description="The query to search for",
parameter_type="string",
)
],
)
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]

View file

@ -4,20 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
from mcp import ClientSession
from mcp.client.sse import sse_client
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
MCPToolGroupDef,
ToolDef,
ToolGroupDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
UserDefinedToolDef,
)
from llama_stack.providers.datatypes import ToolsProtocolPrivate
@ -31,12 +29,14 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async def initialize(self):
pass
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]:
if not isinstance(tool_group, MCPToolGroupDef):
raise ValueError(f"Unsupported tool group type: {type(tool_group)}")
async def list_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required")
tools = []
async with sse_client(tool_group.endpoint.uri) as streams:
async with sse_client(mcp_endpoint.uri) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
tools_result = await session.list_tools()
@ -53,12 +53,12 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
)
)
tools.append(
UserDefinedToolDef(
ToolDef(
name=tool.name,
description=tool.description,
parameters=parameters,
metadata={
"endpoint": tool_group.endpoint.uri,
"endpoint": mcp_endpoint.uri,
},
)
)

View file

@ -5,11 +5,18 @@
# the root directory of this source tree.
import json
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
import requests
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
Tool,
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate
@ -42,8 +49,22 @@ class TavilySearchToolRuntimeImpl(
)
return provider_data.api_key
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
raise NotImplementedError("Tavily search tool group not supported")
async def list_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="web_search",
description="Search the web for information",
parameters=[
ToolParameter(
name="query",
description="The query to search for",
parameter_type="string",
)
],
)
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]

View file

@ -45,8 +45,7 @@ def common_params(inference_model):
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
input_shields=[],
output_shields=[],
available_tools=[],
preprocessing_tools=[],
toolgroups=[],
max_infer_iters=5,
)
@ -83,27 +82,27 @@ def query_attachment_messages():
]
async def create_agent_turn_with_search_tool(
async def create_agent_turn_with_toolgroup(
agents_stack: Dict[str, object],
search_query_messages: List[object],
common_params: Dict[str, str],
tool_name: str,
toolgroup_name: str,
) -> None:
"""
Create an agent turn with a search tool.
Create an agent turn with a toolgroup.
Args:
agents_stack (Dict[str, object]): The agents stack.
search_query_messages (List[object]): The search query messages.
common_params (Dict[str, str]): The common parameters.
search_tool_definition (SearchToolDefinition): The search tool definition.
toolgroup_name (str): The name of the toolgroup.
"""
# Create an agent with the search tool
# Create an agent with the toolgroup
agent_config = AgentConfig(
**{
**common_params,
"tools": [tool_name],
"toolgroups": [toolgroup_name],
}
)
@ -249,7 +248,7 @@ class TestAgents:
agent_config = AgentConfig(
**{
**common_params,
"tools": ["memory"],
"toolgroups": ["builtin::memory"],
"tool_choice": ToolChoice.auto,
}
)
@ -289,13 +288,58 @@ class TestAgents:
if "TAVILY_SEARCH_API_KEY" not in os.environ:
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
await create_agent_turn_with_search_tool(
agents_stack,
search_query_messages,
common_params,
"brave_search",
# Create an agent with the toolgroup
agent_config = AgentConfig(
**{
**common_params,
"toolgroups": ["builtin::web_search"],
}
)
agent_id, session_id = await create_agent_session(
agents_stack.impls[Api.agents], agent_config
)
turn_request = dict(
agent_id=agent_id,
session_id=session_id,
messages=search_query_messages,
stream=True,
)
turn_response = [
chunk
async for chunk in await agents_stack.impls[Api.agents].create_agent_turn(
**turn_request
)
]
assert len(turn_response) > 0
assert all(
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
)
check_event_types(turn_response)
# Check for tool execution events
tool_execution_events = [
chunk
for chunk in turn_response
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
and chunk.event.payload.step_details.step_type
== StepType.tool_execution.value
]
assert len(tool_execution_events) > 0, "No tool execution events found"
# Check the tool execution details
tool_execution = tool_execution_events[0].event.payload.step_details
assert isinstance(tool_execution, ToolExecutionStep)
assert len(tool_execution.tool_calls) > 0
actual_tool_name = tool_execution.tool_calls[0].tool_name
assert actual_tool_name == "web_search"
assert len(tool_execution.tool_responses) > 0
check_turn_complete_event(turn_response, session_id, search_query_messages)
def check_event_types(turn_response):
event_types = [chunk.event.payload.event_type for chunk in turn_response]

View file

@ -8,16 +8,9 @@ import os
import pytest
import pytest_asyncio
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.apis.tools import (
BuiltInToolDef,
ToolGroupInput,
ToolParameter,
UserDefinedToolDef,
UserDefinedToolGroupDef,
)
from llama_stack.apis.tools import ToolGroupInput
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import construct_stack_for_test
@ -47,30 +40,7 @@ def tool_runtime_memory_and_search() -> ProviderFixture:
@pytest.fixture(scope="session")
def tool_group_input_memory() -> ToolGroupInput:
return ToolGroupInput(
tool_group_id="memory_group",
tool_group=UserDefinedToolGroupDef(
tools=[
UserDefinedToolDef(
name="memory",
description="Query the memory bank",
parameters=[
ToolParameter(
name="input_messages",
description="The input messages to search for in memory",
parameter_type="list",
required=True,
),
],
metadata={
"config": {
"memory_bank_configs": [
{"bank_id": "test_bank", "type": "vector"}
]
}
},
)
],
),
toolgroup_id="builtin::memory",
provider_id="memory-runtime",
)
@ -78,10 +48,7 @@ def tool_group_input_memory() -> ToolGroupInput:
@pytest.fixture(scope="session")
def tool_group_input_tavily_search() -> ToolGroupInput:
return ToolGroupInput(
tool_group_id="tavily_search_group",
tool_group=UserDefinedToolGroupDef(
tools=[BuiltInToolDef(built_in_type=BuiltinTool.brave_search, metadata={})],
),
toolgroup_id="builtin::web_search",
provider_id="tavily-search",
)

View file

@ -43,8 +43,8 @@ def sample_documents():
class TestTools:
@pytest.mark.asyncio
async def test_brave_search_tool(self, tools_stack, sample_search_query):
"""Test the Brave search tool functionality."""
async def test_web_search_tool(self, tools_stack, sample_search_query):
"""Test the web search tool functionality."""
if "TAVILY_SEARCH_API_KEY" not in os.environ:
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
@ -52,7 +52,7 @@ class TestTools:
# Execute the tool
response = await tools_impl.invoke_tool(
tool_name="brave_search", args={"query": sample_search_query}
tool_name="web_search", args={"query": sample_search_query}
)
# Verify the response
@ -89,11 +89,12 @@ class TestTools:
response = await tools_impl.invoke_tool(
tool_name="memory",
args={
"input_messages": [
"messages": [
UserMessage(
content="What are the main topics covered in the documentation?",
)
],
"memory_bank_ids": ["test_bank"],
},
)