simplify toolgroups registration

This commit is contained in:
Dinesh Yeduguru 2025-01-07 15:37:52 -08:00
parent ba242c04cc
commit f9a98c278a
15 changed files with 350 additions and 256 deletions

View file

@ -137,15 +137,15 @@ class Session(BaseModel):
memory_bank: Optional[MemoryBank] = None memory_bank: Optional[MemoryBank] = None
class AgentToolWithArgs(BaseModel): class AgentToolGroupWithArgs(BaseModel):
name: str name: str
args: Dict[str, Any] args: Dict[str, Any]
AgentTool = register_schema( AgentToolGroup = register_schema(
Union[ Union[
str, str,
AgentToolWithArgs, AgentToolGroupWithArgs,
], ],
name="AgentTool", name="AgentTool",
) )
@ -156,7 +156,7 @@ class AgentConfigCommon(BaseModel):
input_shields: Optional[List[str]] = Field(default_factory=list) input_shields: Optional[List[str]] = Field(default_factory=list)
output_shields: Optional[List[str]] = Field(default_factory=list) output_shields: Optional[List[str]] = Field(default_factory=list)
tools: Optional[List[AgentTool]] = Field(default_factory=list) toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
client_tools: Optional[List[ToolDef]] = Field(default_factory=list) client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field( tool_prompt_format: Optional[ToolPromptFormat] = Field(
@ -278,7 +278,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
] ]
documents: Optional[List[Document]] = None documents: Optional[List[Document]] = None
tools: Optional[List[AgentTool]] = None toolgroups: Optional[List[AgentToolGroup]] = None
stream: Optional[bool] = False stream: Optional[bool] = False
@ -317,7 +317,7 @@ class Agents(Protocol):
], ],
stream: Optional[bool] = False, stream: Optional[bool] = False,
documents: Optional[List[Document]] = None, documents: Optional[List[Document]] = None,
tools: Optional[List[AgentTool]] = None, tools: Optional[List[AgentToolGroup]] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod(route="/agents/turn/get") @webmethod(route="/agents/turn/get")

View file

@ -5,10 +5,10 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Annotated, Any, Dict, List, Literal, Optional, Union from typing import Any, Dict, List, Literal, Optional
from llama_models.llama3.api.datatypes import BuiltinTool, ToolPromptFormat from llama_models.llama3.api.datatypes import BuiltinTool, ToolPromptFormat
from llama_models.schema_utils import json_schema_type, register_schema, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Protocol, runtime_checkable from typing_extensions import Protocol, runtime_checkable
@ -22,7 +22,7 @@ class ToolParameter(BaseModel):
name: str name: str
parameter_type: str parameter_type: str
description: str description: str
required: bool required: bool = Field(default=True)
default: Optional[Any] = None default: Optional[Any] = None
@ -36,7 +36,7 @@ class ToolHost(Enum):
@json_schema_type @json_schema_type
class Tool(Resource): class Tool(Resource):
type: Literal[ResourceType.tool.value] = ResourceType.tool.value type: Literal[ResourceType.tool.value] = ResourceType.tool.value
tool_group: str toolgroup_id: str
tool_host: ToolHost tool_host: ToolHost
description: str description: str
parameters: List[ToolParameter] parameters: List[ToolParameter]
@ -58,41 +58,19 @@ class ToolDef(BaseModel):
) )
@json_schema_type
class MCPToolGroupDef(BaseModel):
"""
A tool group that is defined by in a model context protocol server.
Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information.
"""
type: Literal["model_context_protocol"] = "model_context_protocol"
endpoint: URL
@json_schema_type
class UserDefinedToolGroupDef(BaseModel):
type: Literal["user_defined"] = "user_defined"
tools: List[ToolDef]
ToolGroupDef = register_schema(
Annotated[
Union[MCPToolGroupDef, UserDefinedToolGroupDef], Field(discriminator="type")
],
name="ToolGroupDef",
)
@json_schema_type @json_schema_type
class ToolGroupInput(BaseModel): class ToolGroupInput(BaseModel):
tool_group_id: str toolgroup_id: str
tool_group_def: ToolGroupDef provider_id: str
provider_id: Optional[str] = None args: Optional[Dict[str, Any]] = None
mcp_endpoint: Optional[URL] = None
@json_schema_type @json_schema_type
class ToolGroup(Resource): class ToolGroup(Resource):
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
mcp_endpoint: Optional[URL] = None
args: Optional[Dict[str, Any]] = None
@json_schema_type @json_schema_type
@ -104,6 +82,7 @@ class ToolInvocationResult(BaseModel):
class ToolStore(Protocol): class ToolStore(Protocol):
def get_tool(self, tool_name: str) -> Tool: ... def get_tool(self, tool_name: str) -> Tool: ...
def get_tool_group(self, tool_group_id: str) -> ToolGroup: ...
@runtime_checkable @runtime_checkable
@ -112,9 +91,10 @@ class ToolGroups(Protocol):
@webmethod(route="/toolgroups/register", method="POST") @webmethod(route="/toolgroups/register", method="POST")
async def register_tool_group( async def register_tool_group(
self, self,
tool_group_id: str, toolgroup_id: str,
tool_group_def: ToolGroupDef, provider_id: str,
provider_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None,
args: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
"""Register a tool group""" """Register a tool group"""
... ...
@ -122,7 +102,7 @@ class ToolGroups(Protocol):
@webmethod(route="/toolgroups/get", method="GET") @webmethod(route="/toolgroups/get", method="GET")
async def get_tool_group( async def get_tool_group(
self, self,
tool_group_id: str, toolgroup_id: str,
) -> ToolGroup: ... ) -> ToolGroup: ...
@webmethod(route="/toolgroups/list", method="GET") @webmethod(route="/toolgroups/list", method="GET")
@ -149,8 +129,10 @@ class ToolGroups(Protocol):
class ToolRuntime(Protocol): class ToolRuntime(Protocol):
tool_store: ToolStore tool_store: ToolStore
@webmethod(route="/tool-runtime/discover", method="POST") @webmethod(route="/tool-runtime/list-tools", method="POST")
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: ... async def list_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]: ...
@webmethod(route="/tool-runtime/invoke", method="POST") @webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool( async def invoke_tool(

View file

@ -6,7 +6,7 @@
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.eval import ( from llama_stack.apis.eval import (
AppEvalTaskConfig, AppEvalTaskConfig,
@ -38,7 +38,7 @@ from llama_stack.apis.scoring import (
ScoringFnParams, ScoringFnParams,
) )
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolRuntime from llama_stack.apis.tools import ToolDef, ToolRuntime
from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.datatypes import RoutingTable
@ -417,7 +417,9 @@ class ToolRuntimeRouter(ToolRuntime):
args=args, args=args,
) )
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]: async def list_tools(
return await self.routing_table.get_provider_impl( self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
tool_group.name ) -> List[ToolDef]:
).discover_tools(tool_group) return await self.routing_table.get_provider_impl(tool_group_id).list_tools(
tool_group_id, mcp_endpoint
)

View file

@ -26,15 +26,7 @@ from llama_stack.apis.scoring_functions import (
ScoringFunctions, ScoringFunctions,
) )
from llama_stack.apis.shields import Shield, Shields from llama_stack.apis.shields import Shield, Shields
from llama_stack.apis.tools import ( from llama_stack.apis.tools import Tool, ToolGroup, ToolGroups, ToolHost
MCPToolGroupDef,
Tool,
ToolGroup,
ToolGroupDef,
ToolGroups,
ToolHost,
UserDefinedToolGroupDef,
)
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
RoutableObject, RoutableObject,
RoutableObjectWithProvider, RoutableObjectWithProvider,
@ -496,51 +488,38 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]: async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
tools = await self.get_all_with_type("tool") tools = await self.get_all_with_type("tool")
if tool_group_id: if tool_group_id:
tools = [tool for tool in tools if tool.tool_group == tool_group_id] tools = [tool for tool in tools if tool.toolgroup_id == tool_group_id]
return tools return tools
async def list_tool_groups(self) -> List[ToolGroup]: async def list_tool_groups(self) -> List[ToolGroup]:
return await self.get_all_with_type("tool_group") return await self.get_all_with_type("tool_group")
async def get_tool_group(self, tool_group_id: str) -> ToolGroup: async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
return await self.get_object_by_identifier("tool_group", tool_group_id) return await self.get_object_by_identifier("tool_group", toolgroup_id)
async def get_tool(self, tool_name: str) -> Tool: async def get_tool(self, tool_name: str) -> Tool:
return await self.get_object_by_identifier("tool", tool_name) return await self.get_object_by_identifier("tool", tool_name)
async def register_tool_group( async def register_tool_group(
self, self,
tool_group_id: str, toolgroup_id: str,
tool_group_def: ToolGroupDef, provider_id: str,
provider_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None,
args: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
tools = [] tools = []
tool_defs = [] tool_defs = await self.impls_by_provider_id[provider_id].list_tools(
tool_host = ToolHost.distribution toolgroup_id, mcp_endpoint
if provider_id is None: )
if len(self.impls_by_provider_id.keys()) > 1: tool_host = (
raise ValueError( ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
f"No provider_id specified and multiple providers available. Please specify a provider_id. Available providers: {', '.join(self.impls_by_provider_id.keys())}" )
)
provider_id = list(self.impls_by_provider_id.keys())[0]
# parse tool group to the type if dict
tool_group_def = TypeAdapter(ToolGroupDef).validate_python(tool_group_def)
if isinstance(tool_group_def, MCPToolGroupDef):
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
tool_group_def
)
tool_host = ToolHost.model_context_protocol
elif isinstance(tool_group_def, UserDefinedToolGroupDef):
tool_defs = tool_group_def.tools
else:
raise ValueError(f"Unknown tool group: {tool_group_def}")
for tool_def in tool_defs: for tool_def in tool_defs:
tools.append( tools.append(
Tool( Tool(
identifier=tool_def.name, identifier=tool_def.name,
tool_group=tool_group_id, toolgroup_id=toolgroup_id,
description=tool_def.description or "", description=tool_def.description or "",
parameters=tool_def.parameters or [], parameters=tool_def.parameters or [],
provider_id=provider_id, provider_id=provider_id,
@ -565,9 +544,11 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
await self.dist_registry.register( await self.dist_registry.register(
ToolGroup( ToolGroup(
identifier=tool_group_id, identifier=toolgroup_id,
provider_id=provider_id, provider_id=provider_id,
provider_resource_id=tool_group_id, provider_resource_id=toolgroup_id,
mcp_endpoint=mcp_endpoint,
args=args,
) )
) )

View file

@ -13,7 +13,7 @@ import secrets
import string import string
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
import httpx import httpx
@ -21,8 +21,8 @@ from llama_models.llama3.api.datatypes import BuiltinTool, ToolCall, ToolParamDe
from llama_stack.apis.agents import ( from llama_stack.apis.agents import (
AgentConfig, AgentConfig,
AgentTool, AgentToolGroup,
AgentToolWithArgs, AgentToolGroupWithArgs,
AgentTurnCreateRequest, AgentTurnCreateRequest,
AgentTurnResponseEvent, AgentTurnResponseEvent,
AgentTurnResponseEventType, AgentTurnResponseEventType,
@ -76,6 +76,10 @@ def make_random_string(length: int = 8):
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})") TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
MEMORY_TOOL_GROUP_ID = "builtin::memory"
MEMORY_QUERY_TOOL = "query_memory"
CODE_INTERPRETER_TOOL = "code_interpreter"
WEB_SEARCH_TOOL = "web_search"
class ChatAgent(ShieldRunnerMixin): class ChatAgent(ShieldRunnerMixin):
@ -192,7 +196,7 @@ class ChatAgent(ShieldRunnerMixin):
sampling_params=self.agent_config.sampling_params, sampling_params=self.agent_config.sampling_params,
stream=request.stream, stream=request.stream,
documents=request.documents, documents=request.documents,
tools_for_turn=request.tools, toolgroups_for_turn=request.toolgroups,
): ):
if isinstance(chunk, CompletionMessage): if isinstance(chunk, CompletionMessage):
log.info( log.info(
@ -243,7 +247,7 @@ class ChatAgent(ShieldRunnerMixin):
sampling_params: SamplingParams, sampling_params: SamplingParams,
stream: bool = False, stream: bool = False,
documents: Optional[List[Document]] = None, documents: Optional[List[Document]] = None,
tools_for_turn: Optional[List[AgentTool]] = None, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
# Doing async generators makes downstream code much simpler and everything amenable to # Doing async generators makes downstream code much simpler and everything amenable to
# streaming. However, it also makes things complicated here because AsyncGenerators cannot # streaming. However, it also makes things complicated here because AsyncGenerators cannot
@ -266,7 +270,7 @@ class ChatAgent(ShieldRunnerMixin):
sampling_params, sampling_params,
stream, stream,
documents, documents,
tools_for_turn, toolgroups_for_turn,
): ):
if isinstance(res, bool): if isinstance(res, bool):
return return
@ -362,21 +366,24 @@ class ChatAgent(ShieldRunnerMixin):
sampling_params: SamplingParams, sampling_params: SamplingParams,
stream: bool = False, stream: bool = False,
documents: Optional[List[Document]] = None, documents: Optional[List[Document]] = None,
tools_for_turn: Optional[List[AgentTool]] = None, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
tool_args = {} toolgroup_args = {}
if tools_for_turn: for toolgroup in self.agent_config.toolgroups:
for tool in tools_for_turn: if isinstance(toolgroup, AgentToolGroupWithArgs):
if isinstance(tool, AgentToolWithArgs): toolgroup_args[toolgroup.name] = toolgroup.args
tool_args[tool.name] = tool.args if toolgroups_for_turn:
for toolgroup in toolgroups_for_turn:
if isinstance(toolgroup, AgentToolGroupWithArgs):
toolgroup_args[toolgroup.name] = toolgroup.args
tool_defs = await self._get_tool_defs(tools_for_turn) tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
if documents: if documents:
await self.handle_documents( await self.handle_documents(
session_id, documents, input_messages, tool_defs session_id, documents, input_messages, tool_defs
) )
if "memory" in tool_defs and len(input_messages) > 0: if MEMORY_QUERY_TOOL in tool_defs and len(input_messages) > 0:
with tracing.span("memory_tool") as span: with tracing.span(MEMORY_QUERY_TOOL) as span:
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
@ -386,18 +393,16 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
) )
extra_args = tool_args.get("memory", {}) query_args = {
tool_args = { "messages": [msg.content for msg in input_messages],
# Query memory with the last message's content **toolgroup_args.get(MEMORY_TOOL_GROUP_ID, {}),
"query": input_messages[-1],
**extra_args,
} }
session_info = await self.storage.get_session_info(session_id) session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it # if the session has a memory bank id, let the memory tool use it
if session_info.memory_bank_id: if session_info.memory_bank_id:
tool_args["memory_bank_id"] = session_info.memory_bank_id query_args["memory_bank_id"] = session_info.memory_bank_id
serialized_args = tracing.serialize_value(tool_args) serialized_args = tracing.serialize_value(query_args)
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload( payload=AgentTurnResponseStepProgressPayload(
@ -415,8 +420,8 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
result = await self.tool_runtime_api.invoke_tool( result = await self.tool_runtime_api.invoke_tool(
tool_name="memory", tool_name=MEMORY_QUERY_TOOL,
args=tool_args, args=query_args,
) )
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
@ -485,7 +490,8 @@ class ChatAgent(ShieldRunnerMixin):
tools=[ tools=[
tool tool
for tool in tool_defs.values() for tool in tool_defs.values()
if tool.tool_name != "memory" if tool_to_group.get(tool.tool_name, None)
!= MEMORY_TOOL_GROUP_ID
], ],
tool_prompt_format=self.agent_config.tool_prompt_format, tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True, stream=True,
@ -632,6 +638,8 @@ class ChatAgent(ShieldRunnerMixin):
self.tool_runtime_api, self.tool_runtime_api,
session_id, session_id,
[message], [message],
toolgroup_args,
tool_to_group,
) )
assert ( assert (
len(result_messages) == 1 len(result_messages) == 1
@ -690,26 +698,37 @@ class ChatAgent(ShieldRunnerMixin):
n_iter += 1 n_iter += 1
async def _get_tool_defs( async def _get_tool_defs(
self, tools_for_turn: Optional[List[AgentTool]] self, toolgroups_for_turn: Optional[List[AgentToolGroup]]
) -> Dict[str, ToolDefinition]: ) -> Dict[str, ToolDefinition]:
# Determine which tools to include # Determine which tools to include
agent_config_tools = set( agent_config_toolgroups = set(
tool.name if isinstance(tool, AgentToolWithArgs) else tool (
for tool in self.agent_config.tools toolgroup.name
if isinstance(toolgroup, AgentToolGroupWithArgs)
else toolgroup
)
for toolgroup in self.agent_config.toolgroups
) )
tools_for_turn_set = ( toolgroups_for_turn_set = (
agent_config_tools agent_config_toolgroups
if tools_for_turn is None if toolgroups_for_turn is None
else { else {
tool.name if isinstance(tool, AgentToolWithArgs) else tool (
for tool in tools_for_turn toolgroup.name
if isinstance(toolgroup, AgentToolGroupWithArgs)
else toolgroup
)
for toolgroup in toolgroups_for_turn
} }
) )
ret = {} tool_def_map = {}
tool_to_group = {}
for tool_def in self.agent_config.client_tools: for tool_def in self.agent_config.client_tools:
ret[tool_def.name] = ToolDefinition( if tool_def_map.get(tool_def.name, None):
raise ValueError(f"Tool {tool_def.name} already exists")
tool_def_map[tool_def.name] = ToolDefinition(
tool_name=tool_def.name, tool_name=tool_def.name,
description=tool_def.description, description=tool_def.description,
parameters={ parameters={
@ -722,41 +741,42 @@ class ChatAgent(ShieldRunnerMixin):
for param in tool_def.parameters for param in tool_def.parameters
}, },
) )
tool_to_group[tool_def.name] = "__client_tools__"
for tool_name in agent_config_tools: for toolgroup_name in agent_config_toolgroups:
if tool_name not in tools_for_turn_set: if toolgroup_name not in toolgroups_for_turn_set:
continue continue
tools = await self.tool_groups_api.list_tools(tool_group_id=toolgroup_name)
for tool_def in tools:
if tool_def.built_in_type:
if tool_def_map.get(tool_def.built_in_type, None):
raise ValueError(
f"Tool {tool_def.built_in_type} already exists"
)
tool_def = await self.tool_groups_api.get_tool(tool_name) tool_def_map[tool_def.built_in_type] = ToolDefinition(
if tool_def is None: tool_name=tool_def.built_in_type
raise ValueError(f"Tool {tool_name} not found")
if tool_def.identifier.startswith("builtin::"):
built_in_type = tool_def.identifier[len("builtin::") :]
if built_in_type == "web_search":
built_in_type = "brave_search"
if built_in_type not in BuiltinTool.__members__:
raise ValueError(f"Unknown built-in tool: {built_in_type}")
ret[built_in_type] = ToolDefinition(
tool_name=BuiltinTool(built_in_type)
)
continue
ret[tool_def.identifier] = ToolDefinition(
tool_name=tool_def.identifier,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
) )
for param in tool_def.parameters tool_to_group[tool_def.built_in_type] = tool_def.toolgroup_id
}, continue
)
return ret if tool_def_map.get(tool_def.identifier, None):
raise ValueError(f"Tool {tool_def.identifier} already exists")
tool_def_map[tool_def.identifier] = ToolDefinition(
tool_name=tool_def.identifier,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool_def.parameters
},
)
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
return tool_def_map, tool_to_group
async def handle_documents( async def handle_documents(
self, self,
@ -765,8 +785,8 @@ class ChatAgent(ShieldRunnerMixin):
input_messages: List[Message], input_messages: List[Message],
tool_defs: Dict[str, ToolDefinition], tool_defs: Dict[str, ToolDefinition],
) -> None: ) -> None:
memory_tool = tool_defs.get("memory", None) memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None)
code_interpreter_tool = tool_defs.get("code_interpreter", None) code_interpreter_tool = tool_defs.get(CODE_INTERPRETER_TOOL, None)
content_items = [] content_items = []
url_items = [] url_items = []
pattern = re.compile("^(https?://|file://|data:)") pattern = re.compile("^(https?://|file://|data:)")
@ -903,7 +923,11 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
async def execute_tool_call_maybe( async def execute_tool_call_maybe(
tool_runtime_api: ToolRuntime, session_id: str, messages: List[CompletionMessage] tool_runtime_api: ToolRuntime,
session_id: str,
messages: List[CompletionMessage],
toolgroup_args: Dict[str, Dict[str, Any]],
tool_to_group: Dict[str, str],
) -> List[ToolResponseMessage]: ) -> List[ToolResponseMessage]:
# While Tools.run interface takes a list of messages, # While Tools.run interface takes a list of messages,
# All tools currently only run on a single message # All tools currently only run on a single message
@ -915,18 +939,26 @@ async def execute_tool_call_maybe(
tool_call = message.tool_calls[0] tool_call = message.tool_calls[0]
name = tool_call.tool_name name = tool_call.tool_name
group_name = tool_to_group.get(name, None)
if group_name is None:
raise ValueError(f"Tool {name} not found in any tool group")
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
tool_call_args = tool_call.arguments
tool_call_args.update(toolgroup_args.get(group_name, {}))
if isinstance(name, BuiltinTool): if isinstance(name, BuiltinTool):
if name == BuiltinTool.brave_search: if name == BuiltinTool.brave_search:
name = "builtin::web_search" name = WEB_SEARCH_TOOL
else: else:
name = "builtin::" + name.value name = name.value
result = await tool_runtime_api.invoke_tool( result = await tool_runtime_api.invoke_tool(
tool_name=name, tool_name=name,
args=dict( args=dict(
session_id=session_id, session_id=session_id,
**tool_call.arguments, **tool_call_args,
), ),
) )
return [ return [
ToolResponseMessage( ToolResponseMessage(
call_id=tool_call.call_id, call_id=tool_call.call_id,

View file

@ -19,7 +19,7 @@ from llama_stack.apis.agents import (
Agents, Agents,
AgentSessionCreateResponse, AgentSessionCreateResponse,
AgentStepResponse, AgentStepResponse,
AgentTool, AgentToolGroup,
AgentTurnCreateRequest, AgentTurnCreateRequest,
Document, Document,
Session, Session,
@ -147,7 +147,7 @@ class MetaReferenceAgentsImpl(Agents):
ToolResponseMessage, ToolResponseMessage,
] ]
], ],
tools: Optional[List[AgentTool]] = None, tools: Optional[List[AgentToolGroup]] = None,
documents: Optional[List[Document]] = None, documents: Optional[List[Document]] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
) -> AsyncGenerator: ) -> AsyncGenerator:

View file

@ -7,9 +7,16 @@
import logging import logging
import tempfile import tempfile
from typing import Any, Dict, List from typing import Any, Dict, List, Optional
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
Tool,
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .code_execution import CodeExecutionContext, CodeExecutionRequest, CodeExecutor from .code_execution import CodeExecutionContext, CodeExecutionRequest, CodeExecutor
@ -35,8 +42,22 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async def unregister_tool(self, tool_id: str) -> None: async def unregister_tool(self, tool_id: str) -> None:
return return
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]: async def list_tools(
raise NotImplementedError("Code interpreter tool group not supported") self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="code_interpreter",
description="Execute code",
parameters=[
ToolParameter(
name="code",
description="The code to execute",
parameter_type="string",
),
],
)
]
async def invoke_tool( async def invoke_tool(
self, tool_name: str, args: Dict[str, Any] self, tool_name: str, args: Dict[str, Any]

View file

@ -5,6 +5,8 @@
# the root directory of this source tree. # the root directory of this source tree.
from typing import List
from jinja2 import Template from jinja2 import Template
from llama_stack.apis.inference import Message, UserMessage from llama_stack.apis.inference import Message, UserMessage
@ -22,7 +24,7 @@ from .config import (
async def generate_rag_query( async def generate_rag_query(
config: MemoryQueryGeneratorConfig, config: MemoryQueryGeneratorConfig,
message: Message, messages: List[Message],
**kwargs, **kwargs,
): ):
""" """
@ -30,9 +32,9 @@ async def generate_rag_query(
retrieving relevant information from the memory bank. retrieving relevant information from the memory bank.
""" """
if config.type == MemoryQueryGenerator.default.value: if config.type == MemoryQueryGenerator.default.value:
query = await default_rag_query_generator(config, message, **kwargs) query = await default_rag_query_generator(config, messages, **kwargs)
elif config.type == MemoryQueryGenerator.llm.value: elif config.type == MemoryQueryGenerator.llm.value:
query = await llm_rag_query_generator(config, message, **kwargs) query = await llm_rag_query_generator(config, messages, **kwargs)
else: else:
raise NotImplementedError(f"Unsupported memory query generator {config.type}") raise NotImplementedError(f"Unsupported memory query generator {config.type}")
return query return query
@ -40,21 +42,21 @@ async def generate_rag_query(
async def default_rag_query_generator( async def default_rag_query_generator(
config: DefaultMemoryQueryGeneratorConfig, config: DefaultMemoryQueryGeneratorConfig,
message: Message, messages: List[Message],
**kwargs, **kwargs,
): ):
return interleaved_content_as_str(message.content) return config.sep.join(interleaved_content_as_str(m.content) for m in messages)
async def llm_rag_query_generator( async def llm_rag_query_generator(
config: LLMMemoryQueryGeneratorConfig, config: LLMMemoryQueryGeneratorConfig,
message: Message, messages: List[Message],
**kwargs, **kwargs,
): ):
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api" assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
inference_api = kwargs["inference_api"] inference_api = kwargs["inference_api"]
m_dict = {"messages": [message.model_dump()]} m_dict = {"messages": [message.model_dump() for message in messages]}
template = Template(config.template) template = Template(config.template)
content = template.render(m_dict) content = template.render(m_dict)

View file

@ -10,13 +10,14 @@ import secrets
import string import string
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from llama_stack.apis.inference import Inference, InterleavedContent, Message from llama_stack.apis.common.content_types import URL
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.memory import Memory, QueryDocumentsResponse from llama_stack.apis.memory import Memory, QueryDocumentsResponse
from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ToolDef, ToolDef,
ToolGroupDef,
ToolInvocationResult, ToolInvocationResult,
ToolParameter,
ToolRuntime, ToolRuntime,
) )
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolsProtocolPrivate
@ -50,17 +51,31 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async def initialize(self): async def initialize(self):
pass pass
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: async def list_tools(
return [] self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="memory",
description="Retrieve context from memory",
parameters=[
ToolParameter(
name="input_messages",
description="The input messages to search for",
parameter_type="array",
),
],
)
]
async def _retrieve_context( async def _retrieve_context(
self, message: Message, bank_ids: List[str] self, input_messages: List[str], bank_ids: List[str]
) -> Optional[List[InterleavedContent]]: ) -> Optional[List[InterleavedContent]]:
if not bank_ids: if not bank_ids:
return None return None
query = await generate_rag_query( query = await generate_rag_query(
self.config.query_generator_config, self.config.query_generator_config,
message, input_messages,
inference_api=self.inference_api, inference_api=self.inference_api,
) )
tasks = [ tasks = [
@ -106,17 +121,22 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
self, tool_name: str, args: Dict[str, Any] self, tool_name: str, args: Dict[str, Any]
) -> ToolInvocationResult: ) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name) tool = await self.tool_store.get_tool(tool_name)
tool_group = await self.tool_store.get_tool_group(tool.toolgroup_id)
final_args = tool_group.args or {}
final_args.update(args)
config = MemoryToolConfig() config = MemoryToolConfig()
if tool.metadata.get("config") is not None: if tool.metadata and tool.metadata.get("config") is not None:
config = MemoryToolConfig(**tool.metadata["config"]) config = MemoryToolConfig(**tool.metadata["config"])
if "memory_bank_id" in args: if "memory_bank_ids" in final_args:
bank_ids = [args["memory_bank_id"]] bank_ids = final_args["memory_bank_ids"]
else: else:
bank_ids = [ bank_ids = [
bank_config.bank_id for bank_config in config.memory_bank_configs bank_config.bank_id for bank_config in config.memory_bank_configs
] ]
if "messages" not in final_args:
raise ValueError("messages are required")
context = await self._retrieve_context( context = await self._retrieve_context(
args["query"], final_args["messages"],
bank_ids, bank_ids,
) )
if context is None: if context is None:

View file

@ -4,11 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List from typing import Any, Dict, List, Optional
import requests import requests
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
Tool,
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolsProtocolPrivate
@ -41,8 +48,22 @@ class BraveSearchToolRuntimeImpl(
) )
return provider_data.api_key return provider_data.api_key
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]: async def list_tools(
raise NotImplementedError("Brave search tool group not supported") self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="web_search",
description="Search the web for information",
parameters=[
ToolParameter(
name="query",
description="The query to search for",
parameter_type="string",
)
],
)
]
async def invoke_tool( async def invoke_tool(
self, tool_name: str, args: Dict[str, Any] self, tool_name: str, args: Dict[str, Any]

View file

@ -4,20 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List from typing import Any, Dict, List, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
from mcp import ClientSession from mcp import ClientSession
from mcp.client.sse import sse_client from mcp.client.sse import sse_client
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
MCPToolGroupDef,
ToolDef, ToolDef,
ToolGroupDef,
ToolInvocationResult, ToolInvocationResult,
ToolParameter, ToolParameter,
ToolRuntime, ToolRuntime,
UserDefinedToolDef,
) )
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolsProtocolPrivate
@ -31,12 +29,14 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async def initialize(self): async def initialize(self):
pass pass
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: async def list_tools(
if not isinstance(tool_group, MCPToolGroupDef): self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
raise ValueError(f"Unsupported tool group type: {type(tool_group)}") ) -> List[ToolDef]:
if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required")
tools = [] tools = []
async with sse_client(tool_group.endpoint.uri) as streams: async with sse_client(mcp_endpoint.uri) as streams:
async with ClientSession(*streams) as session: async with ClientSession(*streams) as session:
await session.initialize() await session.initialize()
tools_result = await session.list_tools() tools_result = await session.list_tools()
@ -53,12 +53,12 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
) )
) )
tools.append( tools.append(
UserDefinedToolDef( ToolDef(
name=tool.name, name=tool.name,
description=tool.description, description=tool.description,
parameters=parameters, parameters=parameters,
metadata={ metadata={
"endpoint": tool_group.endpoint.uri, "endpoint": mcp_endpoint.uri,
}, },
) )
) )

View file

@ -5,11 +5,18 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
from typing import Any, Dict, List from typing import Any, Dict, List, Optional
import requests import requests
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
Tool,
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolsProtocolPrivate
@ -42,8 +49,22 @@ class TavilySearchToolRuntimeImpl(
) )
return provider_data.api_key return provider_data.api_key
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]: async def list_tools(
raise NotImplementedError("Tavily search tool group not supported") self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="web_search",
description="Search the web for information",
parameters=[
ToolParameter(
name="query",
description="The query to search for",
parameter_type="string",
)
],
)
]
async def invoke_tool( async def invoke_tool(
self, tool_name: str, args: Dict[str, Any] self, tool_name: str, args: Dict[str, Any]

View file

@ -45,8 +45,7 @@ def common_params(inference_model):
sampling_params=SamplingParams(temperature=0.7, top_p=0.95), sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
input_shields=[], input_shields=[],
output_shields=[], output_shields=[],
available_tools=[], toolgroups=[],
preprocessing_tools=[],
max_infer_iters=5, max_infer_iters=5,
) )
@ -83,27 +82,27 @@ def query_attachment_messages():
] ]
async def create_agent_turn_with_search_tool( async def create_agent_turn_with_toolgroup(
agents_stack: Dict[str, object], agents_stack: Dict[str, object],
search_query_messages: List[object], search_query_messages: List[object],
common_params: Dict[str, str], common_params: Dict[str, str],
tool_name: str, toolgroup_name: str,
) -> None: ) -> None:
""" """
Create an agent turn with a search tool. Create an agent turn with a toolgroup.
Args: Args:
agents_stack (Dict[str, object]): The agents stack. agents_stack (Dict[str, object]): The agents stack.
search_query_messages (List[object]): The search query messages. search_query_messages (List[object]): The search query messages.
common_params (Dict[str, str]): The common parameters. common_params (Dict[str, str]): The common parameters.
search_tool_definition (SearchToolDefinition): The search tool definition. toolgroup_name (str): The name of the toolgroup.
""" """
# Create an agent with the search tool # Create an agent with the toolgroup
agent_config = AgentConfig( agent_config = AgentConfig(
**{ **{
**common_params, **common_params,
"tools": [tool_name], "toolgroups": [toolgroup_name],
} }
) )
@ -249,7 +248,7 @@ class TestAgents:
agent_config = AgentConfig( agent_config = AgentConfig(
**{ **{
**common_params, **common_params,
"tools": ["memory"], "toolgroups": ["builtin::memory"],
"tool_choice": ToolChoice.auto, "tool_choice": ToolChoice.auto,
} }
) )
@ -289,13 +288,58 @@ class TestAgents:
if "TAVILY_SEARCH_API_KEY" not in os.environ: if "TAVILY_SEARCH_API_KEY" not in os.environ:
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test") pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
await create_agent_turn_with_search_tool( # Create an agent with the toolgroup
agents_stack, agent_config = AgentConfig(
search_query_messages, **{
common_params, **common_params,
"brave_search", "toolgroups": ["builtin::web_search"],
}
) )
agent_id, session_id = await create_agent_session(
agents_stack.impls[Api.agents], agent_config
)
turn_request = dict(
agent_id=agent_id,
session_id=session_id,
messages=search_query_messages,
stream=True,
)
turn_response = [
chunk
async for chunk in await agents_stack.impls[Api.agents].create_agent_turn(
**turn_request
)
]
assert len(turn_response) > 0
assert all(
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
)
check_event_types(turn_response)
# Check for tool execution events
tool_execution_events = [
chunk
for chunk in turn_response
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
and chunk.event.payload.step_details.step_type
== StepType.tool_execution.value
]
assert len(tool_execution_events) > 0, "No tool execution events found"
# Check the tool execution details
tool_execution = tool_execution_events[0].event.payload.step_details
assert isinstance(tool_execution, ToolExecutionStep)
assert len(tool_execution.tool_calls) > 0
actual_tool_name = tool_execution.tool_calls[0].tool_name
assert actual_tool_name == "web_search"
assert len(tool_execution.tool_responses) > 0
check_turn_complete_event(turn_response, session_id, search_query_messages)
def check_event_types(turn_response): def check_event_types(turn_response):
event_types = [chunk.event.payload.event_type for chunk in turn_response] event_types = [chunk.event.payload.event_type for chunk in turn_response]

View file

@ -8,16 +8,9 @@ import os
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.models import ModelInput, ModelType from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.apis.tools import ( from llama_stack.apis.tools import ToolGroupInput
BuiltInToolDef,
ToolGroupInput,
ToolParameter,
UserDefinedToolDef,
UserDefinedToolGroupDef,
)
from llama_stack.distribution.datatypes import Api, Provider from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import construct_stack_for_test from llama_stack.providers.tests.resolver import construct_stack_for_test
@ -47,30 +40,7 @@ def tool_runtime_memory_and_search() -> ProviderFixture:
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def tool_group_input_memory() -> ToolGroupInput: def tool_group_input_memory() -> ToolGroupInput:
return ToolGroupInput( return ToolGroupInput(
tool_group_id="memory_group", toolgroup_id="builtin::memory",
tool_group=UserDefinedToolGroupDef(
tools=[
UserDefinedToolDef(
name="memory",
description="Query the memory bank",
parameters=[
ToolParameter(
name="input_messages",
description="The input messages to search for in memory",
parameter_type="list",
required=True,
),
],
metadata={
"config": {
"memory_bank_configs": [
{"bank_id": "test_bank", "type": "vector"}
]
}
},
)
],
),
provider_id="memory-runtime", provider_id="memory-runtime",
) )
@ -78,10 +48,7 @@ def tool_group_input_memory() -> ToolGroupInput:
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def tool_group_input_tavily_search() -> ToolGroupInput: def tool_group_input_tavily_search() -> ToolGroupInput:
return ToolGroupInput( return ToolGroupInput(
tool_group_id="tavily_search_group", toolgroup_id="builtin::web_search",
tool_group=UserDefinedToolGroupDef(
tools=[BuiltInToolDef(built_in_type=BuiltinTool.brave_search, metadata={})],
),
provider_id="tavily-search", provider_id="tavily-search",
) )

View file

@ -43,8 +43,8 @@ def sample_documents():
class TestTools: class TestTools:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_brave_search_tool(self, tools_stack, sample_search_query): async def test_web_search_tool(self, tools_stack, sample_search_query):
"""Test the Brave search tool functionality.""" """Test the web search tool functionality."""
if "TAVILY_SEARCH_API_KEY" not in os.environ: if "TAVILY_SEARCH_API_KEY" not in os.environ:
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test") pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
@ -52,7 +52,7 @@ class TestTools:
# Execute the tool # Execute the tool
response = await tools_impl.invoke_tool( response = await tools_impl.invoke_tool(
tool_name="brave_search", args={"query": sample_search_query} tool_name="web_search", args={"query": sample_search_query}
) )
# Verify the response # Verify the response
@ -89,11 +89,12 @@ class TestTools:
response = await tools_impl.invoke_tool( response = await tools_impl.invoke_tool(
tool_name="memory", tool_name="memory",
args={ args={
"input_messages": [ "messages": [
UserMessage( UserMessage(
content="What are the main topics covered in the documentation?", content="What are the main topics covered in the documentation?",
) )
], ],
"memory_bank_ids": ["test_bank"],
}, },
) )