feat: log start, complete time to Agent steps (#1116)

This commit is contained in:
ehhuang 2025-02-14 17:48:06 -08:00 committed by GitHub
parent 8dc1cac333
commit ab2b46e528
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 21 additions and 1 deletions

View file

@ -301,6 +301,7 @@ class ChatAgent(ShieldRunnerMixin):
return
step_id = str(uuid.uuid4())
shield_call_start_time = datetime.now()
try:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -323,6 +324,8 @@ class ChatAgent(ShieldRunnerMixin):
step_id=step_id,
turn_id=turn_id,
violation=e.violation,
started_at=shield_call_start_time,
completed_at=datetime.now(),
),
)
)
@ -344,6 +347,8 @@ class ChatAgent(ShieldRunnerMixin):
step_id=step_id,
turn_id=turn_id,
violation=None,
started_at=shield_call_start_time,
completed_at=datetime.now(),
),
)
)
@ -476,6 +481,7 @@ class ChatAgent(ShieldRunnerMixin):
client_tools[tool.name] = tool
while True:
step_id = str(uuid.uuid4())
inference_start_time = datetime.now()
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
@ -574,6 +580,8 @@ class ChatAgent(ShieldRunnerMixin):
step_id=step_id,
turn_id=turn_id,
model_response=copy.deepcopy(message),
started_at=inference_start_time,
completed_at=datetime.now(),
),
)
)
@ -641,6 +649,7 @@ class ChatAgent(ShieldRunnerMixin):
"input": message.model_dump_json(),
},
) as span:
tool_execution_start_time = datetime.now()
result_messages = await execute_tool_call_maybe(
self.tool_runtime_api,
session_id,
@ -668,6 +677,8 @@ class ChatAgent(ShieldRunnerMixin):
content=result_message.content,
)
],
started_at=tool_execution_start_time,
completed_at=datetime.now(),
),
)
)

View file

@ -545,7 +545,7 @@ def test_create_turn_response(llama_stack_client, agent_config):
messages=[
{
"role": "user",
"content": "What is the boiling point of polyjuice?",
"content": "Call get_boiling_point and answer What is the boiling point of polyjuice?",
},
],
session_id=session_id,
@ -557,3 +557,12 @@ def test_create_turn_response(llama_stack_client, agent_config):
assert steps[1].step_type == "tool_execution"
assert steps[1].tool_calls[0].tool_name == "get_boiling_point"
assert steps[2].step_type == "inference"
last_step_completed_at = None
for step in steps:
if last_step_completed_at is None:
last_step_completed_at = step.completed_at
else:
assert last_step_completed_at < step.started_at
assert step.started_at < step.completed_at
last_step_completed_at = step.completed_at