diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 6d85d1a72..df6c9bb36 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -38,10 +38,10 @@ from llama_stack.apis.agents import ( Turn, ) from llama_stack.apis.common.content_types import ( - URL, TextContentItem, ToolCallDelta, ToolCallParseStatus, + URL, ) from llama_stack.apis.inference import ( ChatCompletionResponseEventType, @@ -57,11 +57,7 @@ from llama_stack.apis.inference import ( UserMessage, ) from llama_stack.apis.safety import Safety -from llama_stack.apis.tools import ( - ToolGroups, - ToolInvocationResult, - ToolRuntime, -) +from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime from llama_stack.apis.vector_io import VectorIO from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ( @@ -77,7 +73,9 @@ from .safety import SafetyException, ShieldRunnerMixin def make_random_string(length: int = 8): - return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length)) + return "".join( + secrets.choice(string.ascii_letters + string.digits) for _ in range(length) + ) TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})") @@ -176,7 +174,9 @@ class ChatAgent(ShieldRunnerMixin): messages.extend(self.turn_to_messages(turn)) return messages - async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator: + async def create_and_execute_turn( + self, request: AgentTurnCreateRequest + ) -> AsyncGenerator: span = tracing.get_current_span() if span: span.set_attribute("session_id", request.session_id) @@ -221,13 +221,16 @@ class ChatAgent(ShieldRunnerMixin): messages = await self.get_messages_from_turns(turns) if is_resume: tool_response_messages = [ - ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses + ToolResponseMessage(call_id=x.call_id, content=x.content) + for x in request.tool_responses ] messages.extend(tool_response_messages) last_turn = turns[-1] last_turn_messages = self.turn_to_messages(last_turn) last_turn_messages = [ - x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage) + x + for x in last_turn_messages + if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage) ] last_turn_messages.extend(tool_response_messages) @@ -237,17 +240,31 @@ class ChatAgent(ShieldRunnerMixin): # mark tool execution step as complete # if there's no tool execution in progress step (due to storage, or tool call parsing on client), # we'll create a new tool execution step with current time - in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step( - request.session_id, request.turn_id + in_progress_tool_call_step = ( + await self.storage.get_in_progress_tool_call_step( + request.session_id, request.turn_id + ) ) now = datetime.now(timezone.utc).isoformat() tool_execution_step = ToolExecutionStep( - step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())), + step_id=( + in_progress_tool_call_step.step_id + if in_progress_tool_call_step + else str(uuid.uuid4()) + ), turn_id=request.turn_id, - tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []), + tool_calls=( + in_progress_tool_call_step.tool_calls + if in_progress_tool_call_step + else [] + ), tool_responses=request.tool_responses, completed_at=now, - started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now), + started_at=( + in_progress_tool_call_step.started_at + if in_progress_tool_call_step + else now + ), ) steps.append(tool_execution_step) yield AgentTurnResponseStreamChunk( @@ -281,9 +298,14 @@ class ChatAgent(ShieldRunnerMixin): output_message = chunk continue - assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}" + assert isinstance( + chunk, AgentTurnResponseStreamChunk + ), f"Unexpected type {type(chunk)}" event = chunk.event - if event.payload.event_type == AgentTurnResponseEventType.step_complete.value: + if ( + event.payload.event_type + == AgentTurnResponseEventType.step_complete.value + ): steps.append(event.payload.step_details) yield chunk @@ -459,7 +481,7 @@ class ChatAgent(ShieldRunnerMixin): contexts.append(raw_document_text) attached_context = "\n".join(contexts) - input_messages[-1].context = attached_context + input_messages[-1].content += attached_context session_info = await self.storage.get_session_info(session_id) # if the session has a memory bank id, let the memory tool use it @@ -467,13 +489,19 @@ class ChatAgent(ShieldRunnerMixin): for tool_name in self.tool_name_to_args.keys(): if tool_name == MEMORY_QUERY_TOOL: if "vector_db_ids" not in self.tool_name_to_args[tool_name]: - self.tool_name_to_args[tool_name]["vector_db_ids"] = [session_info.vector_db_id] + self.tool_name_to_args[tool_name]["vector_db_ids"] = [ + session_info.vector_db_id + ] else: - self.tool_name_to_args[tool_name]["vector_db_ids"].append(session_info.vector_db_id) + self.tool_name_to_args[tool_name]["vector_db_ids"].append( + session_info.vector_db_id + ) output_attachments = [] - n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0 + n_iter = ( + await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0 + ) # Build a map of custom tools to their definitions for faster lookup client_tools = {} @@ -551,12 +579,16 @@ class ChatAgent(ShieldRunnerMixin): span.set_attribute("stop_reason", stop_reason) span.set_attribute( "input", - json.dumps([json.loads(m.model_dump_json()) for m in input_messages]), + json.dumps( + [json.loads(m.model_dump_json()) for m in input_messages] + ), ) output_attr = json.dumps( { "content": content, - "tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls], + "tool_calls": [ + json.loads(t.model_dump_json()) for t in tool_calls + ], } ) span.set_attribute("output", output_attr) @@ -620,7 +652,9 @@ class ChatAgent(ShieldRunnerMixin): message.content = [message.content] + output_attachments yield message else: - logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}") + logger.debug( + f"completion message with EOM (iter: {n_iter}): {str(message)}" + ) input_messages = input_messages + [message] else: input_messages = input_messages + [message] @@ -669,7 +703,9 @@ class ChatAgent(ShieldRunnerMixin): "input": message.model_dump_json(), }, ) as span: - tool_execution_start_time = datetime.now(timezone.utc).isoformat() + tool_execution_start_time = datetime.now( + timezone.utc + ).isoformat() tool_result = await self.execute_tool_call_maybe( session_id, tool_call, @@ -718,7 +754,9 @@ class ChatAgent(ShieldRunnerMixin): # TODO: add tool-input touchpoint and a "start" event for this step also # but that needs a lot more refactoring of Tool code potentially if (type(result_message.content) is str) and ( - out_attachment := _interpret_content_as_attachment(result_message.content) + out_attachment := _interpret_content_as_attachment( + result_message.content + ) ): # NOTE: when we push this message back to the model, the model may ignore the # attached file path etc. since the model is trained to only provide a user message @@ -755,16 +793,24 @@ class ChatAgent(ShieldRunnerMixin): toolgroups_for_turn: Optional[List[AgentToolGroup]] = None, ) -> None: toolgroup_to_args = {} - for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []): + for toolgroup in (self.agent_config.toolgroups or []) + ( + toolgroups_for_turn or [] + ): if isinstance(toolgroup, AgentToolGroupWithArgs): tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name) toolgroup_to_args[tool_group_name] = toolgroup.args # Determine which tools to include - tool_groups_to_include = toolgroups_for_turn or self.agent_config.toolgroups or [] + tool_groups_to_include = ( + toolgroups_for_turn or self.agent_config.toolgroups or [] + ) agent_config_toolgroups = [] for toolgroup in tool_groups_to_include: - name = toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup + name = ( + toolgroup.name + if isinstance(toolgroup, AgentToolGroupWithArgs) + else toolgroup + ) if name not in agent_config_toolgroups: agent_config_toolgroups.append(name) @@ -790,20 +836,32 @@ class ChatAgent(ShieldRunnerMixin): }, ) for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups: - toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name) + toolgroup_name, input_tool_name = self._parse_toolgroup_name( + toolgroup_name_with_maybe_tool_name + ) tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name) if not tools.data: available_tool_groups = ", ".join( - [t.identifier for t in (await self.tool_groups_api.list_tool_groups()).data] + [ + t.identifier + for t in (await self.tool_groups_api.list_tool_groups()).data + ] ) - raise ValueError(f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}") - if input_tool_name is not None and not any(tool.identifier == input_tool_name for tool in tools.data): + raise ValueError( + f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}" + ) + if input_tool_name is not None and not any( + tool.identifier == input_tool_name for tool in tools.data + ): raise ValueError( f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}" ) for tool_def in tools.data: - if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP: + if ( + toolgroup_name.startswith("builtin") + and toolgroup_name != RAG_TOOL_GROUP + ): identifier: str | BuiltinTool | None = tool_def.identifier if identifier == "web_search": identifier = BuiltinTool.brave_search @@ -832,14 +890,18 @@ class ChatAgent(ShieldRunnerMixin): for param in tool_def.parameters }, ) - tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {}) + tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get( + toolgroup_name, {} + ) self.tool_defs, self.tool_name_to_args = ( list(tool_name_to_def.values()), tool_name_to_args, ) - def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]: + def _parse_toolgroup_name( + self, toolgroup_name_with_maybe_tool_name: str + ) -> tuple[str, Optional[str]]: """Parse a toolgroup name into its components. Args: @@ -875,7 +937,9 @@ class ChatAgent(ShieldRunnerMixin): else: tool_name_str = tool_name - logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}") + logger.info( + f"executing tool call: {tool_name_str} with args: {tool_call.arguments}" + ) result = await self.tool_runtime_api.invoke_tool( tool_name=tool_name_str, kwargs={