feat: add agent workflow metrics collection

Add comprehensive OpenTelemetry-based metrics for agent observability:

- Workflow completion/failure tracking with duration measurements
- Step execution counters for performance monitoring
- Tool usage tracking with normalized tool names
- Non-blocking telemetry emission with named async tasks
- Comprehensive unit and integration test coverage
- Graceful handling when telemetry is disabled
This commit is contained in:
skamenan7 2025-08-06 17:08:03 -04:00
parent 4c2fcb6b51
commit 69b692af91
13 changed files with 701 additions and 11 deletions

View file

@ -90,6 +90,21 @@ class EventType(Enum):
METRIC = "metric"
@json_schema_type
class MetricType(Enum):
"""The type of metric being recorded.
:cvar COUNTER: A counter metric that only increases (e.g., requests_total)
:cvar UP_DOWN_COUNTER: A counter that can increase or decrease (e.g., active_connections)
:cvar HISTOGRAM: A histogram metric for measuring distributions (e.g., request_duration_seconds)
:cvar GAUGE: A gauge metric for point-in-time values (e.g., cpu_usage_percent)
"""
COUNTER = "counter"
UP_DOWN_COUNTER = "up_down_counter"
HISTOGRAM = "histogram"
GAUGE = "gauge"
@json_schema_type
class LogSeverity(Enum):
"""The severity level of a log message.
@ -143,12 +158,14 @@ class MetricEvent(EventCommon):
:param metric: The name of the metric being measured
:param value: The numeric value of the metric measurement
:param unit: The unit of measurement for the metric value
:param metric_type: The type of metric (optional, inferred if not provided for backwards compatibility)
"""
type: Literal[EventType.METRIC] = EventType.METRIC
metric: str # this would be an enum
value: int | float
unit: str
metric_type: MetricType | None = None
@json_schema_type

View file

@ -4,17 +4,20 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import copy
import json
import re
import secrets
import string
import time
import uuid
import warnings
from collections.abc import AsyncGenerator
from datetime import UTC, datetime
import httpx
from opentelemetry.trace import get_current_span
from llama_stack.apis.agents import (
AgentConfig,
@ -60,6 +63,7 @@ from llama_stack.apis.inference import (
UserMessage,
)
from llama_stack.apis.safety import Safety
from llama_stack.apis.telemetry import MetricEvent, Telemetry
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.datatypes import AccessRule
@ -97,6 +101,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups,
vector_io_api: VectorIO,
telemetry_api: Telemetry | None,
persistence_store: KVStore,
created_at: str,
policy: list[AccessRule],
@ -106,6 +111,7 @@ class ChatAgent(ShieldRunnerMixin):
self.inference_api = inference_api
self.safety_api = safety_api
self.vector_io_api = vector_io_api
self.telemetry_api = telemetry_api
self.storage = AgentPersistence(agent_id, persistence_store, policy)
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
@ -167,6 +173,43 @@ class ChatAgent(ShieldRunnerMixin):
async def create_session(self, name: str) -> str:
return await self.storage.create_session(name)
def _emit_metric(
self, metric_name: str, value: int | float, unit: str, attributes: dict[str, str] | None = None
) -> None:
"""Emit a single metric event"""
if not self.telemetry_api:
return
span = get_current_span()
if not span:
return
context = span.get_span_context()
metric = MetricEvent(
trace_id=format(context.trace_id, "x"),
span_id=format(context.span_id, "x"),
metric=metric_name,
value=value,
timestamp=time.time(),
unit=unit,
attributes={"agent_id": self.agent_id, **(attributes or {})},
)
# Create task with name for better debugging and potential cleanup
task_name = f"metric-{metric_name}-{self.agent_id}"
asyncio.create_task(self.telemetry_api.log_event(metric), name=task_name)
def _track_step(self):
self._emit_metric("llama_stack_agent_steps_total", 1, "1")
def _track_workflow(self, status: str, duration: float):
self._emit_metric("llama_stack_agent_workflows_total", 1, "1", {"status": status})
self._emit_metric("llama_stack_agent_workflow_duration_seconds", duration, "s")
def _track_tool(self, tool_name: str):
normalized_name = "rag" if tool_name == "knowledge_search" else tool_name
self._emit_metric("llama_stack_agent_tool_calls_total", 1, "1", {"tool": normalized_name})
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
messages = []
if self.agent_config.instructions != "":
@ -726,6 +769,9 @@ class ChatAgent(ShieldRunnerMixin):
)
)
# Track step execution metric
self._track_step()
# Add the result message to input_messages for the next iteration
input_messages.append(result_message)
@ -900,6 +946,7 @@ class ChatAgent(ShieldRunnerMixin):
},
)
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
return result

View file

@ -38,6 +38,7 @@ from llama_stack.apis.inference import (
UserMessage,
)
from llama_stack.apis.safety import Safety
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.datatypes import AccessRule
@ -64,6 +65,7 @@ class MetaReferenceAgentsImpl(Agents):
tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups,
policy: list[AccessRule],
telemetry_api: Telemetry | None = None,
):
self.config = config
self.inference_api = inference_api
@ -71,6 +73,7 @@ class MetaReferenceAgentsImpl(Agents):
self.safety_api = safety_api
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
self.telemetry_api = telemetry_api
self.in_memory_store = InmemoryKVStoreImpl()
self.openai_responses_impl: OpenAIResponsesImpl | None = None
@ -130,6 +133,7 @@ class MetaReferenceAgentsImpl(Agents):
vector_io_api=self.vector_io_api,
tool_runtime_api=self.tool_runtime_api,
tool_groups_api=self.tool_groups_api,
telemetry_api=self.telemetry_api,
persistence_store=(
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
),

View file

@ -24,6 +24,7 @@ from llama_stack.apis.telemetry import (
MetricEvent,
MetricLabelMatcher,
MetricQueryType,
MetricType,
QueryCondition,
QueryMetricsResponse,
QuerySpanTreeResponse,
@ -56,6 +57,7 @@ _GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
"counters": {},
"gauges": {},
"up_down_counters": {},
"histograms": {},
}
_global_lock = threading.Lock()
_TRACER_PROVIDER = None
@ -258,12 +260,20 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
# Log to OpenTelemetry meter if available
if self.meter is None:
return
if isinstance(event.value, int):
counter = self._get_or_create_counter(event.metric, event.unit)
counter.add(event.value, attributes=event.attributes)
elif isinstance(event.value, float):
if event.metric_type == MetricType.HISTOGRAM:
histogram = self._get_or_create_histogram(
event.metric,
event.unit,
[0.1, 0.5, 1.0, 2.5, 5.0, 10.0, 25.0, 50.0, 100.0],
)
histogram.record(event.value, attributes=event.attributes)
elif event.metric_type == MetricType.UP_DOWN_COUNTER:
up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit)
up_down_counter.add(event.value, attributes=event.attributes)
else:
counter = self._get_or_create_counter(event.metric, event.unit)
counter.add(event.value, attributes=event.attributes)
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
assert self.meter is not None
@ -275,6 +285,16 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
)
return _GLOBAL_STORAGE["up_down_counters"][name]
def _get_or_create_histogram(self, name: str, unit: str, buckets: list[float] | None = None) -> metrics.Histogram:
assert self.meter is not None
if name not in _GLOBAL_STORAGE["histograms"]:
_GLOBAL_STORAGE["histograms"][name] = self.meter.create_histogram(
name=name,
unit=unit,
description=f"Histogram for {name}",
)
return _GLOBAL_STORAGE["histograms"][name]
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
with self._lock:
span_id = int(event.span_id, 16)

View file

@ -35,6 +35,7 @@ def available_providers() -> list[ProviderSpec]:
Api.vector_dbs,
Api.tool_runtime,
Api.tool_groups,
Api.telemetry,
],
description="Meta's reference implementation of an agent system that can use tools, access vector databases, and perform complex reasoning tasks.",
),