change context to content

This commit is contained in:
Xi Yan 2025-03-29 11:55:50 -07:00
parent d8a8a734b5
commit e84d3966c4

View file

@ -38,10 +38,10 @@ from llama_stack.apis.agents import (
Turn, Turn,
) )
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
URL,
TextContentItem, TextContentItem,
ToolCallDelta, ToolCallDelta,
ToolCallParseStatus, ToolCallParseStatus,
URL,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
@ -57,11 +57,7 @@ from llama_stack.apis.inference import (
UserMessage, UserMessage,
) )
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ( from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
ToolGroups,
ToolInvocationResult,
ToolRuntime,
)
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
@ -77,7 +73,9 @@ from .safety import SafetyException, ShieldRunnerMixin
def make_random_string(length: int = 8): def make_random_string(length: int = 8):
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length)) return "".join(
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
)
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})") TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
@ -176,7 +174,9 @@ class ChatAgent(ShieldRunnerMixin):
messages.extend(self.turn_to_messages(turn)) messages.extend(self.turn_to_messages(turn))
return messages return messages
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator: async def create_and_execute_turn(
self, request: AgentTurnCreateRequest
) -> AsyncGenerator:
span = tracing.get_current_span() span = tracing.get_current_span()
if span: if span:
span.set_attribute("session_id", request.session_id) span.set_attribute("session_id", request.session_id)
@ -221,13 +221,16 @@ class ChatAgent(ShieldRunnerMixin):
messages = await self.get_messages_from_turns(turns) messages = await self.get_messages_from_turns(turns)
if is_resume: if is_resume:
tool_response_messages = [ tool_response_messages = [
ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses ToolResponseMessage(call_id=x.call_id, content=x.content)
for x in request.tool_responses
] ]
messages.extend(tool_response_messages) messages.extend(tool_response_messages)
last_turn = turns[-1] last_turn = turns[-1]
last_turn_messages = self.turn_to_messages(last_turn) last_turn_messages = self.turn_to_messages(last_turn)
last_turn_messages = [ last_turn_messages = [
x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage) x
for x in last_turn_messages
if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
] ]
last_turn_messages.extend(tool_response_messages) last_turn_messages.extend(tool_response_messages)
@ -237,17 +240,31 @@ class ChatAgent(ShieldRunnerMixin):
# mark tool execution step as complete # mark tool execution step as complete
# if there's no tool execution in progress step (due to storage, or tool call parsing on client), # if there's no tool execution in progress step (due to storage, or tool call parsing on client),
# we'll create a new tool execution step with current time # we'll create a new tool execution step with current time
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step( in_progress_tool_call_step = (
request.session_id, request.turn_id await self.storage.get_in_progress_tool_call_step(
request.session_id, request.turn_id
)
) )
now = datetime.now(timezone.utc).isoformat() now = datetime.now(timezone.utc).isoformat()
tool_execution_step = ToolExecutionStep( tool_execution_step = ToolExecutionStep(
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())), step_id=(
in_progress_tool_call_step.step_id
if in_progress_tool_call_step
else str(uuid.uuid4())
),
turn_id=request.turn_id, turn_id=request.turn_id,
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []), tool_calls=(
in_progress_tool_call_step.tool_calls
if in_progress_tool_call_step
else []
),
tool_responses=request.tool_responses, tool_responses=request.tool_responses,
completed_at=now, completed_at=now,
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now), started_at=(
in_progress_tool_call_step.started_at
if in_progress_tool_call_step
else now
),
) )
steps.append(tool_execution_step) steps.append(tool_execution_step)
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
@ -281,9 +298,14 @@ class ChatAgent(ShieldRunnerMixin):
output_message = chunk output_message = chunk
continue continue
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}" assert isinstance(
chunk, AgentTurnResponseStreamChunk
), f"Unexpected type {type(chunk)}"
event = chunk.event event = chunk.event
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value: if (
event.payload.event_type
== AgentTurnResponseEventType.step_complete.value
):
steps.append(event.payload.step_details) steps.append(event.payload.step_details)
yield chunk yield chunk
@ -459,7 +481,7 @@ class ChatAgent(ShieldRunnerMixin):
contexts.append(raw_document_text) contexts.append(raw_document_text)
attached_context = "\n".join(contexts) attached_context = "\n".join(contexts)
input_messages[-1].context = attached_context input_messages[-1].content += attached_context
session_info = await self.storage.get_session_info(session_id) session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it # if the session has a memory bank id, let the memory tool use it
@ -467,13 +489,19 @@ class ChatAgent(ShieldRunnerMixin):
for tool_name in self.tool_name_to_args.keys(): for tool_name in self.tool_name_to_args.keys():
if tool_name == MEMORY_QUERY_TOOL: if tool_name == MEMORY_QUERY_TOOL:
if "vector_db_ids" not in self.tool_name_to_args[tool_name]: if "vector_db_ids" not in self.tool_name_to_args[tool_name]:
self.tool_name_to_args[tool_name]["vector_db_ids"] = [session_info.vector_db_id] self.tool_name_to_args[tool_name]["vector_db_ids"] = [
session_info.vector_db_id
]
else: else:
self.tool_name_to_args[tool_name]["vector_db_ids"].append(session_info.vector_db_id) self.tool_name_to_args[tool_name]["vector_db_ids"].append(
session_info.vector_db_id
)
output_attachments = [] output_attachments = []
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0 n_iter = (
await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
)
# Build a map of custom tools to their definitions for faster lookup # Build a map of custom tools to their definitions for faster lookup
client_tools = {} client_tools = {}
@ -551,12 +579,16 @@ class ChatAgent(ShieldRunnerMixin):
span.set_attribute("stop_reason", stop_reason) span.set_attribute("stop_reason", stop_reason)
span.set_attribute( span.set_attribute(
"input", "input",
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]), json.dumps(
[json.loads(m.model_dump_json()) for m in input_messages]
),
) )
output_attr = json.dumps( output_attr = json.dumps(
{ {
"content": content, "content": content,
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls], "tool_calls": [
json.loads(t.model_dump_json()) for t in tool_calls
],
} }
) )
span.set_attribute("output", output_attr) span.set_attribute("output", output_attr)
@ -620,7 +652,9 @@ class ChatAgent(ShieldRunnerMixin):
message.content = [message.content] + output_attachments message.content = [message.content] + output_attachments
yield message yield message
else: else:
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}") logger.debug(
f"completion message with EOM (iter: {n_iter}): {str(message)}"
)
input_messages = input_messages + [message] input_messages = input_messages + [message]
else: else:
input_messages = input_messages + [message] input_messages = input_messages + [message]
@ -669,7 +703,9 @@ class ChatAgent(ShieldRunnerMixin):
"input": message.model_dump_json(), "input": message.model_dump_json(),
}, },
) as span: ) as span:
tool_execution_start_time = datetime.now(timezone.utc).isoformat() tool_execution_start_time = datetime.now(
timezone.utc
).isoformat()
tool_result = await self.execute_tool_call_maybe( tool_result = await self.execute_tool_call_maybe(
session_id, session_id,
tool_call, tool_call,
@ -718,7 +754,9 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: add tool-input touchpoint and a "start" event for this step also # TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially # but that needs a lot more refactoring of Tool code potentially
if (type(result_message.content) is str) and ( if (type(result_message.content) is str) and (
out_attachment := _interpret_content_as_attachment(result_message.content) out_attachment := _interpret_content_as_attachment(
result_message.content
)
): ):
# NOTE: when we push this message back to the model, the model may ignore the # NOTE: when we push this message back to the model, the model may ignore the
# attached file path etc. since the model is trained to only provide a user message # attached file path etc. since the model is trained to only provide a user message
@ -755,16 +793,24 @@ class ChatAgent(ShieldRunnerMixin):
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
) -> None: ) -> None:
toolgroup_to_args = {} toolgroup_to_args = {}
for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []): for toolgroup in (self.agent_config.toolgroups or []) + (
toolgroups_for_turn or []
):
if isinstance(toolgroup, AgentToolGroupWithArgs): if isinstance(toolgroup, AgentToolGroupWithArgs):
tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name) tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name)
toolgroup_to_args[tool_group_name] = toolgroup.args toolgroup_to_args[tool_group_name] = toolgroup.args
# Determine which tools to include # Determine which tools to include
tool_groups_to_include = toolgroups_for_turn or self.agent_config.toolgroups or [] tool_groups_to_include = (
toolgroups_for_turn or self.agent_config.toolgroups or []
)
agent_config_toolgroups = [] agent_config_toolgroups = []
for toolgroup in tool_groups_to_include: for toolgroup in tool_groups_to_include:
name = toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup name = (
toolgroup.name
if isinstance(toolgroup, AgentToolGroupWithArgs)
else toolgroup
)
if name not in agent_config_toolgroups: if name not in agent_config_toolgroups:
agent_config_toolgroups.append(name) agent_config_toolgroups.append(name)
@ -790,20 +836,32 @@ class ChatAgent(ShieldRunnerMixin):
}, },
) )
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups: for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name) toolgroup_name, input_tool_name = self._parse_toolgroup_name(
toolgroup_name_with_maybe_tool_name
)
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name) tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
if not tools.data: if not tools.data:
available_tool_groups = ", ".join( available_tool_groups = ", ".join(
[t.identifier for t in (await self.tool_groups_api.list_tool_groups()).data] [
t.identifier
for t in (await self.tool_groups_api.list_tool_groups()).data
]
) )
raise ValueError(f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}") raise ValueError(
if input_tool_name is not None and not any(tool.identifier == input_tool_name for tool in tools.data): f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}"
)
if input_tool_name is not None and not any(
tool.identifier == input_tool_name for tool in tools.data
):
raise ValueError( raise ValueError(
f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}" f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
) )
for tool_def in tools.data: for tool_def in tools.data:
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP: if (
toolgroup_name.startswith("builtin")
and toolgroup_name != RAG_TOOL_GROUP
):
identifier: str | BuiltinTool | None = tool_def.identifier identifier: str | BuiltinTool | None = tool_def.identifier
if identifier == "web_search": if identifier == "web_search":
identifier = BuiltinTool.brave_search identifier = BuiltinTool.brave_search
@ -832,14 +890,18 @@ class ChatAgent(ShieldRunnerMixin):
for param in tool_def.parameters for param in tool_def.parameters
}, },
) )
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {}) tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(
toolgroup_name, {}
)
self.tool_defs, self.tool_name_to_args = ( self.tool_defs, self.tool_name_to_args = (
list(tool_name_to_def.values()), list(tool_name_to_def.values()),
tool_name_to_args, tool_name_to_args,
) )
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]: def _parse_toolgroup_name(
self, toolgroup_name_with_maybe_tool_name: str
) -> tuple[str, Optional[str]]:
"""Parse a toolgroup name into its components. """Parse a toolgroup name into its components.
Args: Args:
@ -875,7 +937,9 @@ class ChatAgent(ShieldRunnerMixin):
else: else:
tool_name_str = tool_name tool_name_str = tool_name
logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}") logger.info(
f"executing tool call: {tool_name_str} with args: {tool_call.arguments}"
)
result = await self.tool_runtime_api.invoke_tool( result = await self.tool_runtime_api.invoke_tool(
tool_name=tool_name_str, tool_name=tool_name_str,
kwargs={ kwargs={