forked from phoenix-oss/llama-stack-mirror
chore: refactor Agent toolgroup processing (#1381)
Summary: Refactoring only. Centralize logic to preprocess toolgroup to one place. Test Plan: LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/api/agents/test_agents.py --safety-shield meta-llama/Llama-Guard-3-8B --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/meta-llama/llama-stack/pull/1381). * #1384 * __->__ #1381
This commit is contained in:
parent
99bbe0e70b
commit
41c9bca1aa
1 changed files with 55 additions and 65 deletions
|
@ -12,7 +12,7 @@ import secrets
|
||||||
import string
|
import string
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -181,6 +181,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||||
|
await self._initialize_tools(request.toolgroups)
|
||||||
async with tracing.span("create_and_execute_turn") as span:
|
async with tracing.span("create_and_execute_turn") as span:
|
||||||
span.set_attribute("session_id", request.session_id)
|
span.set_attribute("session_id", request.session_id)
|
||||||
span.set_attribute("agent_id", self.agent_id)
|
span.set_attribute("agent_id", self.agent_id)
|
||||||
|
@ -191,6 +192,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
||||||
|
await self._initialize_tools()
|
||||||
async with tracing.span("resume_turn") as span:
|
async with tracing.span("resume_turn") as span:
|
||||||
span.set_attribute("agent_id", self.agent_id)
|
span.set_attribute("agent_id", self.agent_id)
|
||||||
span.set_attribute("session_id", request.session_id)
|
span.set_attribute("session_id", request.session_id)
|
||||||
|
@ -275,7 +277,6 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
sampling_params=self.agent_config.sampling_params,
|
sampling_params=self.agent_config.sampling_params,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
documents=request.documents if not is_resume else None,
|
documents=request.documents if not is_resume else None,
|
||||||
toolgroups_for_turn=request.toolgroups if not is_resume else None,
|
|
||||||
):
|
):
|
||||||
if isinstance(chunk, CompletionMessage):
|
if isinstance(chunk, CompletionMessage):
|
||||||
output_message = chunk
|
output_message = chunk
|
||||||
|
@ -327,7 +328,6 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
documents: Optional[List[Document]] = None,
|
documents: Optional[List[Document]] = None,
|
||||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||||
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
||||||
|
@ -350,7 +350,6 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
sampling_params,
|
sampling_params,
|
||||||
stream,
|
stream,
|
||||||
documents,
|
documents,
|
||||||
toolgroups_for_turn,
|
|
||||||
):
|
):
|
||||||
if isinstance(res, bool):
|
if isinstance(res, bool):
|
||||||
return
|
return
|
||||||
|
@ -451,30 +450,17 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
documents: Optional[List[Document]] = None,
|
documents: Optional[List[Document]] = None,
|
||||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
# TODO: simplify all of this code, it can be simpler
|
|
||||||
toolgroup_args = {}
|
|
||||||
toolgroups = set()
|
|
||||||
for toolgroup in self.agent_config.toolgroups + (toolgroups_for_turn or []):
|
|
||||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
|
||||||
tool_group_name, tool_name = self._parse_toolgroup_name(toolgroup.name)
|
|
||||||
toolgroups.add(tool_group_name)
|
|
||||||
toolgroup_args[tool_group_name] = toolgroup.args
|
|
||||||
else:
|
|
||||||
toolgroups.add(toolgroup)
|
|
||||||
|
|
||||||
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
|
|
||||||
if documents:
|
if documents:
|
||||||
await self.handle_documents(session_id, documents, input_messages, tool_defs)
|
await self.handle_documents(session_id, documents, input_messages)
|
||||||
|
|
||||||
session_info = await self.storage.get_session_info(session_id)
|
session_info = await self.storage.get_session_info(session_id)
|
||||||
# if the session has a memory bank id, let the memory tool use it
|
# if the session has a memory bank id, let the memory tool use it
|
||||||
if session_info and session_info.vector_db_id:
|
if session_info and session_info.vector_db_id:
|
||||||
if RAG_TOOL_GROUP not in toolgroup_args:
|
if RAG_TOOL_GROUP not in self.toolgroup_to_args:
|
||||||
toolgroup_args[RAG_TOOL_GROUP] = {"vector_db_ids": [session_info.vector_db_id]}
|
self.toolgroup_to_args[RAG_TOOL_GROUP] = {"vector_db_ids": [session_info.vector_db_id]}
|
||||||
else:
|
else:
|
||||||
toolgroup_args[RAG_TOOL_GROUP]["vector_db_ids"].append(session_info.vector_db_id)
|
self.toolgroup_to_args[RAG_TOOL_GROUP]["vector_db_ids"].append(session_info.vector_db_id)
|
||||||
|
|
||||||
output_attachments = []
|
output_attachments = []
|
||||||
|
|
||||||
|
@ -504,7 +490,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
async for chunk in await self.inference_api.chat_completion(
|
async for chunk in await self.inference_api.chat_completion(
|
||||||
self.agent_config.model,
|
self.agent_config.model,
|
||||||
input_messages,
|
input_messages,
|
||||||
tools=tool_defs,
|
tools=self.tool_defs,
|
||||||
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
|
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
|
||||||
response_format=self.agent_config.response_format,
|
response_format=self.agent_config.response_format,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
@ -686,12 +672,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
) as span:
|
) as span:
|
||||||
tool_execution_start_time = datetime.now().astimezone().isoformat()
|
tool_execution_start_time = datetime.now().astimezone().isoformat()
|
||||||
tool_call = message.tool_calls[0]
|
tool_call = message.tool_calls[0]
|
||||||
tool_result = await execute_tool_call_maybe(
|
tool_result = await self.execute_tool_call_maybe(
|
||||||
self.tool_runtime_api,
|
|
||||||
session_id,
|
session_id,
|
||||||
tool_call,
|
tool_call,
|
||||||
toolgroup_args,
|
|
||||||
tool_to_group,
|
|
||||||
)
|
)
|
||||||
if tool_result.content is None:
|
if tool_result.content is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -744,6 +727,15 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
input_messages = input_messages + [message, result_message]
|
input_messages = input_messages + [message, result_message]
|
||||||
|
|
||||||
|
async def _initialize_tools(self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None):
|
||||||
|
self.toolgroup_to_args = {}
|
||||||
|
for toolgroup in self.agent_config.toolgroups + (toolgroups_for_turn or []):
|
||||||
|
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||||
|
tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name)
|
||||||
|
self.toolgroup_to_args[tool_group_name] = toolgroup.args
|
||||||
|
|
||||||
|
self.tool_defs, self.tool_name_to_group_id = await self._get_tool_defs(toolgroups_for_turn)
|
||||||
|
|
||||||
async def _get_tool_defs(
|
async def _get_tool_defs(
|
||||||
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
|
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
|
||||||
) -> Tuple[List[ToolDefinition], Dict[str, str]]:
|
) -> Tuple[List[ToolDefinition], Dict[str, str]]:
|
||||||
|
@ -756,7 +748,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
agent_config_toolgroups.append(name)
|
agent_config_toolgroups.append(name)
|
||||||
|
|
||||||
tool_name_to_def = {}
|
tool_name_to_def = {}
|
||||||
tool_to_group = {}
|
tool_name_to_group_id = {}
|
||||||
|
|
||||||
for tool_def in self.agent_config.client_tools:
|
for tool_def in self.agent_config.client_tools:
|
||||||
if tool_name_to_def.get(tool_def.name, None):
|
if tool_name_to_def.get(tool_def.name, None):
|
||||||
|
@ -774,7 +766,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
for param in tool_def.parameters
|
for param in tool_def.parameters
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
tool_to_group[tool_def.name] = "__client_tools__"
|
tool_name_to_group_id[tool_def.name] = "__client_tools__"
|
||||||
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
||||||
toolgroup_name, tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
toolgroup_name, tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
||||||
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
||||||
|
@ -813,7 +805,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
for param in tool_def.parameters
|
for param in tool_def.parameters
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
tool_to_group[built_in_type] = tool_def.toolgroup_id
|
tool_name_to_group_id[built_in_type] = tool_def.toolgroup_id
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if tool_name_to_def.get(tool_def.identifier, None):
|
if tool_name_to_def.get(tool_def.identifier, None):
|
||||||
|
@ -832,9 +824,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
for param in tool_def.parameters
|
for param in tool_def.parameters
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
|
tool_name_to_group_id[tool_def.identifier] = tool_def.toolgroup_id
|
||||||
|
|
||||||
return list(tool_name_to_def.values()), tool_to_group
|
return list(tool_name_to_def.values()), tool_name_to_group_id
|
||||||
|
|
||||||
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
|
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
|
||||||
"""Parse a toolgroup name into its components.
|
"""Parse a toolgroup name into its components.
|
||||||
|
@ -853,15 +845,44 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
tool_group, tool_name = split_names[0], None
|
tool_group, tool_name = split_names[0], None
|
||||||
return tool_group, tool_name
|
return tool_group, tool_name
|
||||||
|
|
||||||
|
async def execute_tool_call_maybe(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
tool_call: ToolCall,
|
||||||
|
) -> ToolInvocationResult:
|
||||||
|
name = tool_call.tool_name
|
||||||
|
group_name = self.tool_name_to_group_id.get(name, None)
|
||||||
|
if group_name is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Tool {name} not found in any tool group, available tools: {', '.join(self.tool_name_to_group_id.keys())}"
|
||||||
|
)
|
||||||
|
if isinstance(name, BuiltinTool):
|
||||||
|
if name == BuiltinTool.brave_search:
|
||||||
|
name = WEB_SEARCH_TOOL
|
||||||
|
else:
|
||||||
|
name = name.value
|
||||||
|
|
||||||
|
logger.info(f"executing tool call: {name} with args: {tool_call.arguments}")
|
||||||
|
result = await self.tool_runtime_api.invoke_tool(
|
||||||
|
tool_name=name,
|
||||||
|
kwargs={
|
||||||
|
"session_id": session_id,
|
||||||
|
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
||||||
|
**tool_call.arguments,
|
||||||
|
**self.toolgroup_to_args.get(group_name, {}),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logger.debug(f"tool call {name} completed with result: {result}")
|
||||||
|
return result
|
||||||
|
|
||||||
async def handle_documents(
|
async def handle_documents(
|
||||||
self,
|
self,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
documents: List[Document],
|
documents: List[Document],
|
||||||
input_messages: List[Message],
|
input_messages: List[Message],
|
||||||
tool_defs: Dict[str, ToolDefinition],
|
|
||||||
) -> None:
|
) -> None:
|
||||||
memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in tool_defs)
|
memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in self.tool_defs)
|
||||||
code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in tool_defs)
|
code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in self.tool_defs)
|
||||||
content_items = []
|
content_items = []
|
||||||
url_items = []
|
url_items = []
|
||||||
pattern = re.compile("^(https?://|file://|data:)")
|
pattern = re.compile("^(https?://|file://|data:)")
|
||||||
|
@ -994,37 +1015,6 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def execute_tool_call_maybe(
|
|
||||||
tool_runtime_api: ToolRuntime,
|
|
||||||
session_id: str,
|
|
||||||
tool_call: ToolCall,
|
|
||||||
toolgroup_args: Dict[str, Dict[str, Any]],
|
|
||||||
tool_to_group: Dict[str, str],
|
|
||||||
) -> ToolInvocationResult:
|
|
||||||
name = tool_call.tool_name
|
|
||||||
group_name = tool_to_group.get(name, None)
|
|
||||||
if group_name is None:
|
|
||||||
raise ValueError(f"Tool {name} not found in any tool group")
|
|
||||||
if isinstance(name, BuiltinTool):
|
|
||||||
if name == BuiltinTool.brave_search:
|
|
||||||
name = WEB_SEARCH_TOOL
|
|
||||||
else:
|
|
||||||
name = name.value
|
|
||||||
|
|
||||||
logger.info(f"executing tool call: {name} with args: {tool_call.arguments}")
|
|
||||||
result = await tool_runtime_api.invoke_tool(
|
|
||||||
tool_name=name,
|
|
||||||
kwargs={
|
|
||||||
"session_id": session_id,
|
|
||||||
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
|
||||||
**tool_call.arguments,
|
|
||||||
**toolgroup_args.get(group_name, {}),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
logger.info(f"tool call {name} completed with result: {result}")
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def _interpret_content_as_attachment(
|
def _interpret_content_as_attachment(
|
||||||
content: str,
|
content: str,
|
||||||
) -> Optional[Attachment]:
|
) -> Optional[Attachment]:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue