improve agent metrics integration test and cleanup fixtures

- simplified test to use telemetry.query_metrics for verification
- test now validates actual queryable metrics data
- verified by query metrics functionality added in #3074
This commit is contained in:
skamenan7 2025-09-19 10:47:16 -04:00
parent 69b692af91
commit 8f0413e743
5 changed files with 406 additions and 208 deletions

View file

@ -63,7 +63,7 @@ from llama_stack.apis.inference import (
UserMessage,
)
from llama_stack.apis.safety import Safety
from llama_stack.apis.telemetry import MetricEvent, Telemetry
from llama_stack.apis.telemetry import MetricEvent, MetricType, Telemetry
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.datatypes import AccessRule
@ -124,6 +124,9 @@ class ChatAgent(ShieldRunnerMixin):
output_shields=agent_config.output_shields,
)
# Initialize workflow start time to None
self._workflow_start_time: float | None = None
def turn_to_messages(self, turn: Turn) -> list[Message]:
messages = []
@ -174,14 +177,23 @@ class ChatAgent(ShieldRunnerMixin):
return await self.storage.create_session(name)
def _emit_metric(
self, metric_name: str, value: int | float, unit: str, attributes: dict[str, str] | None = None
self,
metric_name: str,
value: int | float,
unit: str,
attributes: dict[str, str] | None = None,
metric_type: MetricType | None = None,
) -> None:
"""Emit a single metric event"""
logger.info(f"_emit_metric called: {metric_name} = {value} {unit}")
if not self.telemetry_api:
logger.warning(f"No telemetry_api available for metric {metric_name}")
return
span = get_current_span()
if not span:
logger.warning(f"No current span available for metric {metric_name}")
return
context = span.get_span_context()
@ -193,22 +205,42 @@ class ChatAgent(ShieldRunnerMixin):
timestamp=time.time(),
unit=unit,
attributes={"agent_id": self.agent_id, **(attributes or {})},
metric_type=metric_type,
)
# Create task with name for better debugging and potential cleanup
# Create task with name for better debugging and capture any async errors
task_name = f"metric-{metric_name}-{self.agent_id}"
asyncio.create_task(self.telemetry_api.log_event(metric), name=task_name)
logger.info(f"Creating telemetry task: {task_name}")
task = asyncio.create_task(self.telemetry_api.log_event(metric), name=task_name)
def _on_metric_task_done(t: asyncio.Task) -> None:
try:
exc = t.exception()
except asyncio.CancelledError:
logger.debug("Metric task %s was cancelled", task_name)
return
if exc is not None:
logger.warning("Metric task %s failed: %s", task_name, exc)
# Only add callback if task creation succeeded (not None from mocking)
if task is not None:
task.add_done_callback(_on_metric_task_done)
def _track_step(self):
self._emit_metric("llama_stack_agent_steps_total", 1, "1")
logger.info("_track_step called")
self._emit_metric("llama_stack_agent_steps_total", 1, "1", metric_type=MetricType.COUNTER)
def _track_workflow(self, status: str, duration: float):
self._emit_metric("llama_stack_agent_workflows_total", 1, "1", {"status": status})
self._emit_metric("llama_stack_agent_workflow_duration_seconds", duration, "s")
logger.info(f"_track_workflow called: status={status}, duration={duration:.2f}s")
self._emit_metric("llama_stack_agent_workflows_total", 1, "1", {"status": status}, MetricType.COUNTER)
self._emit_metric(
"llama_stack_agent_workflow_duration_seconds", duration, "s", metric_type=MetricType.HISTOGRAM
)
def _track_tool(self, tool_name: str):
logger.info(f"_track_tool called: {tool_name}")
normalized_name = "rag" if tool_name == "knowledge_search" else tool_name
self._emit_metric("llama_stack_agent_tool_calls_total", 1, "1", {"tool": normalized_name})
self._emit_metric("llama_stack_agent_tool_calls_total", 1, "1", {"tool": normalized_name}, MetricType.COUNTER)
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
messages = []
@ -244,6 +276,9 @@ class ChatAgent(ShieldRunnerMixin):
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
# Set workflow start time for resume operations
self._workflow_start_time = time.time()
await self._initialize_tools()
async for chunk in self._run_turn(request):
yield chunk
@ -255,6 +290,9 @@ class ChatAgent(ShieldRunnerMixin):
) -> AsyncGenerator:
assert request.stream is True, "Non-streaming not supported"
# Track workflow start time for metrics
self._workflow_start_time = time.time()
is_resume = isinstance(request, AgentTurnResumeRequest)
session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
@ -356,6 +394,10 @@ class ChatAgent(ShieldRunnerMixin):
)
)
else:
# Track workflow completion when turn is actually complete
workflow_duration = time.time() - (self._workflow_start_time or time.time())
self._track_workflow("completed", workflow_duration)
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
@ -771,6 +813,7 @@ class ChatAgent(ShieldRunnerMixin):
# Track step execution metric
self._track_step()
self._track_tool(tool_call.tool_name)
# Add the result message to input_messages for the next iteration
input_messages.append(result_message)