mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
improve agent metrics integration test and cleanup fixtures
- simplified test to use telemetry.query_metrics for verification - test now validates actual queryable metrics data - verified by query metrics functionality added in #3074
This commit is contained in:
parent
69b692af91
commit
8f0413e743
5 changed files with 406 additions and 208 deletions
|
@ -63,7 +63,7 @@ from llama_stack.apis.inference import (
|
|||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.telemetry import MetricEvent, Telemetry
|
||||
from llama_stack.apis.telemetry import MetricEvent, MetricType, Telemetry
|
||||
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
|
@ -124,6 +124,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
output_shields=agent_config.output_shields,
|
||||
)
|
||||
|
||||
# Initialize workflow start time to None
|
||||
self._workflow_start_time: float | None = None
|
||||
|
||||
def turn_to_messages(self, turn: Turn) -> list[Message]:
|
||||
messages = []
|
||||
|
||||
|
@ -174,14 +177,23 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
return await self.storage.create_session(name)
|
||||
|
||||
def _emit_metric(
|
||||
self, metric_name: str, value: int | float, unit: str, attributes: dict[str, str] | None = None
|
||||
self,
|
||||
metric_name: str,
|
||||
value: int | float,
|
||||
unit: str,
|
||||
attributes: dict[str, str] | None = None,
|
||||
metric_type: MetricType | None = None,
|
||||
) -> None:
|
||||
"""Emit a single metric event"""
|
||||
logger.info(f"_emit_metric called: {metric_name} = {value} {unit}")
|
||||
|
||||
if not self.telemetry_api:
|
||||
logger.warning(f"No telemetry_api available for metric {metric_name}")
|
||||
return
|
||||
|
||||
span = get_current_span()
|
||||
if not span:
|
||||
logger.warning(f"No current span available for metric {metric_name}")
|
||||
return
|
||||
|
||||
context = span.get_span_context()
|
||||
|
@ -193,22 +205,42 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
timestamp=time.time(),
|
||||
unit=unit,
|
||||
attributes={"agent_id": self.agent_id, **(attributes or {})},
|
||||
metric_type=metric_type,
|
||||
)
|
||||
|
||||
# Create task with name for better debugging and potential cleanup
|
||||
# Create task with name for better debugging and capture any async errors
|
||||
task_name = f"metric-{metric_name}-{self.agent_id}"
|
||||
asyncio.create_task(self.telemetry_api.log_event(metric), name=task_name)
|
||||
logger.info(f"Creating telemetry task: {task_name}")
|
||||
task = asyncio.create_task(self.telemetry_api.log_event(metric), name=task_name)
|
||||
|
||||
def _on_metric_task_done(t: asyncio.Task) -> None:
|
||||
try:
|
||||
exc = t.exception()
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Metric task %s was cancelled", task_name)
|
||||
return
|
||||
if exc is not None:
|
||||
logger.warning("Metric task %s failed: %s", task_name, exc)
|
||||
|
||||
# Only add callback if task creation succeeded (not None from mocking)
|
||||
if task is not None:
|
||||
task.add_done_callback(_on_metric_task_done)
|
||||
|
||||
def _track_step(self):
|
||||
self._emit_metric("llama_stack_agent_steps_total", 1, "1")
|
||||
logger.info("_track_step called")
|
||||
self._emit_metric("llama_stack_agent_steps_total", 1, "1", metric_type=MetricType.COUNTER)
|
||||
|
||||
def _track_workflow(self, status: str, duration: float):
|
||||
self._emit_metric("llama_stack_agent_workflows_total", 1, "1", {"status": status})
|
||||
self._emit_metric("llama_stack_agent_workflow_duration_seconds", duration, "s")
|
||||
logger.info(f"_track_workflow called: status={status}, duration={duration:.2f}s")
|
||||
self._emit_metric("llama_stack_agent_workflows_total", 1, "1", {"status": status}, MetricType.COUNTER)
|
||||
self._emit_metric(
|
||||
"llama_stack_agent_workflow_duration_seconds", duration, "s", metric_type=MetricType.HISTOGRAM
|
||||
)
|
||||
|
||||
def _track_tool(self, tool_name: str):
|
||||
logger.info(f"_track_tool called: {tool_name}")
|
||||
normalized_name = "rag" if tool_name == "knowledge_search" else tool_name
|
||||
self._emit_metric("llama_stack_agent_tool_calls_total", 1, "1", {"tool": normalized_name})
|
||||
self._emit_metric("llama_stack_agent_tool_calls_total", 1, "1", {"tool": normalized_name}, MetricType.COUNTER)
|
||||
|
||||
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
|
||||
messages = []
|
||||
|
@ -244,6 +276,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
|
||||
# Set workflow start time for resume operations
|
||||
self._workflow_start_time = time.time()
|
||||
|
||||
await self._initialize_tools()
|
||||
async for chunk in self._run_turn(request):
|
||||
yield chunk
|
||||
|
@ -255,6 +290,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
) -> AsyncGenerator:
|
||||
assert request.stream is True, "Non-streaming not supported"
|
||||
|
||||
# Track workflow start time for metrics
|
||||
self._workflow_start_time = time.time()
|
||||
|
||||
is_resume = isinstance(request, AgentTurnResumeRequest)
|
||||
session_info = await self.storage.get_session_info(request.session_id)
|
||||
if session_info is None:
|
||||
|
@ -356,6 +394,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
else:
|
||||
# Track workflow completion when turn is actually complete
|
||||
workflow_duration = time.time() - (self._workflow_start_time or time.time())
|
||||
self._track_workflow("completed", workflow_duration)
|
||||
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnCompletePayload(
|
||||
|
@ -771,6 +813,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
# Track step execution metric
|
||||
self._track_step()
|
||||
self._track_tool(tool_call.tool_name)
|
||||
|
||||
# Add the result message to input_messages for the next iteration
|
||||
input_messages.append(result_message)
|
||||
|
|
|
@ -12,7 +12,9 @@ from opentelemetry import metrics, trace
|
|||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.metrics import MeterProvider
|
||||
from opentelemetry.sdk.metrics._internal.aggregation import ExplicitBucketHistogramAggregation
|
||||
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
|
||||
from opentelemetry.sdk.metrics.view import View
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
|
@ -110,7 +112,17 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
|
||||
if TelemetrySink.OTEL_METRIC in self.config.sinks:
|
||||
metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter())
|
||||
metric_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
|
||||
|
||||
# decent default buckets for agent workflow timings
|
||||
hist_buckets = [0.1, 0.5, 1.0, 2.5, 5.0, 10.0, 25.0, 50.0, 100.0]
|
||||
views = [
|
||||
View(
|
||||
instrument_type=metrics.Histogram,
|
||||
aggregation=ExplicitBucketHistogramAggregation(boundaries=hist_buckets),
|
||||
)
|
||||
]
|
||||
|
||||
metric_provider = MeterProvider(resource=resource, metric_readers=[metric_reader], views=views)
|
||||
metrics.set_meter_provider(metric_provider)
|
||||
|
||||
if TelemetrySink.SQLITE in self.config.sinks:
|
||||
|
@ -140,8 +152,6 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
self._log_metric(event)
|
||||
elif isinstance(event, StructuredLogEvent):
|
||||
self._log_structured(event, ttl_seconds)
|
||||
else:
|
||||
raise ValueError(f"Unknown event type: {event}")
|
||||
|
||||
async def query_metrics(
|
||||
self,
|
||||
|
@ -211,7 +221,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
_GLOBAL_STORAGE["counters"][name] = self.meter.create_counter(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"Counter for {name}",
|
||||
description=name.replace("_", " "),
|
||||
)
|
||||
return _GLOBAL_STORAGE["counters"][name]
|
||||
|
||||
|
@ -221,7 +231,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"Gauge for {name}",
|
||||
description=name.replace("_", " "),
|
||||
)
|
||||
return _GLOBAL_STORAGE["gauges"][name]
|
||||
|
||||
|
@ -265,7 +275,6 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
histogram = self._get_or_create_histogram(
|
||||
event.metric,
|
||||
event.unit,
|
||||
[0.1, 0.5, 1.0, 2.5, 5.0, 10.0, 25.0, 50.0, 100.0],
|
||||
)
|
||||
histogram.record(event.value, attributes=event.attributes)
|
||||
elif event.metric_type == MetricType.UP_DOWN_COUNTER:
|
||||
|
@ -281,17 +290,17 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
_GLOBAL_STORAGE["up_down_counters"][name] = self.meter.create_up_down_counter(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"UpDownCounter for {name}",
|
||||
description=name.replace("_", " "),
|
||||
)
|
||||
return _GLOBAL_STORAGE["up_down_counters"][name]
|
||||
|
||||
def _get_or_create_histogram(self, name: str, unit: str, buckets: list[float] | None = None) -> metrics.Histogram:
|
||||
def _get_or_create_histogram(self, name: str, unit: str) -> metrics.Histogram:
|
||||
assert self.meter is not None
|
||||
if name not in _GLOBAL_STORAGE["histograms"]:
|
||||
_GLOBAL_STORAGE["histograms"][name] = self.meter.create_histogram(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"Histogram for {name}",
|
||||
description=name.replace("_", " "),
|
||||
)
|
||||
return _GLOBAL_STORAGE["histograms"][name]
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue