precommit

This commit is contained in:
Xi Yan 2025-03-22 15:33:52 -07:00
parent 4da8378ff8
commit 4837a6fa0e

View file

@ -6,14 +6,12 @@
import copy import copy
import json import json
import os
import re import re
import secrets import secrets
import string import string
import uuid import uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import AsyncGenerator, List, Optional, Union from typing import AsyncGenerator, List, Optional, Union
from urllib.parse import urlparse
import httpx import httpx
@ -40,10 +38,10 @@ from llama_stack.apis.agents import (
Turn, Turn,
) )
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
URL,
TextContentItem, TextContentItem,
ToolCallDelta, ToolCallDelta,
ToolCallParseStatus, ToolCallParseStatus,
URL,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
@ -60,7 +58,6 @@ from llama_stack.apis.inference import (
) )
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
RAGDocument,
ToolGroups, ToolGroups,
ToolInvocationResult, ToolInvocationResult,
ToolRuntime, ToolRuntime,
@ -80,9 +77,7 @@ from .safety import SafetyException, ShieldRunnerMixin
def make_random_string(length: int = 8): def make_random_string(length: int = 8):
return "".join( return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
)
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})") TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
@ -181,9 +176,7 @@ class ChatAgent(ShieldRunnerMixin):
messages.extend(self.turn_to_messages(turn)) messages.extend(self.turn_to_messages(turn))
return messages return messages
async def create_and_execute_turn( async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
self, request: AgentTurnCreateRequest
) -> AsyncGenerator:
await self._initialize_tools(request.toolgroups) await self._initialize_tools(request.toolgroups)
async with tracing.span("create_and_execute_turn") as span: async with tracing.span("create_and_execute_turn") as span:
span.set_attribute("session_id", request.session_id) span.set_attribute("session_id", request.session_id)
@ -224,16 +217,13 @@ class ChatAgent(ShieldRunnerMixin):
messages = await self.get_messages_from_turns(turns) messages = await self.get_messages_from_turns(turns)
if is_resume: if is_resume:
tool_response_messages = [ tool_response_messages = [
ToolResponseMessage(call_id=x.call_id, content=x.content) ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses
for x in request.tool_responses
] ]
messages.extend(tool_response_messages) messages.extend(tool_response_messages)
last_turn = turns[-1] last_turn = turns[-1]
last_turn_messages = self.turn_to_messages(last_turn) last_turn_messages = self.turn_to_messages(last_turn)
last_turn_messages = [ last_turn_messages = [
x x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
for x in last_turn_messages
if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
] ]
last_turn_messages.extend(tool_response_messages) last_turn_messages.extend(tool_response_messages)
@ -243,31 +233,17 @@ class ChatAgent(ShieldRunnerMixin):
# mark tool execution step as complete # mark tool execution step as complete
# if there's no tool execution in progress step (due to storage, or tool call parsing on client), # if there's no tool execution in progress step (due to storage, or tool call parsing on client),
# we'll create a new tool execution step with current time # we'll create a new tool execution step with current time
in_progress_tool_call_step = ( in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
await self.storage.get_in_progress_tool_call_step( request.session_id, request.turn_id
request.session_id, request.turn_id
)
) )
now = datetime.now(timezone.utc).isoformat() now = datetime.now(timezone.utc).isoformat()
tool_execution_step = ToolExecutionStep( tool_execution_step = ToolExecutionStep(
step_id=( step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
in_progress_tool_call_step.step_id
if in_progress_tool_call_step
else str(uuid.uuid4())
),
turn_id=request.turn_id, turn_id=request.turn_id,
tool_calls=( tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
in_progress_tool_call_step.tool_calls
if in_progress_tool_call_step
else []
),
tool_responses=request.tool_responses, tool_responses=request.tool_responses,
completed_at=now, completed_at=now,
started_at=( started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
in_progress_tool_call_step.started_at
if in_progress_tool_call_step
else now
),
) )
steps.append(tool_execution_step) steps.append(tool_execution_step)
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
@ -301,14 +277,9 @@ class ChatAgent(ShieldRunnerMixin):
output_message = chunk output_message = chunk
continue continue
assert isinstance( assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
chunk, AgentTurnResponseStreamChunk
), f"Unexpected type {type(chunk)}"
event = chunk.event event = chunk.event
if ( if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
event.payload.event_type
== AgentTurnResponseEventType.step_complete.value
):
steps.append(event.payload.step_details) steps.append(event.payload.step_details)
yield chunk yield chunk
@ -490,19 +461,13 @@ class ChatAgent(ShieldRunnerMixin):
for tool_name in self.tool_name_to_args.keys(): for tool_name in self.tool_name_to_args.keys():
if tool_name == MEMORY_QUERY_TOOL: if tool_name == MEMORY_QUERY_TOOL:
if "vector_db_ids" not in self.tool_name_to_args[tool_name]: if "vector_db_ids" not in self.tool_name_to_args[tool_name]:
self.tool_name_to_args[tool_name]["vector_db_ids"] = [ self.tool_name_to_args[tool_name]["vector_db_ids"] = [session_info.vector_db_id]
session_info.vector_db_id
]
else: else:
self.tool_name_to_args[tool_name]["vector_db_ids"].append( self.tool_name_to_args[tool_name]["vector_db_ids"].append(session_info.vector_db_id)
session_info.vector_db_id
)
output_attachments = [] output_attachments = []
n_iter = ( n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
)
# Build a map of custom tools to their definitions for faster lookup # Build a map of custom tools to their definitions for faster lookup
client_tools = {} client_tools = {}
@ -583,16 +548,12 @@ class ChatAgent(ShieldRunnerMixin):
span.set_attribute("stop_reason", stop_reason) span.set_attribute("stop_reason", stop_reason)
span.set_attribute( span.set_attribute(
"input", "input",
json.dumps( json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
[json.loads(m.model_dump_json()) for m in input_messages]
),
) )
output_attr = json.dumps( output_attr = json.dumps(
{ {
"content": content, "content": content,
"tool_calls": [ "tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
json.loads(t.model_dump_json()) for t in tool_calls
],
} }
) )
span.set_attribute("output", output_attr) span.set_attribute("output", output_attr)
@ -656,9 +617,7 @@ class ChatAgent(ShieldRunnerMixin):
message.content = [message.content] + output_attachments message.content = [message.content] + output_attachments
yield message yield message
else: else:
logger.debug( logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
f"completion message with EOM (iter: {n_iter}): {str(message)}"
)
input_messages = input_messages + [message] input_messages = input_messages + [message]
else: else:
input_messages = input_messages + [message] input_messages = input_messages + [message]
@ -707,9 +666,7 @@ class ChatAgent(ShieldRunnerMixin):
"input": message.model_dump_json(), "input": message.model_dump_json(),
}, },
) as span: ) as span:
tool_execution_start_time = datetime.now( tool_execution_start_time = datetime.now(timezone.utc).isoformat()
timezone.utc
).isoformat()
tool_result = await self.execute_tool_call_maybe( tool_result = await self.execute_tool_call_maybe(
session_id, session_id,
tool_call, tool_call,
@ -758,9 +715,7 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: add tool-input touchpoint and a "start" event for this step also # TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially # but that needs a lot more refactoring of Tool code potentially
if (type(result_message.content) is str) and ( if (type(result_message.content) is str) and (
out_attachment := _interpret_content_as_attachment( out_attachment := _interpret_content_as_attachment(result_message.content)
result_message.content
)
): ):
# NOTE: when we push this message back to the model, the model may ignore the # NOTE: when we push this message back to the model, the model may ignore the
# attached file path etc. since the model is trained to only provide a user message # attached file path etc. since the model is trained to only provide a user message
@ -797,24 +752,16 @@ class ChatAgent(ShieldRunnerMixin):
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
) -> None: ) -> None:
toolgroup_to_args = {} toolgroup_to_args = {}
for toolgroup in (self.agent_config.toolgroups or []) + ( for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []):
toolgroups_for_turn or []
):
if isinstance(toolgroup, AgentToolGroupWithArgs): if isinstance(toolgroup, AgentToolGroupWithArgs):
tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name) tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name)
toolgroup_to_args[tool_group_name] = toolgroup.args toolgroup_to_args[tool_group_name] = toolgroup.args
# Determine which tools to include # Determine which tools to include
tool_groups_to_include = ( tool_groups_to_include = toolgroups_for_turn or self.agent_config.toolgroups or []
toolgroups_for_turn or self.agent_config.toolgroups or []
)
agent_config_toolgroups = [] agent_config_toolgroups = []
for toolgroup in tool_groups_to_include: for toolgroup in tool_groups_to_include:
name = ( name = toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup
toolgroup.name
if isinstance(toolgroup, AgentToolGroupWithArgs)
else toolgroup
)
if name not in agent_config_toolgroups: if name not in agent_config_toolgroups:
agent_config_toolgroups.append(name) agent_config_toolgroups.append(name)
@ -840,32 +787,20 @@ class ChatAgent(ShieldRunnerMixin):
}, },
) )
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups: for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
toolgroup_name, input_tool_name = self._parse_toolgroup_name( toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
toolgroup_name_with_maybe_tool_name
)
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name) tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
if not tools.data: if not tools.data:
available_tool_groups = ", ".join( available_tool_groups = ", ".join(
[ [t.identifier for t in (await self.tool_groups_api.list_tool_groups()).data]
t.identifier
for t in (await self.tool_groups_api.list_tool_groups()).data
]
) )
raise ValueError( raise ValueError(f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}")
f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}" if input_tool_name is not None and not any(tool.identifier == input_tool_name for tool in tools.data):
)
if input_tool_name is not None and not any(
tool.identifier == input_tool_name for tool in tools.data
):
raise ValueError( raise ValueError(
f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}" f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
) )
for tool_def in tools.data: for tool_def in tools.data:
if ( if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
toolgroup_name.startswith("builtin")
and toolgroup_name != RAG_TOOL_GROUP
):
identifier: str | BuiltinTool | None = tool_def.identifier identifier: str | BuiltinTool | None = tool_def.identifier
if identifier == "web_search": if identifier == "web_search":
identifier = BuiltinTool.brave_search identifier = BuiltinTool.brave_search
@ -894,18 +829,14 @@ class ChatAgent(ShieldRunnerMixin):
for param in tool_def.parameters for param in tool_def.parameters
}, },
) )
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get( tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})
toolgroup_name, {}
)
self.tool_defs, self.tool_name_to_args = ( self.tool_defs, self.tool_name_to_args = (
list(tool_name_to_def.values()), list(tool_name_to_def.values()),
tool_name_to_args, tool_name_to_args,
) )
def _parse_toolgroup_name( def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
self, toolgroup_name_with_maybe_tool_name: str
) -> tuple[str, Optional[str]]:
"""Parse a toolgroup name into its components. """Parse a toolgroup name into its components.
Args: Args:
@ -941,9 +872,7 @@ class ChatAgent(ShieldRunnerMixin):
else: else:
tool_name_str = tool_name tool_name_str = tool_name
logger.info( logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
f"executing tool call: {tool_name_str} with args: {tool_call.arguments}"
)
result = await self.tool_runtime_api.invoke_tool( result = await self.tool_runtime_api.invoke_tool(
tool_name=tool_name_str, tool_name=tool_name_str,
kwargs={ kwargs={
@ -965,6 +894,7 @@ async def load_data_from_url(url: str) -> str:
return resp return resp
return "" return ""
async def get_raw_document_text(document: Document) -> str: async def get_raw_document_text(document: Document) -> str:
if isinstance(document.content, URL): if isinstance(document.content, URL):
return await load_data_from_url(document.content.uri) return await load_data_from_url(document.content.uri)
@ -976,9 +906,7 @@ async def get_raw_document_text(document: Document) -> str:
elif isinstance(document.content, TextContentItem): elif isinstance(document.content, TextContentItem):
return document.content.text return document.content.text
else: else:
raise ValueError( raise ValueError(f"Unexpected document content type: {type(document.content)}")
f"Unexpected document content type: {type(document.content)}"
)
def _interpret_content_as_attachment( def _interpret_content_as_attachment(
@ -993,4 +921,4 @@ def _interpret_content_as_attachment(
mime_type=data["mimetype"], mime_type=data["mimetype"],
) )
return None return None