diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index b403b9203..95225b730 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -239,13 +239,14 @@ class ChatAgent(ShieldRunnerMixin): # return a "final value" for the `yield from` statement. we simulate that by yielding a # final boolean (to see whether an exception happened) and then explicitly testing for it. - async for res in self.run_multiple_shields_wrapper( - turn_id, input_messages, self.input_shields, "user-input" - ): - if isinstance(res, bool): - return - else: - yield res + if len(self.input_shields) > 0: + async for res in self.run_multiple_shields_wrapper( + turn_id, input_messages, self.input_shields, "user-input" + ): + if isinstance(res, bool): + return + else: + yield res async for res in self._run( session_id, turn_id, input_messages, attachments, sampling_params, stream @@ -262,13 +263,14 @@ class ChatAgent(ShieldRunnerMixin): # for output shields run on the full input and output combination messages = input_messages + [final_response] - async for res in self.run_multiple_shields_wrapper( - turn_id, messages, self.output_shields, "assistant-output" - ): - if isinstance(res, bool): - return - else: - yield res + if len(self.output_shields) > 0: + async for res in self.run_multiple_shields_wrapper( + turn_id, messages, self.output_shields, "assistant-output" + ): + if isinstance(res, bool): + return + else: + yield res yield final_response @@ -531,106 +533,72 @@ class ChatAgent(ShieldRunnerMixin): input_messages = input_messages + [message] else: log.info(f"{str(message)}") - try: - tool_call = message.tool_calls[0] + tool_call = message.tool_calls[0] - name = tool_call.tool_name - if not isinstance(name, BuiltinTool): - yield message - return - - step_id = str(uuid.uuid4()) - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepStartPayload( - step_type=StepType.tool_execution.value, - step_id=step_id, - ) - ) - ) - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepProgressPayload( - step_type=StepType.tool_execution.value, - step_id=step_id, - tool_call=tool_call, - ) - ) - ) - - with tracing.span( - "tool_execution", - { - "tool_name": tool_call.tool_name, - "input": message.model_dump_json(), - }, - ) as span: - result_messages = await execute_tool_call_maybe( - self.tools_dict, - [message], - ) - assert ( - len(result_messages) == 1 - ), "Currently not supporting multiple messages" - result_message = result_messages[0] - span.set_attribute("output", result_message.model_dump_json()) - - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.tool_execution.value, - step_details=ToolExecutionStep( - step_id=step_id, - turn_id=turn_id, - tool_calls=[tool_call], - tool_responses=[ - ToolResponse( - call_id=result_message.call_id, - tool_name=result_message.tool_name, - content=result_message.content, - ) - ], - ), - ) - ) - ) - - # TODO: add tool-input touchpoint and a "start" event for this step also - # but that needs a lot more refactoring of Tool code potentially - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.shield_call.value, - step_details=ShieldCallStep( - step_id=str(uuid.uuid4()), - turn_id=turn_id, - violation=None, - ), - ) - ) - ) - - except SafetyException as e: - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.shield_call.value, - step_details=ShieldCallStep( - step_id=str(uuid.uuid4()), - turn_id=turn_id, - violation=e.violation, - ), - ) - ) - ) - - yield CompletionMessage( - content=str(e), - stop_reason=StopReason.end_of_turn, - ) - yield False + name = tool_call.tool_name + if not isinstance(name, BuiltinTool): + yield message return + step_id = str(uuid.uuid4()) + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepStartPayload( + step_type=StepType.tool_execution.value, + step_id=step_id, + ) + ) + ) + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepProgressPayload( + step_type=StepType.tool_execution.value, + step_id=step_id, + tool_call=tool_call, + ) + ) + ) + + with tracing.span( + "tool_execution", + { + "tool_name": tool_call.tool_name, + "input": message.model_dump_json(), + }, + ) as span: + result_messages = await execute_tool_call_maybe( + self.tools_dict, + [message], + ) + assert ( + len(result_messages) == 1 + ), "Currently not supporting multiple messages" + result_message = result_messages[0] + span.set_attribute("output", result_message.model_dump_json()) + + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepCompletePayload( + step_type=StepType.tool_execution.value, + step_details=ToolExecutionStep( + step_id=step_id, + turn_id=turn_id, + tool_calls=[tool_call], + tool_responses=[ + ToolResponse( + call_id=result_message.call_id, + tool_name=result_message.tool_name, + content=result_message.content, + ) + ], + ), + ) + ) + ) + + # TODO: add tool-input touchpoint and a "start" event for this step also + # but that needs a lot more refactoring of Tool code potentially + if out_attachment := interpret_content_as_attachment( result_message.content ): diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index d5565dd62..e5ad14195 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -7,6 +7,7 @@ from typing import * # noqa: F403 import json import uuid + from botocore.client import BaseClient from llama_models.datatypes import CoreModelId diff --git a/llama_stack/templates/bedrock/bedrock.py b/llama_stack/templates/bedrock/bedrock.py index 8911d159d..0b5b7d90d 100644 --- a/llama_stack/templates/bedrock/bedrock.py +++ b/llama_stack/templates/bedrock/bedrock.py @@ -7,12 +7,14 @@ from pathlib import Path from llama_models.sku_list import all_registered_models + +from llama_stack.apis.models import ModelInput from llama_stack.distribution.datatypes import Provider from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig -from llama_stack.templates.template import DistributionTemplate, RunConfigSettings from llama_stack.providers.remote.inference.bedrock.bedrock import MODEL_ALIASES -from llama_stack.apis.models import ModelInput +from llama_stack.templates.template import DistributionTemplate, RunConfigSettings + def get_distribution_template() -> DistributionTemplate: providers = {