This commit is contained in:
Sumanth Kamenani 2025-09-24 13:24:23 -04:00 committed by GitHub
commit b3271c6c9e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 905 additions and 17 deletions

View file

@ -90,6 +90,21 @@ class EventType(Enum):
METRIC = "metric"
@json_schema_type
class MetricType(Enum):
"""The type of metric being recorded.
:cvar COUNTER: A counter metric that only increases (e.g., requests_total)
:cvar UP_DOWN_COUNTER: A counter that can increase or decrease (e.g., active_connections)
:cvar HISTOGRAM: A histogram metric for measuring distributions (e.g., request_duration_seconds)
:cvar GAUGE: A gauge metric for point-in-time values (e.g., cpu_usage_percent)
"""
COUNTER = "counter"
UP_DOWN_COUNTER = "up_down_counter"
HISTOGRAM = "histogram"
GAUGE = "gauge"
@json_schema_type
class LogSeverity(Enum):
"""The severity level of a log message.
@ -143,12 +158,14 @@ class MetricEvent(EventCommon):
:param metric: The name of the metric being measured
:param value: The numeric value of the metric measurement
:param unit: The unit of measurement for the metric value
:param metric_type: The type of metric (optional, inferred if not provided for backwards compatibility)
"""
type: Literal[EventType.METRIC] = EventType.METRIC
metric: str # this would be an enum
value: int | float
unit: str
metric_type: MetricType | None = None
@json_schema_type

View file

@ -4,17 +4,20 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import copy
import json
import re
import secrets
import string
import time
import uuid
import warnings
from collections.abc import AsyncGenerator
from datetime import UTC, datetime
import httpx
from opentelemetry.trace import get_current_span
from llama_stack.apis.agents import (
AgentConfig,
@ -60,6 +63,7 @@ from llama_stack.apis.inference import (
UserMessage,
)
from llama_stack.apis.safety import Safety
from llama_stack.apis.telemetry import MetricEvent, MetricType, Telemetry
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.datatypes import AccessRule
@ -97,6 +101,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups,
vector_io_api: VectorIO,
telemetry_api: Telemetry | None,
persistence_store: KVStore,
created_at: str,
policy: list[AccessRule],
@ -106,6 +111,7 @@ class ChatAgent(ShieldRunnerMixin):
self.inference_api = inference_api
self.safety_api = safety_api
self.vector_io_api = vector_io_api
self.telemetry_api = telemetry_api
self.storage = AgentPersistence(agent_id, persistence_store, policy)
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
@ -118,6 +124,9 @@ class ChatAgent(ShieldRunnerMixin):
output_shields=agent_config.output_shields,
)
# Initialize workflow start time to None
self._workflow_start_time: float | None = None
def turn_to_messages(self, turn: Turn) -> list[Message]:
messages = []
@ -167,6 +176,72 @@ class ChatAgent(ShieldRunnerMixin):
async def create_session(self, name: str) -> str:
return await self.storage.create_session(name)
def _emit_metric(
self,
metric_name: str,
value: int | float,
unit: str,
attributes: dict[str, str] | None = None,
metric_type: MetricType | None = None,
) -> None:
"""Emit a single metric event"""
logger.info(f"_emit_metric called: {metric_name} = {value} {unit}")
if not self.telemetry_api:
logger.warning(f"No telemetry_api available for metric {metric_name}")
return
span = get_current_span()
if not span:
logger.warning(f"No current span available for metric {metric_name}")
return
context = span.get_span_context()
metric = MetricEvent(
trace_id=format(context.trace_id, "x"),
span_id=format(context.span_id, "x"),
metric=metric_name,
value=value,
timestamp=time.time(),
unit=unit,
attributes={"agent_id": self.agent_id, **(attributes or {})},
metric_type=metric_type,
)
# Create task with name for better debugging and capture any async errors
task_name = f"metric-{metric_name}-{self.agent_id}"
logger.info(f"Creating telemetry task: {task_name}")
task = asyncio.create_task(self.telemetry_api.log_event(metric), name=task_name)
def _on_metric_task_done(t: asyncio.Task) -> None:
try:
exc = t.exception()
except asyncio.CancelledError:
logger.debug("Metric task %s was cancelled", task_name)
return
if exc is not None:
logger.warning("Metric task %s failed: %s", task_name, exc)
# Only add callback if task creation succeeded (not None from mocking)
if task is not None:
task.add_done_callback(_on_metric_task_done)
def _track_step(self):
logger.info("_track_step called")
self._emit_metric("llama_stack_agent_steps_total", 1, "1", metric_type=MetricType.COUNTER)
def _track_workflow(self, status: str, duration: float):
logger.info(f"_track_workflow called: status={status}, duration={duration:.2f}s")
self._emit_metric("llama_stack_agent_workflows_total", 1, "1", {"status": status}, MetricType.COUNTER)
self._emit_metric(
"llama_stack_agent_workflow_duration_seconds", duration, "s", metric_type=MetricType.HISTOGRAM
)
def _track_tool(self, tool_name: str):
logger.info(f"_track_tool called: {tool_name}")
normalized_name = "rag" if tool_name == "knowledge_search" else tool_name
self._emit_metric("llama_stack_agent_tool_calls_total", 1, "1", {"tool": normalized_name}, MetricType.COUNTER)
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
messages = []
if self.agent_config.instructions != "":
@ -201,6 +276,9 @@ class ChatAgent(ShieldRunnerMixin):
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
# Set workflow start time for resume operations
self._workflow_start_time = time.time()
await self._initialize_tools()
async for chunk in self._run_turn(request):
yield chunk
@ -212,6 +290,9 @@ class ChatAgent(ShieldRunnerMixin):
) -> AsyncGenerator:
assert request.stream is True, "Non-streaming not supported"
# Track workflow start time for metrics
self._workflow_start_time = time.time()
is_resume = isinstance(request, AgentTurnResumeRequest)
session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
@ -313,6 +394,10 @@ class ChatAgent(ShieldRunnerMixin):
)
)
else:
# Track workflow completion when turn is actually complete
workflow_duration = time.time() - (self._workflow_start_time or time.time())
self._track_workflow("completed", workflow_duration)
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
@ -726,6 +811,10 @@ class ChatAgent(ShieldRunnerMixin):
)
)
# Track step execution metric
self._track_step()
self._track_tool(tool_call.tool_name)
# Add the result message to input_messages for the next iteration
input_messages.append(result_message)
@ -900,6 +989,7 @@ class ChatAgent(ShieldRunnerMixin):
},
)
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
return result

View file

@ -38,6 +38,7 @@ from llama_stack.apis.inference import (
UserMessage,
)
from llama_stack.apis.safety import Safety
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.datatypes import AccessRule
@ -64,6 +65,7 @@ class MetaReferenceAgentsImpl(Agents):
tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups,
policy: list[AccessRule],
telemetry_api: Telemetry | None = None,
):
self.config = config
self.inference_api = inference_api
@ -71,6 +73,7 @@ class MetaReferenceAgentsImpl(Agents):
self.safety_api = safety_api
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
self.telemetry_api = telemetry_api
self.in_memory_store = InmemoryKVStoreImpl()
self.openai_responses_impl: OpenAIResponsesImpl | None = None
@ -130,6 +133,7 @@ class MetaReferenceAgentsImpl(Agents):
vector_io_api=self.vector_io_api,
tool_runtime_api=self.tool_runtime_api,
tool_groups_api=self.tool_groups_api,
telemetry_api=self.telemetry_api,
persistence_store=(
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
),

View file

@ -12,7 +12,9 @@ from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics._internal.aggregation import ExplicitBucketHistogramAggregation
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.sdk.metrics.view import View
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
@ -24,6 +26,7 @@ from llama_stack.apis.telemetry import (
MetricEvent,
MetricLabelMatcher,
MetricQueryType,
MetricType,
QueryCondition,
QueryMetricsResponse,
QuerySpanTreeResponse,
@ -56,6 +59,7 @@ _GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
"counters": {},
"gauges": {},
"up_down_counters": {},
"histograms": {},
}
_global_lock = threading.Lock()
_TRACER_PROVIDER = None
@ -108,7 +112,17 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
if TelemetrySink.OTEL_METRIC in self.config.sinks:
metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter())
metric_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
# decent default buckets for agent workflow timings
hist_buckets = [0.1, 0.5, 1.0, 2.5, 5.0, 10.0, 25.0, 50.0, 100.0]
views = [
View(
instrument_type=metrics.Histogram,
aggregation=ExplicitBucketHistogramAggregation(boundaries=hist_buckets),
)
]
metric_provider = MeterProvider(resource=resource, metric_readers=[metric_reader], views=views)
metrics.set_meter_provider(metric_provider)
if TelemetrySink.SQLITE in self.config.sinks:
@ -138,8 +152,6 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
self._log_metric(event)
elif isinstance(event, StructuredLogEvent):
self._log_structured(event, ttl_seconds)
else:
raise ValueError(f"Unknown event type: {event}")
async def query_metrics(
self,
@ -209,7 +221,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
_GLOBAL_STORAGE["counters"][name] = self.meter.create_counter(
name=name,
unit=unit,
description=f"Counter for {name}",
description=name.replace("_", " "),
)
return _GLOBAL_STORAGE["counters"][name]
@ -219,7 +231,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
name=name,
unit=unit,
description=f"Gauge for {name}",
description=name.replace("_", " "),
)
return _GLOBAL_STORAGE["gauges"][name]
@ -258,12 +270,19 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
# Log to OpenTelemetry meter if available
if self.meter is None:
return
if isinstance(event.value, int):
counter = self._get_or_create_counter(event.metric, event.unit)
counter.add(event.value, attributes=event.attributes)
elif isinstance(event.value, float):
if event.metric_type == MetricType.HISTOGRAM:
histogram = self._get_or_create_histogram(
event.metric,
event.unit,
)
histogram.record(event.value, attributes=event.attributes)
elif event.metric_type == MetricType.UP_DOWN_COUNTER:
up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit)
up_down_counter.add(event.value, attributes=event.attributes)
else:
counter = self._get_or_create_counter(event.metric, event.unit)
counter.add(event.value, attributes=event.attributes)
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
assert self.meter is not None
@ -271,10 +290,20 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
_GLOBAL_STORAGE["up_down_counters"][name] = self.meter.create_up_down_counter(
name=name,
unit=unit,
description=f"UpDownCounter for {name}",
description=name.replace("_", " "),
)
return _GLOBAL_STORAGE["up_down_counters"][name]
def _get_or_create_histogram(self, name: str, unit: str) -> metrics.Histogram:
assert self.meter is not None
if name not in _GLOBAL_STORAGE["histograms"]:
_GLOBAL_STORAGE["histograms"][name] = self.meter.create_histogram(
name=name,
unit=unit,
description=name.replace("_", " "),
)
return _GLOBAL_STORAGE["histograms"][name]
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
with self._lock:
span_id = int(event.span_id, 16)

View file

@ -35,6 +35,7 @@ def available_providers() -> list[ProviderSpec]:
Api.vector_dbs,
Api.tool_runtime,
Api.tool_groups,
Api.telemetry,
],
description="Meta's reference implementation of an agent system that can use tools, access vector databases, and perform complex reasoning tasks.",
),