precommit

This commit is contained in:
Xi Yan 2025-03-22 15:33:52 -07:00
parent 4da8378ff8
commit 4837a6fa0e

View file

@ -6,14 +6,12 @@
import copy
import json
import os
import re
import secrets
import string
import uuid
from datetime import datetime, timezone
from typing import AsyncGenerator, List, Optional, Union
from urllib.parse import urlparse
import httpx
@ -40,10 +38,10 @@ from llama_stack.apis.agents import (
Turn,
)
from llama_stack.apis.common.content_types import (
URL,
TextContentItem,
ToolCallDelta,
ToolCallParseStatus,
URL,
)
from llama_stack.apis.inference import (
ChatCompletionResponseEventType,
@ -60,7 +58,6 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import (
RAGDocument,
ToolGroups,
ToolInvocationResult,
ToolRuntime,
@ -80,9 +77,7 @@ from .safety import SafetyException, ShieldRunnerMixin
def make_random_string(length: int = 8):
return "".join(
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
)
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
@ -181,9 +176,7 @@ class ChatAgent(ShieldRunnerMixin):
messages.extend(self.turn_to_messages(turn))
return messages
async def create_and_execute_turn(
self, request: AgentTurnCreateRequest
) -> AsyncGenerator:
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
await self._initialize_tools(request.toolgroups)
async with tracing.span("create_and_execute_turn") as span:
span.set_attribute("session_id", request.session_id)
@ -224,16 +217,13 @@ class ChatAgent(ShieldRunnerMixin):
messages = await self.get_messages_from_turns(turns)
if is_resume:
tool_response_messages = [
ToolResponseMessage(call_id=x.call_id, content=x.content)
for x in request.tool_responses
ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses
]
messages.extend(tool_response_messages)
last_turn = turns[-1]
last_turn_messages = self.turn_to_messages(last_turn)
last_turn_messages = [
x
for x in last_turn_messages
if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
]
last_turn_messages.extend(tool_response_messages)
@ -243,31 +233,17 @@ class ChatAgent(ShieldRunnerMixin):
# mark tool execution step as complete
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
# we'll create a new tool execution step with current time
in_progress_tool_call_step = (
await self.storage.get_in_progress_tool_call_step(
request.session_id, request.turn_id
)
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
request.session_id, request.turn_id
)
now = datetime.now(timezone.utc).isoformat()
tool_execution_step = ToolExecutionStep(
step_id=(
in_progress_tool_call_step.step_id
if in_progress_tool_call_step
else str(uuid.uuid4())
),
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
turn_id=request.turn_id,
tool_calls=(
in_progress_tool_call_step.tool_calls
if in_progress_tool_call_step
else []
),
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
tool_responses=request.tool_responses,
completed_at=now,
started_at=(
in_progress_tool_call_step.started_at
if in_progress_tool_call_step
else now
),
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
)
steps.append(tool_execution_step)
yield AgentTurnResponseStreamChunk(
@ -301,14 +277,9 @@ class ChatAgent(ShieldRunnerMixin):
output_message = chunk
continue
assert isinstance(
chunk, AgentTurnResponseStreamChunk
), f"Unexpected type {type(chunk)}"
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
event = chunk.event
if (
event.payload.event_type
== AgentTurnResponseEventType.step_complete.value
):
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
steps.append(event.payload.step_details)
yield chunk
@ -490,19 +461,13 @@ class ChatAgent(ShieldRunnerMixin):
for tool_name in self.tool_name_to_args.keys():
if tool_name == MEMORY_QUERY_TOOL:
if "vector_db_ids" not in self.tool_name_to_args[tool_name]:
self.tool_name_to_args[tool_name]["vector_db_ids"] = [
session_info.vector_db_id
]
self.tool_name_to_args[tool_name]["vector_db_ids"] = [session_info.vector_db_id]
else:
self.tool_name_to_args[tool_name]["vector_db_ids"].append(
session_info.vector_db_id
)
self.tool_name_to_args[tool_name]["vector_db_ids"].append(session_info.vector_db_id)
output_attachments = []
n_iter = (
await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
)
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
# Build a map of custom tools to their definitions for faster lookup
client_tools = {}
@ -583,16 +548,12 @@ class ChatAgent(ShieldRunnerMixin):
span.set_attribute("stop_reason", stop_reason)
span.set_attribute(
"input",
json.dumps(
[json.loads(m.model_dump_json()) for m in input_messages]
),
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
)
output_attr = json.dumps(
{
"content": content,
"tool_calls": [
json.loads(t.model_dump_json()) for t in tool_calls
],
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
}
)
span.set_attribute("output", output_attr)
@ -656,9 +617,7 @@ class ChatAgent(ShieldRunnerMixin):
message.content = [message.content] + output_attachments
yield message
else:
logger.debug(
f"completion message with EOM (iter: {n_iter}): {str(message)}"
)
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
input_messages = input_messages + [message]
else:
input_messages = input_messages + [message]
@ -707,9 +666,7 @@ class ChatAgent(ShieldRunnerMixin):
"input": message.model_dump_json(),
},
) as span:
tool_execution_start_time = datetime.now(
timezone.utc
).isoformat()
tool_execution_start_time = datetime.now(timezone.utc).isoformat()
tool_result = await self.execute_tool_call_maybe(
session_id,
tool_call,
@ -758,9 +715,7 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially
if (type(result_message.content) is str) and (
out_attachment := _interpret_content_as_attachment(
result_message.content
)
out_attachment := _interpret_content_as_attachment(result_message.content)
):
# NOTE: when we push this message back to the model, the model may ignore the
# attached file path etc. since the model is trained to only provide a user message
@ -797,24 +752,16 @@ class ChatAgent(ShieldRunnerMixin):
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
) -> None:
toolgroup_to_args = {}
for toolgroup in (self.agent_config.toolgroups or []) + (
toolgroups_for_turn or []
):
for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []):
if isinstance(toolgroup, AgentToolGroupWithArgs):
tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name)
toolgroup_to_args[tool_group_name] = toolgroup.args
# Determine which tools to include
tool_groups_to_include = (
toolgroups_for_turn or self.agent_config.toolgroups or []
)
tool_groups_to_include = toolgroups_for_turn or self.agent_config.toolgroups or []
agent_config_toolgroups = []
for toolgroup in tool_groups_to_include:
name = (
toolgroup.name
if isinstance(toolgroup, AgentToolGroupWithArgs)
else toolgroup
)
name = toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup
if name not in agent_config_toolgroups:
agent_config_toolgroups.append(name)
@ -840,32 +787,20 @@ class ChatAgent(ShieldRunnerMixin):
},
)
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
toolgroup_name, input_tool_name = self._parse_toolgroup_name(
toolgroup_name_with_maybe_tool_name
)
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
if not tools.data:
available_tool_groups = ", ".join(
[
t.identifier
for t in (await self.tool_groups_api.list_tool_groups()).data
]
[t.identifier for t in (await self.tool_groups_api.list_tool_groups()).data]
)
raise ValueError(
f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}"
)
if input_tool_name is not None and not any(
tool.identifier == input_tool_name for tool in tools.data
):
raise ValueError(f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}")
if input_tool_name is not None and not any(tool.identifier == input_tool_name for tool in tools.data):
raise ValueError(
f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
)
for tool_def in tools.data:
if (
toolgroup_name.startswith("builtin")
and toolgroup_name != RAG_TOOL_GROUP
):
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
identifier: str | BuiltinTool | None = tool_def.identifier
if identifier == "web_search":
identifier = BuiltinTool.brave_search
@ -894,18 +829,14 @@ class ChatAgent(ShieldRunnerMixin):
for param in tool_def.parameters
},
)
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(
toolgroup_name, {}
)
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})
self.tool_defs, self.tool_name_to_args = (
list(tool_name_to_def.values()),
tool_name_to_args,
)
def _parse_toolgroup_name(
self, toolgroup_name_with_maybe_tool_name: str
) -> tuple[str, Optional[str]]:
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
"""Parse a toolgroup name into its components.
Args:
@ -941,9 +872,7 @@ class ChatAgent(ShieldRunnerMixin):
else:
tool_name_str = tool_name
logger.info(
f"executing tool call: {tool_name_str} with args: {tool_call.arguments}"
)
logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
result = await self.tool_runtime_api.invoke_tool(
tool_name=tool_name_str,
kwargs={
@ -965,6 +894,7 @@ async def load_data_from_url(url: str) -> str:
return resp
return ""
async def get_raw_document_text(document: Document) -> str:
if isinstance(document.content, URL):
return await load_data_from_url(document.content.uri)
@ -976,9 +906,7 @@ async def get_raw_document_text(document: Document) -> str:
elif isinstance(document.content, TextContentItem):
return document.content.text
else:
raise ValueError(
f"Unexpected document content type: {type(document.content)}"
)
raise ValueError(f"Unexpected document content type: {type(document.content)}")
def _interpret_content_as_attachment(
@ -993,4 +921,4 @@ def _interpret_content_as_attachment(
mime_type=data["mimetype"],
)
return None
return None