diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index fc597d0f7..1c21df57f 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -301,6 +301,7 @@ class ChatAgent(ShieldRunnerMixin): return step_id = str(uuid.uuid4()) + shield_call_start_time = datetime.now() try: yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( @@ -323,6 +324,8 @@ class ChatAgent(ShieldRunnerMixin): step_id=step_id, turn_id=turn_id, violation=e.violation, + started_at=shield_call_start_time, + completed_at=datetime.now(), ), ) ) @@ -344,6 +347,8 @@ class ChatAgent(ShieldRunnerMixin): step_id=step_id, turn_id=turn_id, violation=None, + started_at=shield_call_start_time, + completed_at=datetime.now(), ), ) ) @@ -476,6 +481,7 @@ class ChatAgent(ShieldRunnerMixin): client_tools[tool.name] = tool while True: step_id = str(uuid.uuid4()) + inference_start_time = datetime.now() yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepStartPayload( @@ -574,6 +580,8 @@ class ChatAgent(ShieldRunnerMixin): step_id=step_id, turn_id=turn_id, model_response=copy.deepcopy(message), + started_at=inference_start_time, + completed_at=datetime.now(), ), ) ) @@ -641,6 +649,7 @@ class ChatAgent(ShieldRunnerMixin): "input": message.model_dump_json(), }, ) as span: + tool_execution_start_time = datetime.now() result_messages = await execute_tool_call_maybe( self.tool_runtime_api, session_id, @@ -668,6 +677,8 @@ class ChatAgent(ShieldRunnerMixin): content=result_message.content, ) ], + started_at=tool_execution_start_time, + completed_at=datetime.now(), ), ) ) diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index e5c20c3a5..0369f325b 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -545,7 +545,7 @@ def test_create_turn_response(llama_stack_client, agent_config): messages=[ { "role": "user", - "content": "What is the boiling point of polyjuice?", + "content": "Call get_boiling_point and answer What is the boiling point of polyjuice?", }, ], session_id=session_id, @@ -557,3 +557,12 @@ def test_create_turn_response(llama_stack_client, agent_config): assert steps[1].step_type == "tool_execution" assert steps[1].tool_calls[0].tool_name == "get_boiling_point" assert steps[2].step_type == "inference" + + last_step_completed_at = None + for step in steps: + if last_step_completed_at is None: + last_step_completed_at = step.completed_at + else: + assert last_step_completed_at < step.started_at + assert step.started_at < step.completed_at + last_step_completed_at = step.completed_at