From f9a98c278a325df0e12c8a982c46faea49c5f6c5 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 7 Jan 2025 15:37:52 -0800 Subject: [PATCH] simplify toolgroups registration --- llama_stack/apis/agents/agents.py | 12 +- llama_stack/apis/tools/tools.py | 58 ++---- llama_stack/distribution/routers/routers.py | 14 +- .../distribution/routers/routing_tables.py | 57 ++---- .../agents/meta_reference/agent_instance.py | 178 +++++++++++------- .../inline/agents/meta_reference/agents.py | 4 +- .../code_interpreter/code_interpreter.py | 29 ++- .../tool_runtime/memory/context_retriever.py | 16 +- .../inline/tool_runtime/memory/memory.py | 40 +++- .../tool_runtime/brave_search/brave_search.py | 29 ++- .../model_context_protocol.py | 20 +- .../tavily_search/tavily_search.py | 29 ++- .../providers/tests/agents/test_agents.py | 72 +++++-- llama_stack/providers/tests/tools/fixtures.py | 39 +--- .../providers/tests/tools/test_tools.py | 9 +- 15 files changed, 350 insertions(+), 256 deletions(-) diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index db0e3ab3b..f5fbcb9c4 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -137,15 +137,15 @@ class Session(BaseModel): memory_bank: Optional[MemoryBank] = None -class AgentToolWithArgs(BaseModel): +class AgentToolGroupWithArgs(BaseModel): name: str args: Dict[str, Any] -AgentTool = register_schema( +AgentToolGroup = register_schema( Union[ str, - AgentToolWithArgs, + AgentToolGroupWithArgs, ], name="AgentTool", ) @@ -156,7 +156,7 @@ class AgentConfigCommon(BaseModel): input_shields: Optional[List[str]] = Field(default_factory=list) output_shields: Optional[List[str]] = Field(default_factory=list) - tools: Optional[List[AgentTool]] = Field(default_factory=list) + toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list) client_tools: Optional[List[ToolDef]] = Field(default_factory=list) tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) tool_prompt_format: Optional[ToolPromptFormat] = Field( @@ -278,7 +278,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn): ] documents: Optional[List[Document]] = None - tools: Optional[List[AgentTool]] = None + toolgroups: Optional[List[AgentToolGroup]] = None stream: Optional[bool] = False @@ -317,7 +317,7 @@ class Agents(Protocol): ], stream: Optional[bool] = False, documents: Optional[List[Document]] = None, - tools: Optional[List[AgentTool]] = None, + tools: Optional[List[AgentToolGroup]] = None, ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... @webmethod(route="/agents/turn/get") diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index bc19a8a02..24845e101 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -5,10 +5,10 @@ # the root directory of this source tree. from enum import Enum -from typing import Annotated, Any, Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional from llama_models.llama3.api.datatypes import BuiltinTool, ToolPromptFormat -from llama_models.schema_utils import json_schema_type, register_schema, webmethod +from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from typing_extensions import Protocol, runtime_checkable @@ -22,7 +22,7 @@ class ToolParameter(BaseModel): name: str parameter_type: str description: str - required: bool + required: bool = Field(default=True) default: Optional[Any] = None @@ -36,7 +36,7 @@ class ToolHost(Enum): @json_schema_type class Tool(Resource): type: Literal[ResourceType.tool.value] = ResourceType.tool.value - tool_group: str + toolgroup_id: str tool_host: ToolHost description: str parameters: List[ToolParameter] @@ -58,41 +58,19 @@ class ToolDef(BaseModel): ) -@json_schema_type -class MCPToolGroupDef(BaseModel): - """ - A tool group that is defined by in a model context protocol server. - Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information. - """ - - type: Literal["model_context_protocol"] = "model_context_protocol" - endpoint: URL - - -@json_schema_type -class UserDefinedToolGroupDef(BaseModel): - type: Literal["user_defined"] = "user_defined" - tools: List[ToolDef] - - -ToolGroupDef = register_schema( - Annotated[ - Union[MCPToolGroupDef, UserDefinedToolGroupDef], Field(discriminator="type") - ], - name="ToolGroupDef", -) - - @json_schema_type class ToolGroupInput(BaseModel): - tool_group_id: str - tool_group_def: ToolGroupDef - provider_id: Optional[str] = None + toolgroup_id: str + provider_id: str + args: Optional[Dict[str, Any]] = None + mcp_endpoint: Optional[URL] = None @json_schema_type class ToolGroup(Resource): type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value + mcp_endpoint: Optional[URL] = None + args: Optional[Dict[str, Any]] = None @json_schema_type @@ -104,6 +82,7 @@ class ToolInvocationResult(BaseModel): class ToolStore(Protocol): def get_tool(self, tool_name: str) -> Tool: ... + def get_tool_group(self, tool_group_id: str) -> ToolGroup: ... @runtime_checkable @@ -112,9 +91,10 @@ class ToolGroups(Protocol): @webmethod(route="/toolgroups/register", method="POST") async def register_tool_group( self, - tool_group_id: str, - tool_group_def: ToolGroupDef, - provider_id: Optional[str] = None, + toolgroup_id: str, + provider_id: str, + mcp_endpoint: Optional[URL] = None, + args: Optional[Dict[str, Any]] = None, ) -> None: """Register a tool group""" ... @@ -122,7 +102,7 @@ class ToolGroups(Protocol): @webmethod(route="/toolgroups/get", method="GET") async def get_tool_group( self, - tool_group_id: str, + toolgroup_id: str, ) -> ToolGroup: ... @webmethod(route="/toolgroups/list", method="GET") @@ -149,8 +129,10 @@ class ToolGroups(Protocol): class ToolRuntime(Protocol): tool_store: ToolStore - @webmethod(route="/tool-runtime/discover", method="POST") - async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: ... + @webmethod(route="/tool-runtime/list-tools", method="POST") + async def list_tools( + self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + ) -> List[ToolDef]: ... @webmethod(route="/tool-runtime/invoke", method="POST") async def invoke_tool( diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 84ef467eb..230feea71 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -6,7 +6,7 @@ from typing import Any, AsyncGenerator, Dict, List, Optional -from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.common.content_types import InterleavedContent, URL from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult from llama_stack.apis.eval import ( AppEvalTaskConfig, @@ -38,7 +38,7 @@ from llama_stack.apis.scoring import ( ScoringFnParams, ) from llama_stack.apis.shields import Shield -from llama_stack.apis.tools import Tool, ToolGroupDef, ToolRuntime +from llama_stack.apis.tools import ToolDef, ToolRuntime from llama_stack.providers.datatypes import RoutingTable @@ -417,7 +417,9 @@ class ToolRuntimeRouter(ToolRuntime): args=args, ) - async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]: - return await self.routing_table.get_provider_impl( - tool_group.name - ).discover_tools(tool_group) + async def list_tools( + self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + ) -> List[ToolDef]: + return await self.routing_table.get_provider_impl(tool_group_id).list_tools( + tool_group_id, mcp_endpoint + ) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index b51de8fef..4ed932807 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -26,15 +26,7 @@ from llama_stack.apis.scoring_functions import ( ScoringFunctions, ) from llama_stack.apis.shields import Shield, Shields -from llama_stack.apis.tools import ( - MCPToolGroupDef, - Tool, - ToolGroup, - ToolGroupDef, - ToolGroups, - ToolHost, - UserDefinedToolGroupDef, -) +from llama_stack.apis.tools import Tool, ToolGroup, ToolGroups, ToolHost from llama_stack.distribution.datatypes import ( RoutableObject, RoutableObjectWithProvider, @@ -496,51 +488,38 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]: tools = await self.get_all_with_type("tool") if tool_group_id: - tools = [tool for tool in tools if tool.tool_group == tool_group_id] + tools = [tool for tool in tools if tool.toolgroup_id == tool_group_id] return tools async def list_tool_groups(self) -> List[ToolGroup]: return await self.get_all_with_type("tool_group") - async def get_tool_group(self, tool_group_id: str) -> ToolGroup: - return await self.get_object_by_identifier("tool_group", tool_group_id) + async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: + return await self.get_object_by_identifier("tool_group", toolgroup_id) async def get_tool(self, tool_name: str) -> Tool: return await self.get_object_by_identifier("tool", tool_name) async def register_tool_group( self, - tool_group_id: str, - tool_group_def: ToolGroupDef, - provider_id: Optional[str] = None, + toolgroup_id: str, + provider_id: str, + mcp_endpoint: Optional[URL] = None, + args: Optional[Dict[str, Any]] = None, ) -> None: tools = [] - tool_defs = [] - tool_host = ToolHost.distribution - if provider_id is None: - if len(self.impls_by_provider_id.keys()) > 1: - raise ValueError( - f"No provider_id specified and multiple providers available. Please specify a provider_id. Available providers: {', '.join(self.impls_by_provider_id.keys())}" - ) - provider_id = list(self.impls_by_provider_id.keys())[0] - - # parse tool group to the type if dict - tool_group_def = TypeAdapter(ToolGroupDef).validate_python(tool_group_def) - if isinstance(tool_group_def, MCPToolGroupDef): - tool_defs = await self.impls_by_provider_id[provider_id].discover_tools( - tool_group_def - ) - tool_host = ToolHost.model_context_protocol - elif isinstance(tool_group_def, UserDefinedToolGroupDef): - tool_defs = tool_group_def.tools - else: - raise ValueError(f"Unknown tool group: {tool_group_def}") + tool_defs = await self.impls_by_provider_id[provider_id].list_tools( + toolgroup_id, mcp_endpoint + ) + tool_host = ( + ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution + ) for tool_def in tool_defs: tools.append( Tool( identifier=tool_def.name, - tool_group=tool_group_id, + toolgroup_id=toolgroup_id, description=tool_def.description or "", parameters=tool_def.parameters or [], provider_id=provider_id, @@ -565,9 +544,11 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): await self.dist_registry.register( ToolGroup( - identifier=tool_group_id, + identifier=toolgroup_id, provider_id=provider_id, - provider_resource_id=tool_group_id, + provider_resource_id=toolgroup_id, + mcp_endpoint=mcp_endpoint, + args=args, ) ) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index ceb764ffe..cfe839dad 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -13,7 +13,7 @@ import secrets import string import uuid from datetime import datetime -from typing import AsyncGenerator, Dict, List, Optional +from typing import Any, AsyncGenerator, Dict, List, Optional from urllib.parse import urlparse import httpx @@ -21,8 +21,8 @@ from llama_models.llama3.api.datatypes import BuiltinTool, ToolCall, ToolParamDe from llama_stack.apis.agents import ( AgentConfig, - AgentTool, - AgentToolWithArgs, + AgentToolGroup, + AgentToolGroupWithArgs, AgentTurnCreateRequest, AgentTurnResponseEvent, AgentTurnResponseEventType, @@ -76,6 +76,10 @@ def make_random_string(length: int = 8): TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})") +MEMORY_TOOL_GROUP_ID = "builtin::memory" +MEMORY_QUERY_TOOL = "query_memory" +CODE_INTERPRETER_TOOL = "code_interpreter" +WEB_SEARCH_TOOL = "web_search" class ChatAgent(ShieldRunnerMixin): @@ -192,7 +196,7 @@ class ChatAgent(ShieldRunnerMixin): sampling_params=self.agent_config.sampling_params, stream=request.stream, documents=request.documents, - tools_for_turn=request.tools, + toolgroups_for_turn=request.toolgroups, ): if isinstance(chunk, CompletionMessage): log.info( @@ -243,7 +247,7 @@ class ChatAgent(ShieldRunnerMixin): sampling_params: SamplingParams, stream: bool = False, documents: Optional[List[Document]] = None, - tools_for_turn: Optional[List[AgentTool]] = None, + toolgroups_for_turn: Optional[List[AgentToolGroup]] = None, ) -> AsyncGenerator: # Doing async generators makes downstream code much simpler and everything amenable to # streaming. However, it also makes things complicated here because AsyncGenerators cannot @@ -266,7 +270,7 @@ class ChatAgent(ShieldRunnerMixin): sampling_params, stream, documents, - tools_for_turn, + toolgroups_for_turn, ): if isinstance(res, bool): return @@ -362,21 +366,24 @@ class ChatAgent(ShieldRunnerMixin): sampling_params: SamplingParams, stream: bool = False, documents: Optional[List[Document]] = None, - tools_for_turn: Optional[List[AgentTool]] = None, + toolgroups_for_turn: Optional[List[AgentToolGroup]] = None, ) -> AsyncGenerator: - tool_args = {} - if tools_for_turn: - for tool in tools_for_turn: - if isinstance(tool, AgentToolWithArgs): - tool_args[tool.name] = tool.args + toolgroup_args = {} + for toolgroup in self.agent_config.toolgroups: + if isinstance(toolgroup, AgentToolGroupWithArgs): + toolgroup_args[toolgroup.name] = toolgroup.args + if toolgroups_for_turn: + for toolgroup in toolgroups_for_turn: + if isinstance(toolgroup, AgentToolGroupWithArgs): + toolgroup_args[toolgroup.name] = toolgroup.args - tool_defs = await self._get_tool_defs(tools_for_turn) + tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn) if documents: await self.handle_documents( session_id, documents, input_messages, tool_defs ) - if "memory" in tool_defs and len(input_messages) > 0: - with tracing.span("memory_tool") as span: + if MEMORY_QUERY_TOOL in tool_defs and len(input_messages) > 0: + with tracing.span(MEMORY_QUERY_TOOL) as span: step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( @@ -386,18 +393,16 @@ class ChatAgent(ShieldRunnerMixin): ) ) ) - extra_args = tool_args.get("memory", {}) - tool_args = { - # Query memory with the last message's content - "query": input_messages[-1], - **extra_args, + query_args = { + "messages": [msg.content for msg in input_messages], + **toolgroup_args.get(MEMORY_TOOL_GROUP_ID, {}), } session_info = await self.storage.get_session_info(session_id) # if the session has a memory bank id, let the memory tool use it if session_info.memory_bank_id: - tool_args["memory_bank_id"] = session_info.memory_bank_id - serialized_args = tracing.serialize_value(tool_args) + query_args["memory_bank_id"] = session_info.memory_bank_id + serialized_args = tracing.serialize_value(query_args) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( @@ -415,8 +420,8 @@ class ChatAgent(ShieldRunnerMixin): ) ) result = await self.tool_runtime_api.invoke_tool( - tool_name="memory", - args=tool_args, + tool_name=MEMORY_QUERY_TOOL, + args=query_args, ) yield AgentTurnResponseStreamChunk( @@ -485,7 +490,8 @@ class ChatAgent(ShieldRunnerMixin): tools=[ tool for tool in tool_defs.values() - if tool.tool_name != "memory" + if tool_to_group.get(tool.tool_name, None) + != MEMORY_TOOL_GROUP_ID ], tool_prompt_format=self.agent_config.tool_prompt_format, stream=True, @@ -632,6 +638,8 @@ class ChatAgent(ShieldRunnerMixin): self.tool_runtime_api, session_id, [message], + toolgroup_args, + tool_to_group, ) assert ( len(result_messages) == 1 @@ -690,26 +698,37 @@ class ChatAgent(ShieldRunnerMixin): n_iter += 1 async def _get_tool_defs( - self, tools_for_turn: Optional[List[AgentTool]] + self, toolgroups_for_turn: Optional[List[AgentToolGroup]] ) -> Dict[str, ToolDefinition]: # Determine which tools to include - agent_config_tools = set( - tool.name if isinstance(tool, AgentToolWithArgs) else tool - for tool in self.agent_config.tools + agent_config_toolgroups = set( + ( + toolgroup.name + if isinstance(toolgroup, AgentToolGroupWithArgs) + else toolgroup + ) + for toolgroup in self.agent_config.toolgroups ) - tools_for_turn_set = ( - agent_config_tools - if tools_for_turn is None + toolgroups_for_turn_set = ( + agent_config_toolgroups + if toolgroups_for_turn is None else { - tool.name if isinstance(tool, AgentToolWithArgs) else tool - for tool in tools_for_turn + ( + toolgroup.name + if isinstance(toolgroup, AgentToolGroupWithArgs) + else toolgroup + ) + for toolgroup in toolgroups_for_turn } ) - ret = {} + tool_def_map = {} + tool_to_group = {} for tool_def in self.agent_config.client_tools: - ret[tool_def.name] = ToolDefinition( + if tool_def_map.get(tool_def.name, None): + raise ValueError(f"Tool {tool_def.name} already exists") + tool_def_map[tool_def.name] = ToolDefinition( tool_name=tool_def.name, description=tool_def.description, parameters={ @@ -722,41 +741,42 @@ class ChatAgent(ShieldRunnerMixin): for param in tool_def.parameters }, ) - - for tool_name in agent_config_tools: - if tool_name not in tools_for_turn_set: + tool_to_group[tool_def.name] = "__client_tools__" + for toolgroup_name in agent_config_toolgroups: + if toolgroup_name not in toolgroups_for_turn_set: continue + tools = await self.tool_groups_api.list_tools(tool_group_id=toolgroup_name) + for tool_def in tools: + if tool_def.built_in_type: + if tool_def_map.get(tool_def.built_in_type, None): + raise ValueError( + f"Tool {tool_def.built_in_type} already exists" + ) - tool_def = await self.tool_groups_api.get_tool(tool_name) - if tool_def is None: - raise ValueError(f"Tool {tool_name} not found") - - if tool_def.identifier.startswith("builtin::"): - built_in_type = tool_def.identifier[len("builtin::") :] - if built_in_type == "web_search": - built_in_type = "brave_search" - if built_in_type not in BuiltinTool.__members__: - raise ValueError(f"Unknown built-in tool: {built_in_type}") - ret[built_in_type] = ToolDefinition( - tool_name=BuiltinTool(built_in_type) - ) - continue - - ret[tool_def.identifier] = ToolDefinition( - tool_name=tool_def.identifier, - description=tool_def.description, - parameters={ - param.name: ToolParamDefinition( - param_type=param.parameter_type, - description=param.description, - required=param.required, - default=param.default, + tool_def_map[tool_def.built_in_type] = ToolDefinition( + tool_name=tool_def.built_in_type ) - for param in tool_def.parameters - }, - ) + tool_to_group[tool_def.built_in_type] = tool_def.toolgroup_id + continue - return ret + if tool_def_map.get(tool_def.identifier, None): + raise ValueError(f"Tool {tool_def.identifier} already exists") + tool_def_map[tool_def.identifier] = ToolDefinition( + tool_name=tool_def.identifier, + description=tool_def.description, + parameters={ + param.name: ToolParamDefinition( + param_type=param.parameter_type, + description=param.description, + required=param.required, + default=param.default, + ) + for param in tool_def.parameters + }, + ) + tool_to_group[tool_def.identifier] = tool_def.toolgroup_id + + return tool_def_map, tool_to_group async def handle_documents( self, @@ -765,8 +785,8 @@ class ChatAgent(ShieldRunnerMixin): input_messages: List[Message], tool_defs: Dict[str, ToolDefinition], ) -> None: - memory_tool = tool_defs.get("memory", None) - code_interpreter_tool = tool_defs.get("code_interpreter", None) + memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None) + code_interpreter_tool = tool_defs.get(CODE_INTERPRETER_TOOL, None) content_items = [] url_items = [] pattern = re.compile("^(https?://|file://|data:)") @@ -903,7 +923,11 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa async def execute_tool_call_maybe( - tool_runtime_api: ToolRuntime, session_id: str, messages: List[CompletionMessage] + tool_runtime_api: ToolRuntime, + session_id: str, + messages: List[CompletionMessage], + toolgroup_args: Dict[str, Dict[str, Any]], + tool_to_group: Dict[str, str], ) -> List[ToolResponseMessage]: # While Tools.run interface takes a list of messages, # All tools currently only run on a single message @@ -915,18 +939,26 @@ async def execute_tool_call_maybe( tool_call = message.tool_calls[0] name = tool_call.tool_name + group_name = tool_to_group.get(name, None) + if group_name is None: + raise ValueError(f"Tool {name} not found in any tool group") + # get the arguments generated by the model and augment with toolgroup arg overrides for the agent + tool_call_args = tool_call.arguments + tool_call_args.update(toolgroup_args.get(group_name, {})) if isinstance(name, BuiltinTool): if name == BuiltinTool.brave_search: - name = "builtin::web_search" + name = WEB_SEARCH_TOOL else: - name = "builtin::" + name.value + name = name.value + result = await tool_runtime_api.invoke_tool( tool_name=name, args=dict( session_id=session_id, - **tool_call.arguments, + **tool_call_args, ), ) + return [ ToolResponseMessage( call_id=tool_call.call_id, diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 0181ef609..2ea74300d 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -19,7 +19,7 @@ from llama_stack.apis.agents import ( Agents, AgentSessionCreateResponse, AgentStepResponse, - AgentTool, + AgentToolGroup, AgentTurnCreateRequest, Document, Session, @@ -147,7 +147,7 @@ class MetaReferenceAgentsImpl(Agents): ToolResponseMessage, ] ], - tools: Optional[List[AgentTool]] = None, + tools: Optional[List[AgentToolGroup]] = None, documents: Optional[List[Document]] = None, stream: Optional[bool] = False, ) -> AsyncGenerator: diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py index 0fe0d0243..fc568996d 100644 --- a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py @@ -7,9 +7,16 @@ import logging import tempfile -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional -from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime +from llama_stack.apis.common.content_types import URL +from llama_stack.apis.tools import ( + Tool, + ToolDef, + ToolInvocationResult, + ToolParameter, + ToolRuntime, +) from llama_stack.providers.datatypes import ToolsProtocolPrivate from .code_execution import CodeExecutionContext, CodeExecutionRequest, CodeExecutor @@ -35,8 +42,22 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): async def unregister_tool(self, tool_id: str) -> None: return - async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]: - raise NotImplementedError("Code interpreter tool group not supported") + async def list_tools( + self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + ) -> List[ToolDef]: + return [ + ToolDef( + name="code_interpreter", + description="Execute code", + parameters=[ + ToolParameter( + name="code", + description="The code to execute", + parameter_type="string", + ), + ], + ) + ] async def invoke_tool( self, tool_name: str, args: Dict[str, Any] diff --git a/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py b/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py index 7ee751a17..1fb1d0992 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py +++ b/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py @@ -5,6 +5,8 @@ # the root directory of this source tree. +from typing import List + from jinja2 import Template from llama_stack.apis.inference import Message, UserMessage @@ -22,7 +24,7 @@ from .config import ( async def generate_rag_query( config: MemoryQueryGeneratorConfig, - message: Message, + messages: List[Message], **kwargs, ): """ @@ -30,9 +32,9 @@ async def generate_rag_query( retrieving relevant information from the memory bank. """ if config.type == MemoryQueryGenerator.default.value: - query = await default_rag_query_generator(config, message, **kwargs) + query = await default_rag_query_generator(config, messages, **kwargs) elif config.type == MemoryQueryGenerator.llm.value: - query = await llm_rag_query_generator(config, message, **kwargs) + query = await llm_rag_query_generator(config, messages, **kwargs) else: raise NotImplementedError(f"Unsupported memory query generator {config.type}") return query @@ -40,21 +42,21 @@ async def generate_rag_query( async def default_rag_query_generator( config: DefaultMemoryQueryGeneratorConfig, - message: Message, + messages: List[Message], **kwargs, ): - return interleaved_content_as_str(message.content) + return config.sep.join(interleaved_content_as_str(m.content) for m in messages) async def llm_rag_query_generator( config: LLMMemoryQueryGeneratorConfig, - message: Message, + messages: List[Message], **kwargs, ): assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api" inference_api = kwargs["inference_api"] - m_dict = {"messages": [message.model_dump()]} + m_dict = {"messages": [message.model_dump() for message in messages]} template = Template(config.template) content = template.render(m_dict) diff --git a/llama_stack/providers/inline/tool_runtime/memory/memory.py b/llama_stack/providers/inline/tool_runtime/memory/memory.py index cad123696..c8c2cc772 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/memory.py +++ b/llama_stack/providers/inline/tool_runtime/memory/memory.py @@ -10,13 +10,14 @@ import secrets import string from typing import Any, Dict, List, Optional -from llama_stack.apis.inference import Inference, InterleavedContent, Message +from llama_stack.apis.common.content_types import URL +from llama_stack.apis.inference import Inference, InterleavedContent from llama_stack.apis.memory import Memory, QueryDocumentsResponse from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.tools import ( ToolDef, - ToolGroupDef, ToolInvocationResult, + ToolParameter, ToolRuntime, ) from llama_stack.providers.datatypes import ToolsProtocolPrivate @@ -50,17 +51,31 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): async def initialize(self): pass - async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: - return [] + async def list_tools( + self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + ) -> List[ToolDef]: + return [ + ToolDef( + name="memory", + description="Retrieve context from memory", + parameters=[ + ToolParameter( + name="input_messages", + description="The input messages to search for", + parameter_type="array", + ), + ], + ) + ] async def _retrieve_context( - self, message: Message, bank_ids: List[str] + self, input_messages: List[str], bank_ids: List[str] ) -> Optional[List[InterleavedContent]]: if not bank_ids: return None query = await generate_rag_query( self.config.query_generator_config, - message, + input_messages, inference_api=self.inference_api, ) tasks = [ @@ -106,17 +121,22 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): self, tool_name: str, args: Dict[str, Any] ) -> ToolInvocationResult: tool = await self.tool_store.get_tool(tool_name) + tool_group = await self.tool_store.get_tool_group(tool.toolgroup_id) + final_args = tool_group.args or {} + final_args.update(args) config = MemoryToolConfig() - if tool.metadata.get("config") is not None: + if tool.metadata and tool.metadata.get("config") is not None: config = MemoryToolConfig(**tool.metadata["config"]) - if "memory_bank_id" in args: - bank_ids = [args["memory_bank_id"]] + if "memory_bank_ids" in final_args: + bank_ids = final_args["memory_bank_ids"] else: bank_ids = [ bank_config.bank_id for bank_config in config.memory_bank_configs ] + if "messages" not in final_args: + raise ValueError("messages are required") context = await self._retrieve_context( - args["query"], + final_args["messages"], bank_ids, ) if context is None: diff --git a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py index cd0468d93..162e82d62 100644 --- a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py +++ b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py @@ -4,11 +4,18 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import requests -from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime +from llama_stack.apis.common.content_types import URL +from llama_stack.apis.tools import ( + Tool, + ToolDef, + ToolInvocationResult, + ToolParameter, + ToolRuntime, +) from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.datatypes import ToolsProtocolPrivate @@ -41,8 +48,22 @@ class BraveSearchToolRuntimeImpl( ) return provider_data.api_key - async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]: - raise NotImplementedError("Brave search tool group not supported") + async def list_tools( + self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + ) -> List[ToolDef]: + return [ + ToolDef( + name="web_search", + description="Search the web for information", + parameters=[ + ToolParameter( + name="query", + description="The query to search for", + parameter_type="string", + ) + ], + ) + ] async def invoke_tool( self, tool_name: str, args: Dict[str, Any] diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index 19ada8457..dd2bb5e5e 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -4,20 +4,18 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from urllib.parse import urlparse from mcp import ClientSession from mcp.client.sse import sse_client +from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( - MCPToolGroupDef, ToolDef, - ToolGroupDef, ToolInvocationResult, ToolParameter, ToolRuntime, - UserDefinedToolDef, ) from llama_stack.providers.datatypes import ToolsProtocolPrivate @@ -31,12 +29,14 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): async def initialize(self): pass - async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: - if not isinstance(tool_group, MCPToolGroupDef): - raise ValueError(f"Unsupported tool group type: {type(tool_group)}") + async def list_tools( + self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + ) -> List[ToolDef]: + if mcp_endpoint is None: + raise ValueError("mcp_endpoint is required") tools = [] - async with sse_client(tool_group.endpoint.uri) as streams: + async with sse_client(mcp_endpoint.uri) as streams: async with ClientSession(*streams) as session: await session.initialize() tools_result = await session.list_tools() @@ -53,12 +53,12 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): ) ) tools.append( - UserDefinedToolDef( + ToolDef( name=tool.name, description=tool.description, parameters=parameters, metadata={ - "endpoint": tool_group.endpoint.uri, + "endpoint": mcp_endpoint.uri, }, ) ) diff --git a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py index f4e980929..6dc515be3 100644 --- a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py +++ b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py @@ -5,11 +5,18 @@ # the root directory of this source tree. import json -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import requests -from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime +from llama_stack.apis.common.content_types import URL +from llama_stack.apis.tools import ( + Tool, + ToolDef, + ToolInvocationResult, + ToolParameter, + ToolRuntime, +) from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.datatypes import ToolsProtocolPrivate @@ -42,8 +49,22 @@ class TavilySearchToolRuntimeImpl( ) return provider_data.api_key - async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]: - raise NotImplementedError("Tavily search tool group not supported") + async def list_tools( + self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + ) -> List[ToolDef]: + return [ + ToolDef( + name="web_search", + description="Search the web for information", + parameters=[ + ToolParameter( + name="query", + description="The query to search for", + parameter_type="string", + ) + ], + ) + ] async def invoke_tool( self, tool_name: str, args: Dict[str, Any] diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 18dc90420..fb22e976e 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -45,8 +45,7 @@ def common_params(inference_model): sampling_params=SamplingParams(temperature=0.7, top_p=0.95), input_shields=[], output_shields=[], - available_tools=[], - preprocessing_tools=[], + toolgroups=[], max_infer_iters=5, ) @@ -83,27 +82,27 @@ def query_attachment_messages(): ] -async def create_agent_turn_with_search_tool( +async def create_agent_turn_with_toolgroup( agents_stack: Dict[str, object], search_query_messages: List[object], common_params: Dict[str, str], - tool_name: str, + toolgroup_name: str, ) -> None: """ - Create an agent turn with a search tool. + Create an agent turn with a toolgroup. Args: agents_stack (Dict[str, object]): The agents stack. search_query_messages (List[object]): The search query messages. common_params (Dict[str, str]): The common parameters. - search_tool_definition (SearchToolDefinition): The search tool definition. + toolgroup_name (str): The name of the toolgroup. """ - # Create an agent with the search tool + # Create an agent with the toolgroup agent_config = AgentConfig( **{ **common_params, - "tools": [tool_name], + "toolgroups": [toolgroup_name], } ) @@ -249,7 +248,7 @@ class TestAgents: agent_config = AgentConfig( **{ **common_params, - "tools": ["memory"], + "toolgroups": ["builtin::memory"], "tool_choice": ToolChoice.auto, } ) @@ -289,13 +288,58 @@ class TestAgents: if "TAVILY_SEARCH_API_KEY" not in os.environ: pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test") - await create_agent_turn_with_search_tool( - agents_stack, - search_query_messages, - common_params, - "brave_search", + # Create an agent with the toolgroup + agent_config = AgentConfig( + **{ + **common_params, + "toolgroups": ["builtin::web_search"], + } ) + agent_id, session_id = await create_agent_session( + agents_stack.impls[Api.agents], agent_config + ) + turn_request = dict( + agent_id=agent_id, + session_id=session_id, + messages=search_query_messages, + stream=True, + ) + + turn_response = [ + chunk + async for chunk in await agents_stack.impls[Api.agents].create_agent_turn( + **turn_request + ) + ] + + assert len(turn_response) > 0 + assert all( + isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response + ) + + check_event_types(turn_response) + + # Check for tool execution events + tool_execution_events = [ + chunk + for chunk in turn_response + if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload) + and chunk.event.payload.step_details.step_type + == StepType.tool_execution.value + ] + assert len(tool_execution_events) > 0, "No tool execution events found" + + # Check the tool execution details + tool_execution = tool_execution_events[0].event.payload.step_details + assert isinstance(tool_execution, ToolExecutionStep) + assert len(tool_execution.tool_calls) > 0 + actual_tool_name = tool_execution.tool_calls[0].tool_name + assert actual_tool_name == "web_search" + assert len(tool_execution.tool_responses) > 0 + + check_turn_complete_event(turn_response, session_id, search_query_messages) + def check_event_types(turn_response): event_types = [chunk.event.payload.event_type for chunk in turn_response] diff --git a/llama_stack/providers/tests/tools/fixtures.py b/llama_stack/providers/tests/tools/fixtures.py index 58defd57d..a9f923c87 100644 --- a/llama_stack/providers/tests/tools/fixtures.py +++ b/llama_stack/providers/tests/tools/fixtures.py @@ -8,16 +8,9 @@ import os import pytest import pytest_asyncio -from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.models import ModelInput, ModelType -from llama_stack.apis.tools import ( - BuiltInToolDef, - ToolGroupInput, - ToolParameter, - UserDefinedToolDef, - UserDefinedToolGroupDef, -) +from llama_stack.apis.tools import ToolGroupInput from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.tests.resolver import construct_stack_for_test @@ -47,30 +40,7 @@ def tool_runtime_memory_and_search() -> ProviderFixture: @pytest.fixture(scope="session") def tool_group_input_memory() -> ToolGroupInput: return ToolGroupInput( - tool_group_id="memory_group", - tool_group=UserDefinedToolGroupDef( - tools=[ - UserDefinedToolDef( - name="memory", - description="Query the memory bank", - parameters=[ - ToolParameter( - name="input_messages", - description="The input messages to search for in memory", - parameter_type="list", - required=True, - ), - ], - metadata={ - "config": { - "memory_bank_configs": [ - {"bank_id": "test_bank", "type": "vector"} - ] - } - }, - ) - ], - ), + toolgroup_id="builtin::memory", provider_id="memory-runtime", ) @@ -78,10 +48,7 @@ def tool_group_input_memory() -> ToolGroupInput: @pytest.fixture(scope="session") def tool_group_input_tavily_search() -> ToolGroupInput: return ToolGroupInput( - tool_group_id="tavily_search_group", - tool_group=UserDefinedToolGroupDef( - tools=[BuiltInToolDef(built_in_type=BuiltinTool.brave_search, metadata={})], - ), + toolgroup_id="builtin::web_search", provider_id="tavily-search", ) diff --git a/llama_stack/providers/tests/tools/test_tools.py b/llama_stack/providers/tests/tools/test_tools.py index f33b4a61d..917db55e1 100644 --- a/llama_stack/providers/tests/tools/test_tools.py +++ b/llama_stack/providers/tests/tools/test_tools.py @@ -43,8 +43,8 @@ def sample_documents(): class TestTools: @pytest.mark.asyncio - async def test_brave_search_tool(self, tools_stack, sample_search_query): - """Test the Brave search tool functionality.""" + async def test_web_search_tool(self, tools_stack, sample_search_query): + """Test the web search tool functionality.""" if "TAVILY_SEARCH_API_KEY" not in os.environ: pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test") @@ -52,7 +52,7 @@ class TestTools: # Execute the tool response = await tools_impl.invoke_tool( - tool_name="brave_search", args={"query": sample_search_query} + tool_name="web_search", args={"query": sample_search_query} ) # Verify the response @@ -89,11 +89,12 @@ class TestTools: response = await tools_impl.invoke_tool( tool_name="memory", args={ - "input_messages": [ + "messages": [ UserMessage( content="What are the main topics covered in the documentation?", ) ], + "memory_bank_ids": ["test_bank"], }, )