mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 02:09:40 +00:00
simplify toolgroups registration
This commit is contained in:
parent
ba242c04cc
commit
f9a98c278a
15 changed files with 350 additions and 256 deletions
|
|
@ -13,7 +13,7 @@ import secrets
|
|||
import string
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, Dict, List, Optional
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
|
@ -21,8 +21,8 @@ from llama_models.llama3.api.datatypes import BuiltinTool, ToolCall, ToolParamDe
|
|||
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
AgentTool,
|
||||
AgentToolWithArgs,
|
||||
AgentToolGroup,
|
||||
AgentToolGroupWithArgs,
|
||||
AgentTurnCreateRequest,
|
||||
AgentTurnResponseEvent,
|
||||
AgentTurnResponseEventType,
|
||||
|
|
@ -76,6 +76,10 @@ def make_random_string(length: int = 8):
|
|||
|
||||
|
||||
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||
MEMORY_TOOL_GROUP_ID = "builtin::memory"
|
||||
MEMORY_QUERY_TOOL = "query_memory"
|
||||
CODE_INTERPRETER_TOOL = "code_interpreter"
|
||||
WEB_SEARCH_TOOL = "web_search"
|
||||
|
||||
|
||||
class ChatAgent(ShieldRunnerMixin):
|
||||
|
|
@ -192,7 +196,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
documents=request.documents,
|
||||
tools_for_turn=request.tools,
|
||||
toolgroups_for_turn=request.toolgroups,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
log.info(
|
||||
|
|
@ -243,7 +247,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
tools_for_turn: Optional[List[AgentTool]] = None,
|
||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||
) -> AsyncGenerator:
|
||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
||||
|
|
@ -266,7 +270,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
sampling_params,
|
||||
stream,
|
||||
documents,
|
||||
tools_for_turn,
|
||||
toolgroups_for_turn,
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
|
|
@ -362,21 +366,24 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
tools_for_turn: Optional[List[AgentTool]] = None,
|
||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||
) -> AsyncGenerator:
|
||||
tool_args = {}
|
||||
if tools_for_turn:
|
||||
for tool in tools_for_turn:
|
||||
if isinstance(tool, AgentToolWithArgs):
|
||||
tool_args[tool.name] = tool.args
|
||||
toolgroup_args = {}
|
||||
for toolgroup in self.agent_config.toolgroups:
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||
toolgroup_args[toolgroup.name] = toolgroup.args
|
||||
if toolgroups_for_turn:
|
||||
for toolgroup in toolgroups_for_turn:
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||
toolgroup_args[toolgroup.name] = toolgroup.args
|
||||
|
||||
tool_defs = await self._get_tool_defs(tools_for_turn)
|
||||
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
|
||||
if documents:
|
||||
await self.handle_documents(
|
||||
session_id, documents, input_messages, tool_defs
|
||||
)
|
||||
if "memory" in tool_defs and len(input_messages) > 0:
|
||||
with tracing.span("memory_tool") as span:
|
||||
if MEMORY_QUERY_TOOL in tool_defs and len(input_messages) > 0:
|
||||
with tracing.span(MEMORY_QUERY_TOOL) as span:
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
|
|
@ -386,18 +393,16 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
extra_args = tool_args.get("memory", {})
|
||||
tool_args = {
|
||||
# Query memory with the last message's content
|
||||
"query": input_messages[-1],
|
||||
**extra_args,
|
||||
query_args = {
|
||||
"messages": [msg.content for msg in input_messages],
|
||||
**toolgroup_args.get(MEMORY_TOOL_GROUP_ID, {}),
|
||||
}
|
||||
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
# if the session has a memory bank id, let the memory tool use it
|
||||
if session_info.memory_bank_id:
|
||||
tool_args["memory_bank_id"] = session_info.memory_bank_id
|
||||
serialized_args = tracing.serialize_value(tool_args)
|
||||
query_args["memory_bank_id"] = session_info.memory_bank_id
|
||||
serialized_args = tracing.serialize_value(query_args)
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
|
|
@ -415,8 +420,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name="memory",
|
||||
args=tool_args,
|
||||
tool_name=MEMORY_QUERY_TOOL,
|
||||
args=query_args,
|
||||
)
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
|
|
@ -485,7 +490,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tools=[
|
||||
tool
|
||||
for tool in tool_defs.values()
|
||||
if tool.tool_name != "memory"
|
||||
if tool_to_group.get(tool.tool_name, None)
|
||||
!= MEMORY_TOOL_GROUP_ID
|
||||
],
|
||||
tool_prompt_format=self.agent_config.tool_prompt_format,
|
||||
stream=True,
|
||||
|
|
@ -632,6 +638,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self.tool_runtime_api,
|
||||
session_id,
|
||||
[message],
|
||||
toolgroup_args,
|
||||
tool_to_group,
|
||||
)
|
||||
assert (
|
||||
len(result_messages) == 1
|
||||
|
|
@ -690,26 +698,37 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
n_iter += 1
|
||||
|
||||
async def _get_tool_defs(
|
||||
self, tools_for_turn: Optional[List[AgentTool]]
|
||||
self, toolgroups_for_turn: Optional[List[AgentToolGroup]]
|
||||
) -> Dict[str, ToolDefinition]:
|
||||
# Determine which tools to include
|
||||
agent_config_tools = set(
|
||||
tool.name if isinstance(tool, AgentToolWithArgs) else tool
|
||||
for tool in self.agent_config.tools
|
||||
agent_config_toolgroups = set(
|
||||
(
|
||||
toolgroup.name
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs)
|
||||
else toolgroup
|
||||
)
|
||||
for toolgroup in self.agent_config.toolgroups
|
||||
)
|
||||
tools_for_turn_set = (
|
||||
agent_config_tools
|
||||
if tools_for_turn is None
|
||||
toolgroups_for_turn_set = (
|
||||
agent_config_toolgroups
|
||||
if toolgroups_for_turn is None
|
||||
else {
|
||||
tool.name if isinstance(tool, AgentToolWithArgs) else tool
|
||||
for tool in tools_for_turn
|
||||
(
|
||||
toolgroup.name
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs)
|
||||
else toolgroup
|
||||
)
|
||||
for toolgroup in toolgroups_for_turn
|
||||
}
|
||||
)
|
||||
|
||||
ret = {}
|
||||
tool_def_map = {}
|
||||
tool_to_group = {}
|
||||
|
||||
for tool_def in self.agent_config.client_tools:
|
||||
ret[tool_def.name] = ToolDefinition(
|
||||
if tool_def_map.get(tool_def.name, None):
|
||||
raise ValueError(f"Tool {tool_def.name} already exists")
|
||||
tool_def_map[tool_def.name] = ToolDefinition(
|
||||
tool_name=tool_def.name,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
|
|
@ -722,41 +741,42 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
for param in tool_def.parameters
|
||||
},
|
||||
)
|
||||
|
||||
for tool_name in agent_config_tools:
|
||||
if tool_name not in tools_for_turn_set:
|
||||
tool_to_group[tool_def.name] = "__client_tools__"
|
||||
for toolgroup_name in agent_config_toolgroups:
|
||||
if toolgroup_name not in toolgroups_for_turn_set:
|
||||
continue
|
||||
tools = await self.tool_groups_api.list_tools(tool_group_id=toolgroup_name)
|
||||
for tool_def in tools:
|
||||
if tool_def.built_in_type:
|
||||
if tool_def_map.get(tool_def.built_in_type, None):
|
||||
raise ValueError(
|
||||
f"Tool {tool_def.built_in_type} already exists"
|
||||
)
|
||||
|
||||
tool_def = await self.tool_groups_api.get_tool(tool_name)
|
||||
if tool_def is None:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
|
||||
if tool_def.identifier.startswith("builtin::"):
|
||||
built_in_type = tool_def.identifier[len("builtin::") :]
|
||||
if built_in_type == "web_search":
|
||||
built_in_type = "brave_search"
|
||||
if built_in_type not in BuiltinTool.__members__:
|
||||
raise ValueError(f"Unknown built-in tool: {built_in_type}")
|
||||
ret[built_in_type] = ToolDefinition(
|
||||
tool_name=BuiltinTool(built_in_type)
|
||||
)
|
||||
continue
|
||||
|
||||
ret[tool_def.identifier] = ToolDefinition(
|
||||
tool_name=tool_def.identifier,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
tool_def_map[tool_def.built_in_type] = ToolDefinition(
|
||||
tool_name=tool_def.built_in_type
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
},
|
||||
)
|
||||
tool_to_group[tool_def.built_in_type] = tool_def.toolgroup_id
|
||||
continue
|
||||
|
||||
return ret
|
||||
if tool_def_map.get(tool_def.identifier, None):
|
||||
raise ValueError(f"Tool {tool_def.identifier} already exists")
|
||||
tool_def_map[tool_def.identifier] = ToolDefinition(
|
||||
tool_name=tool_def.identifier,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
},
|
||||
)
|
||||
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
|
||||
|
||||
return tool_def_map, tool_to_group
|
||||
|
||||
async def handle_documents(
|
||||
self,
|
||||
|
|
@ -765,8 +785,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
input_messages: List[Message],
|
||||
tool_defs: Dict[str, ToolDefinition],
|
||||
) -> None:
|
||||
memory_tool = tool_defs.get("memory", None)
|
||||
code_interpreter_tool = tool_defs.get("code_interpreter", None)
|
||||
memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None)
|
||||
code_interpreter_tool = tool_defs.get(CODE_INTERPRETER_TOOL, None)
|
||||
content_items = []
|
||||
url_items = []
|
||||
pattern = re.compile("^(https?://|file://|data:)")
|
||||
|
|
@ -903,7 +923,11 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
|||
|
||||
|
||||
async def execute_tool_call_maybe(
|
||||
tool_runtime_api: ToolRuntime, session_id: str, messages: List[CompletionMessage]
|
||||
tool_runtime_api: ToolRuntime,
|
||||
session_id: str,
|
||||
messages: List[CompletionMessage],
|
||||
toolgroup_args: Dict[str, Dict[str, Any]],
|
||||
tool_to_group: Dict[str, str],
|
||||
) -> List[ToolResponseMessage]:
|
||||
# While Tools.run interface takes a list of messages,
|
||||
# All tools currently only run on a single message
|
||||
|
|
@ -915,18 +939,26 @@ async def execute_tool_call_maybe(
|
|||
|
||||
tool_call = message.tool_calls[0]
|
||||
name = tool_call.tool_name
|
||||
group_name = tool_to_group.get(name, None)
|
||||
if group_name is None:
|
||||
raise ValueError(f"Tool {name} not found in any tool group")
|
||||
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
||||
tool_call_args = tool_call.arguments
|
||||
tool_call_args.update(toolgroup_args.get(group_name, {}))
|
||||
if isinstance(name, BuiltinTool):
|
||||
if name == BuiltinTool.brave_search:
|
||||
name = "builtin::web_search"
|
||||
name = WEB_SEARCH_TOOL
|
||||
else:
|
||||
name = "builtin::" + name.value
|
||||
name = name.value
|
||||
|
||||
result = await tool_runtime_api.invoke_tool(
|
||||
tool_name=name,
|
||||
args=dict(
|
||||
session_id=session_id,
|
||||
**tool_call.arguments,
|
||||
**tool_call_args,
|
||||
),
|
||||
)
|
||||
|
||||
return [
|
||||
ToolResponseMessage(
|
||||
call_id=tool_call.call_id,
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from llama_stack.apis.agents import (
|
|||
Agents,
|
||||
AgentSessionCreateResponse,
|
||||
AgentStepResponse,
|
||||
AgentTool,
|
||||
AgentToolGroup,
|
||||
AgentTurnCreateRequest,
|
||||
Document,
|
||||
Session,
|
||||
|
|
@ -147,7 +147,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
ToolResponseMessage,
|
||||
]
|
||||
],
|
||||
tools: Optional[List[AgentTool]] = None,
|
||||
tools: Optional[List[AgentToolGroup]] = None,
|
||||
documents: Optional[List[Document]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
) -> AsyncGenerator:
|
||||
|
|
|
|||
|
|
@ -7,9 +7,16 @@
|
|||
|
||||
import logging
|
||||
import tempfile
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
|
||||
from .code_execution import CodeExecutionContext, CodeExecutionRequest, CodeExecutor
|
||||
|
|
@ -35,8 +42,22 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
return
|
||||
|
||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
|
||||
raise NotImplementedError("Code interpreter tool group not supported")
|
||||
async def list_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return [
|
||||
ToolDef(
|
||||
name="code_interpreter",
|
||||
description="Execute code",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="code",
|
||||
description="The code to execute",
|
||||
parameter_type="string",
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import List
|
||||
|
||||
from jinja2 import Template
|
||||
|
||||
from llama_stack.apis.inference import Message, UserMessage
|
||||
|
|
@ -22,7 +24,7 @@ from .config import (
|
|||
|
||||
async def generate_rag_query(
|
||||
config: MemoryQueryGeneratorConfig,
|
||||
message: Message,
|
||||
messages: List[Message],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
|
@ -30,9 +32,9 @@ async def generate_rag_query(
|
|||
retrieving relevant information from the memory bank.
|
||||
"""
|
||||
if config.type == MemoryQueryGenerator.default.value:
|
||||
query = await default_rag_query_generator(config, message, **kwargs)
|
||||
query = await default_rag_query_generator(config, messages, **kwargs)
|
||||
elif config.type == MemoryQueryGenerator.llm.value:
|
||||
query = await llm_rag_query_generator(config, message, **kwargs)
|
||||
query = await llm_rag_query_generator(config, messages, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
|
||||
return query
|
||||
|
|
@ -40,21 +42,21 @@ async def generate_rag_query(
|
|||
|
||||
async def default_rag_query_generator(
|
||||
config: DefaultMemoryQueryGeneratorConfig,
|
||||
message: Message,
|
||||
messages: List[Message],
|
||||
**kwargs,
|
||||
):
|
||||
return interleaved_content_as_str(message.content)
|
||||
return config.sep.join(interleaved_content_as_str(m.content) for m in messages)
|
||||
|
||||
|
||||
async def llm_rag_query_generator(
|
||||
config: LLMMemoryQueryGeneratorConfig,
|
||||
message: Message,
|
||||
messages: List[Message],
|
||||
**kwargs,
|
||||
):
|
||||
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
|
||||
inference_api = kwargs["inference_api"]
|
||||
|
||||
m_dict = {"messages": [message.model_dump()]}
|
||||
m_dict = {"messages": [message.model_dump() for message in messages]}
|
||||
|
||||
template = Template(config.template)
|
||||
content = template.render(m_dict)
|
||||
|
|
|
|||
|
|
@ -10,13 +10,14 @@ import secrets
|
|||
import string
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent, Message
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.memory import Memory, QueryDocumentsResponse
|
||||
from llama_stack.apis.memory_banks import MemoryBanks
|
||||
from llama_stack.apis.tools import (
|
||||
ToolDef,
|
||||
ToolGroupDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
|
|
@ -50,17 +51,31 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]:
|
||||
return []
|
||||
async def list_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return [
|
||||
ToolDef(
|
||||
name="memory",
|
||||
description="Retrieve context from memory",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="input_messages",
|
||||
description="The input messages to search for",
|
||||
parameter_type="array",
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
async def _retrieve_context(
|
||||
self, message: Message, bank_ids: List[str]
|
||||
self, input_messages: List[str], bank_ids: List[str]
|
||||
) -> Optional[List[InterleavedContent]]:
|
||||
if not bank_ids:
|
||||
return None
|
||||
query = await generate_rag_query(
|
||||
self.config.query_generator_config,
|
||||
message,
|
||||
input_messages,
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
tasks = [
|
||||
|
|
@ -106,17 +121,22 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
self, tool_name: str, args: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
tool = await self.tool_store.get_tool(tool_name)
|
||||
tool_group = await self.tool_store.get_tool_group(tool.toolgroup_id)
|
||||
final_args = tool_group.args or {}
|
||||
final_args.update(args)
|
||||
config = MemoryToolConfig()
|
||||
if tool.metadata.get("config") is not None:
|
||||
if tool.metadata and tool.metadata.get("config") is not None:
|
||||
config = MemoryToolConfig(**tool.metadata["config"])
|
||||
if "memory_bank_id" in args:
|
||||
bank_ids = [args["memory_bank_id"]]
|
||||
if "memory_bank_ids" in final_args:
|
||||
bank_ids = final_args["memory_bank_ids"]
|
||||
else:
|
||||
bank_ids = [
|
||||
bank_config.bank_id for bank_config in config.memory_bank_configs
|
||||
]
|
||||
if "messages" not in final_args:
|
||||
raise ValueError("messages are required")
|
||||
context = await self._retrieve_context(
|
||||
args["query"],
|
||||
final_args["messages"],
|
||||
bank_ids,
|
||||
)
|
||||
if context is None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue