feat: log start, complete time to Agent steps (#1116)

This commit is contained in:
ehhuang 2025-02-14 17:48:06 -08:00 committed by GitHub
parent 8dc1cac333
commit ab2b46e528
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 21 additions and 1 deletions

View file

@ -301,6 +301,7 @@ class ChatAgent(ShieldRunnerMixin):
return return
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
shield_call_start_time = datetime.now()
try: try:
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
@ -323,6 +324,8 @@ class ChatAgent(ShieldRunnerMixin):
step_id=step_id, step_id=step_id,
turn_id=turn_id, turn_id=turn_id,
violation=e.violation, violation=e.violation,
started_at=shield_call_start_time,
completed_at=datetime.now(),
), ),
) )
) )
@ -344,6 +347,8 @@ class ChatAgent(ShieldRunnerMixin):
step_id=step_id, step_id=step_id,
turn_id=turn_id, turn_id=turn_id,
violation=None, violation=None,
started_at=shield_call_start_time,
completed_at=datetime.now(),
), ),
) )
) )
@ -476,6 +481,7 @@ class ChatAgent(ShieldRunnerMixin):
client_tools[tool.name] = tool client_tools[tool.name] = tool
while True: while True:
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
inference_start_time = datetime.now()
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload( payload=AgentTurnResponseStepStartPayload(
@ -574,6 +580,8 @@ class ChatAgent(ShieldRunnerMixin):
step_id=step_id, step_id=step_id,
turn_id=turn_id, turn_id=turn_id,
model_response=copy.deepcopy(message), model_response=copy.deepcopy(message),
started_at=inference_start_time,
completed_at=datetime.now(),
), ),
) )
) )
@ -641,6 +649,7 @@ class ChatAgent(ShieldRunnerMixin):
"input": message.model_dump_json(), "input": message.model_dump_json(),
}, },
) as span: ) as span:
tool_execution_start_time = datetime.now()
result_messages = await execute_tool_call_maybe( result_messages = await execute_tool_call_maybe(
self.tool_runtime_api, self.tool_runtime_api,
session_id, session_id,
@ -668,6 +677,8 @@ class ChatAgent(ShieldRunnerMixin):
content=result_message.content, content=result_message.content,
) )
], ],
started_at=tool_execution_start_time,
completed_at=datetime.now(),
), ),
) )
) )

View file

@ -545,7 +545,7 @@ def test_create_turn_response(llama_stack_client, agent_config):
messages=[ messages=[
{ {
"role": "user", "role": "user",
"content": "What is the boiling point of polyjuice?", "content": "Call get_boiling_point and answer What is the boiling point of polyjuice?",
}, },
], ],
session_id=session_id, session_id=session_id,
@ -557,3 +557,12 @@ def test_create_turn_response(llama_stack_client, agent_config):
assert steps[1].step_type == "tool_execution" assert steps[1].step_type == "tool_execution"
assert steps[1].tool_calls[0].tool_name == "get_boiling_point" assert steps[1].tool_calls[0].tool_name == "get_boiling_point"
assert steps[2].step_type == "inference" assert steps[2].step_type == "inference"
last_step_completed_at = None
for step in steps:
if last_step_completed_at is None:
last_step_completed_at = step.completed_at
else:
assert last_step_completed_at < step.started_at
assert step.started_at < step.completed_at
last_step_completed_at = step.completed_at