This commit is contained in:
Sumanth Kamenani 2025-09-24 13:24:23 -04:00 committed by GitHub
commit b3271c6c9e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 905 additions and 17 deletions

View file

@ -13616,6 +13616,10 @@
"unit": { "unit": {
"type": "string", "type": "string",
"description": "The unit of measurement for the metric value" "description": "The unit of measurement for the metric value"
},
"metric_type": {
"$ref": "#/components/schemas/MetricType",
"description": "The type of metric (optional, inferred if not provided for backwards compatibility)"
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -13631,6 +13635,17 @@
"title": "MetricEvent", "title": "MetricEvent",
"description": "A metric event containing a measured value." "description": "A metric event containing a measured value."
}, },
"MetricType": {
"type": "string",
"enum": [
"counter",
"up_down_counter",
"histogram",
"gauge"
],
"title": "MetricType",
"description": "The type of metric being recorded."
},
"SpanEndPayload": { "SpanEndPayload": {
"type": "object", "type": "object",
"properties": { "properties": {

View file

@ -10122,6 +10122,10 @@ components:
type: string type: string
description: >- description: >-
The unit of measurement for the metric value The unit of measurement for the metric value
metric_type:
$ref: '#/components/schemas/MetricType'
description: >-
The type of metric (optional, inferred if not provided for backwards compatibility)
additionalProperties: false additionalProperties: false
required: required:
- trace_id - trace_id
@ -10134,6 +10138,15 @@ components:
title: MetricEvent title: MetricEvent
description: >- description: >-
A metric event containing a measured value. A metric event containing a measured value.
MetricType:
type: string
enum:
- counter
- up_down_counter
- histogram
- gauge
title: MetricType
description: The type of metric being recorded.
SpanEndPayload: SpanEndPayload:
type: object type: object
properties: properties:

View file

@ -90,6 +90,21 @@ class EventType(Enum):
METRIC = "metric" METRIC = "metric"
@json_schema_type
class MetricType(Enum):
"""The type of metric being recorded.
:cvar COUNTER: A counter metric that only increases (e.g., requests_total)
:cvar UP_DOWN_COUNTER: A counter that can increase or decrease (e.g., active_connections)
:cvar HISTOGRAM: A histogram metric for measuring distributions (e.g., request_duration_seconds)
:cvar GAUGE: A gauge metric for point-in-time values (e.g., cpu_usage_percent)
"""
COUNTER = "counter"
UP_DOWN_COUNTER = "up_down_counter"
HISTOGRAM = "histogram"
GAUGE = "gauge"
@json_schema_type @json_schema_type
class LogSeverity(Enum): class LogSeverity(Enum):
"""The severity level of a log message. """The severity level of a log message.
@ -143,12 +158,14 @@ class MetricEvent(EventCommon):
:param metric: The name of the metric being measured :param metric: The name of the metric being measured
:param value: The numeric value of the metric measurement :param value: The numeric value of the metric measurement
:param unit: The unit of measurement for the metric value :param unit: The unit of measurement for the metric value
:param metric_type: The type of metric (optional, inferred if not provided for backwards compatibility)
""" """
type: Literal[EventType.METRIC] = EventType.METRIC type: Literal[EventType.METRIC] = EventType.METRIC
metric: str # this would be an enum metric: str # this would be an enum
value: int | float value: int | float
unit: str unit: str
metric_type: MetricType | None = None
@json_schema_type @json_schema_type

View file

@ -4,17 +4,20 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
import copy import copy
import json import json
import re import re
import secrets import secrets
import string import string
import time
import uuid import uuid
import warnings import warnings
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from datetime import UTC, datetime from datetime import UTC, datetime
import httpx import httpx
from opentelemetry.trace import get_current_span
from llama_stack.apis.agents import ( from llama_stack.apis.agents import (
AgentConfig, AgentConfig,
@ -60,6 +63,7 @@ from llama_stack.apis.inference import (
UserMessage, UserMessage,
) )
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.telemetry import MetricEvent, MetricType, Telemetry
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.datatypes import AccessRule from llama_stack.core.datatypes import AccessRule
@ -97,6 +101,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_runtime_api: ToolRuntime, tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups, tool_groups_api: ToolGroups,
vector_io_api: VectorIO, vector_io_api: VectorIO,
telemetry_api: Telemetry | None,
persistence_store: KVStore, persistence_store: KVStore,
created_at: str, created_at: str,
policy: list[AccessRule], policy: list[AccessRule],
@ -106,6 +111,7 @@ class ChatAgent(ShieldRunnerMixin):
self.inference_api = inference_api self.inference_api = inference_api
self.safety_api = safety_api self.safety_api = safety_api
self.vector_io_api = vector_io_api self.vector_io_api = vector_io_api
self.telemetry_api = telemetry_api
self.storage = AgentPersistence(agent_id, persistence_store, policy) self.storage = AgentPersistence(agent_id, persistence_store, policy)
self.tool_runtime_api = tool_runtime_api self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api self.tool_groups_api = tool_groups_api
@ -118,6 +124,9 @@ class ChatAgent(ShieldRunnerMixin):
output_shields=agent_config.output_shields, output_shields=agent_config.output_shields,
) )
# Initialize workflow start time to None
self._workflow_start_time: float | None = None
def turn_to_messages(self, turn: Turn) -> list[Message]: def turn_to_messages(self, turn: Turn) -> list[Message]:
messages = [] messages = []
@ -167,6 +176,72 @@ class ChatAgent(ShieldRunnerMixin):
async def create_session(self, name: str) -> str: async def create_session(self, name: str) -> str:
return await self.storage.create_session(name) return await self.storage.create_session(name)
def _emit_metric(
self,
metric_name: str,
value: int | float,
unit: str,
attributes: dict[str, str] | None = None,
metric_type: MetricType | None = None,
) -> None:
"""Emit a single metric event"""
logger.info(f"_emit_metric called: {metric_name} = {value} {unit}")
if not self.telemetry_api:
logger.warning(f"No telemetry_api available for metric {metric_name}")
return
span = get_current_span()
if not span:
logger.warning(f"No current span available for metric {metric_name}")
return
context = span.get_span_context()
metric = MetricEvent(
trace_id=format(context.trace_id, "x"),
span_id=format(context.span_id, "x"),
metric=metric_name,
value=value,
timestamp=time.time(),
unit=unit,
attributes={"agent_id": self.agent_id, **(attributes or {})},
metric_type=metric_type,
)
# Create task with name for better debugging and capture any async errors
task_name = f"metric-{metric_name}-{self.agent_id}"
logger.info(f"Creating telemetry task: {task_name}")
task = asyncio.create_task(self.telemetry_api.log_event(metric), name=task_name)
def _on_metric_task_done(t: asyncio.Task) -> None:
try:
exc = t.exception()
except asyncio.CancelledError:
logger.debug("Metric task %s was cancelled", task_name)
return
if exc is not None:
logger.warning("Metric task %s failed: %s", task_name, exc)
# Only add callback if task creation succeeded (not None from mocking)
if task is not None:
task.add_done_callback(_on_metric_task_done)
def _track_step(self):
logger.info("_track_step called")
self._emit_metric("llama_stack_agent_steps_total", 1, "1", metric_type=MetricType.COUNTER)
def _track_workflow(self, status: str, duration: float):
logger.info(f"_track_workflow called: status={status}, duration={duration:.2f}s")
self._emit_metric("llama_stack_agent_workflows_total", 1, "1", {"status": status}, MetricType.COUNTER)
self._emit_metric(
"llama_stack_agent_workflow_duration_seconds", duration, "s", metric_type=MetricType.HISTOGRAM
)
def _track_tool(self, tool_name: str):
logger.info(f"_track_tool called: {tool_name}")
normalized_name = "rag" if tool_name == "knowledge_search" else tool_name
self._emit_metric("llama_stack_agent_tool_calls_total", 1, "1", {"tool": normalized_name}, MetricType.COUNTER)
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]: async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
messages = [] messages = []
if self.agent_config.instructions != "": if self.agent_config.instructions != "":
@ -201,6 +276,9 @@ class ChatAgent(ShieldRunnerMixin):
if self.agent_config.name: if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name) span.set_attribute("agent_name", self.agent_config.name)
# Set workflow start time for resume operations
self._workflow_start_time = time.time()
await self._initialize_tools() await self._initialize_tools()
async for chunk in self._run_turn(request): async for chunk in self._run_turn(request):
yield chunk yield chunk
@ -212,6 +290,9 @@ class ChatAgent(ShieldRunnerMixin):
) -> AsyncGenerator: ) -> AsyncGenerator:
assert request.stream is True, "Non-streaming not supported" assert request.stream is True, "Non-streaming not supported"
# Track workflow start time for metrics
self._workflow_start_time = time.time()
is_resume = isinstance(request, AgentTurnResumeRequest) is_resume = isinstance(request, AgentTurnResumeRequest)
session_info = await self.storage.get_session_info(request.session_id) session_info = await self.storage.get_session_info(request.session_id)
if session_info is None: if session_info is None:
@ -313,6 +394,10 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
else: else:
# Track workflow completion when turn is actually complete
workflow_duration = time.time() - (self._workflow_start_time or time.time())
self._track_workflow("completed", workflow_duration)
chunk = AgentTurnResponseStreamChunk( chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload( payload=AgentTurnResponseTurnCompletePayload(
@ -726,6 +811,10 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
# Track step execution metric
self._track_step()
self._track_tool(tool_call.tool_name)
# Add the result message to input_messages for the next iteration # Add the result message to input_messages for the next iteration
input_messages.append(result_message) input_messages.append(result_message)
@ -900,6 +989,7 @@ class ChatAgent(ShieldRunnerMixin):
}, },
) )
logger.debug(f"tool call {tool_name_str} completed with result: {result}") logger.debug(f"tool call {tool_name_str} completed with result: {result}")
return result return result

View file

@ -38,6 +38,7 @@ from llama_stack.apis.inference import (
UserMessage, UserMessage,
) )
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.datatypes import AccessRule from llama_stack.core.datatypes import AccessRule
@ -64,6 +65,7 @@ class MetaReferenceAgentsImpl(Agents):
tool_runtime_api: ToolRuntime, tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups, tool_groups_api: ToolGroups,
policy: list[AccessRule], policy: list[AccessRule],
telemetry_api: Telemetry | None = None,
): ):
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
@ -71,6 +73,7 @@ class MetaReferenceAgentsImpl(Agents):
self.safety_api = safety_api self.safety_api = safety_api
self.tool_runtime_api = tool_runtime_api self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api self.tool_groups_api = tool_groups_api
self.telemetry_api = telemetry_api
self.in_memory_store = InmemoryKVStoreImpl() self.in_memory_store = InmemoryKVStoreImpl()
self.openai_responses_impl: OpenAIResponsesImpl | None = None self.openai_responses_impl: OpenAIResponsesImpl | None = None
@ -130,6 +133,7 @@ class MetaReferenceAgentsImpl(Agents):
vector_io_api=self.vector_io_api, vector_io_api=self.vector_io_api,
tool_runtime_api=self.tool_runtime_api, tool_runtime_api=self.tool_runtime_api,
tool_groups_api=self.tool_groups_api, tool_groups_api=self.tool_groups_api,
telemetry_api=self.telemetry_api,
persistence_store=( persistence_store=(
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
), ),

View file

@ -12,7 +12,9 @@ from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.metrics import MeterProvider from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics._internal.aggregation import ExplicitBucketHistogramAggregation
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.sdk.metrics.view import View
from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.sdk.trace.export import BatchSpanProcessor
@ -24,6 +26,7 @@ from llama_stack.apis.telemetry import (
MetricEvent, MetricEvent,
MetricLabelMatcher, MetricLabelMatcher,
MetricQueryType, MetricQueryType,
MetricType,
QueryCondition, QueryCondition,
QueryMetricsResponse, QueryMetricsResponse,
QuerySpanTreeResponse, QuerySpanTreeResponse,
@ -56,6 +59,7 @@ _GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
"counters": {}, "counters": {},
"gauges": {}, "gauges": {},
"up_down_counters": {}, "up_down_counters": {},
"histograms": {},
} }
_global_lock = threading.Lock() _global_lock = threading.Lock()
_TRACER_PROVIDER = None _TRACER_PROVIDER = None
@ -108,7 +112,17 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
if TelemetrySink.OTEL_METRIC in self.config.sinks: if TelemetrySink.OTEL_METRIC in self.config.sinks:
metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter()) metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter())
metric_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
# decent default buckets for agent workflow timings
hist_buckets = [0.1, 0.5, 1.0, 2.5, 5.0, 10.0, 25.0, 50.0, 100.0]
views = [
View(
instrument_type=metrics.Histogram,
aggregation=ExplicitBucketHistogramAggregation(boundaries=hist_buckets),
)
]
metric_provider = MeterProvider(resource=resource, metric_readers=[metric_reader], views=views)
metrics.set_meter_provider(metric_provider) metrics.set_meter_provider(metric_provider)
if TelemetrySink.SQLITE in self.config.sinks: if TelemetrySink.SQLITE in self.config.sinks:
@ -138,8 +152,6 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
self._log_metric(event) self._log_metric(event)
elif isinstance(event, StructuredLogEvent): elif isinstance(event, StructuredLogEvent):
self._log_structured(event, ttl_seconds) self._log_structured(event, ttl_seconds)
else:
raise ValueError(f"Unknown event type: {event}")
async def query_metrics( async def query_metrics(
self, self,
@ -209,7 +221,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
_GLOBAL_STORAGE["counters"][name] = self.meter.create_counter( _GLOBAL_STORAGE["counters"][name] = self.meter.create_counter(
name=name, name=name,
unit=unit, unit=unit,
description=f"Counter for {name}", description=name.replace("_", " "),
) )
return _GLOBAL_STORAGE["counters"][name] return _GLOBAL_STORAGE["counters"][name]
@ -219,7 +231,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge( _GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
name=name, name=name,
unit=unit, unit=unit,
description=f"Gauge for {name}", description=name.replace("_", " "),
) )
return _GLOBAL_STORAGE["gauges"][name] return _GLOBAL_STORAGE["gauges"][name]
@ -258,12 +270,19 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
# Log to OpenTelemetry meter if available # Log to OpenTelemetry meter if available
if self.meter is None: if self.meter is None:
return return
if isinstance(event.value, int):
counter = self._get_or_create_counter(event.metric, event.unit) if event.metric_type == MetricType.HISTOGRAM:
counter.add(event.value, attributes=event.attributes) histogram = self._get_or_create_histogram(
elif isinstance(event.value, float): event.metric,
event.unit,
)
histogram.record(event.value, attributes=event.attributes)
elif event.metric_type == MetricType.UP_DOWN_COUNTER:
up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit) up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit)
up_down_counter.add(event.value, attributes=event.attributes) up_down_counter.add(event.value, attributes=event.attributes)
else:
counter = self._get_or_create_counter(event.metric, event.unit)
counter.add(event.value, attributes=event.attributes)
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter: def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
assert self.meter is not None assert self.meter is not None
@ -271,10 +290,20 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
_GLOBAL_STORAGE["up_down_counters"][name] = self.meter.create_up_down_counter( _GLOBAL_STORAGE["up_down_counters"][name] = self.meter.create_up_down_counter(
name=name, name=name,
unit=unit, unit=unit,
description=f"UpDownCounter for {name}", description=name.replace("_", " "),
) )
return _GLOBAL_STORAGE["up_down_counters"][name] return _GLOBAL_STORAGE["up_down_counters"][name]
def _get_or_create_histogram(self, name: str, unit: str) -> metrics.Histogram:
assert self.meter is not None
if name not in _GLOBAL_STORAGE["histograms"]:
_GLOBAL_STORAGE["histograms"][name] = self.meter.create_histogram(
name=name,
unit=unit,
description=name.replace("_", " "),
)
return _GLOBAL_STORAGE["histograms"][name]
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None: def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
with self._lock: with self._lock:
span_id = int(event.span_id, 16) span_id = int(event.span_id, 16)

View file

@ -35,6 +35,7 @@ def available_providers() -> list[ProviderSpec]:
Api.vector_dbs, Api.vector_dbs,
Api.tool_runtime, Api.tool_runtime,
Api.tool_groups, Api.tool_groups,
Api.telemetry,
], ],
description="Meta's reference implementation of an agent system that can use tools, access vector databases, and perform complex reasoning tasks.", description="Meta's reference implementation of an agent system that can use tools, access vector databases, and perform complex reasoning tasks.",
), ),

View file

@ -25,8 +25,8 @@ classifiers = [
] ]
dependencies = [ dependencies = [
"aiohttp", "aiohttp",
"fastapi>=0.115.0,<1.0", # server "fastapi>=0.115.0,<1.0", # server
"fire", # for MCP in LLS client "fire", # for MCP in LLS client
"httpx", "httpx",
"huggingface-hub>=0.34.0,<1.0", "huggingface-hub>=0.34.0,<1.0",
"jinja2>=3.1.6", "jinja2>=3.1.6",
@ -43,12 +43,12 @@ dependencies = [
"tiktoken", "tiktoken",
"pillow", "pillow",
"h11>=0.16.0", "h11>=0.16.0",
"python-multipart>=0.0.20", # For fastapi Form "python-multipart>=0.0.20", # For fastapi Form
"uvicorn>=0.34.0", # server "uvicorn>=0.34.0", # server
"opentelemetry-sdk>=1.30.0", # server "opentelemetry-sdk>=1.30.0", # server
"opentelemetry-exporter-otlp-proto-http>=1.30.0", # server "opentelemetry-exporter-otlp-proto-http>=1.30.0", # server
"aiosqlite>=0.21.0", # server - for metadata store "aiosqlite>=0.21.0", # server - for metadata store
"asyncpg", # for metadata store "asyncpg", # for metadata store
] ]
[project.optional-dependencies] [project.optional-dependencies]

View file

@ -0,0 +1,170 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator, Callable
from pathlib import Path
from typing import Any
from unittest.mock import Mock, patch
import pytest
from llama_stack.apis.inference import ToolDefinition
from llama_stack.apis.tools import ToolInvocationResult
from llama_stack.providers.inline.agents.meta_reference.agent_instance import ChatAgent
from llama_stack.providers.inline.telemetry.meta_reference.config import (
TelemetryConfig,
TelemetrySink,
)
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
TelemetryAdapter,
)
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.sqlite.sqlite import SqliteKVStoreImpl
from llama_stack.providers.utils.telemetry import tracing as telemetry_tracing
@pytest.fixture
def make_agent_fixture():
def _make(telemetry, kvstore) -> ChatAgent:
agent = ChatAgent(
agent_id="test-agent",
agent_config=Mock(),
inference_api=Mock(),
safety_api=Mock(),
tool_runtime_api=Mock(),
tool_groups_api=Mock(),
vector_io_api=Mock(),
telemetry_api=telemetry,
persistence_store=kvstore,
created_at="2025-01-01T00:00:00Z",
policy=[],
)
agent.agent_config.client_tools = []
agent.agent_config.max_infer_iters = 5
agent.input_shields = []
agent.output_shields = []
agent.tool_defs = [
ToolDefinition(tool_name="web_search", description="", parameters={}),
ToolDefinition(tool_name="knowledge_search", description="", parameters={}),
]
agent.tool_name_to_args = {}
# Stub tool runtime invoke_tool
async def _mock_invoke_tool(
*args: Any,
tool_name: str | None = None,
kwargs: dict | None = None,
**extra: Any,
):
return ToolInvocationResult(content="Tool execution result")
agent.tool_runtime_api.invoke_tool = _mock_invoke_tool
return agent
return _make
def _chat_stream(tool_name: str | None, content: str = ""):
from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
StopReason,
)
from llama_stack.models.llama.datatypes import ToolCall
async def gen():
# Start
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
)
)
# Content
if content:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=TextDelta(text=content),
)
)
# Tool call if specified
if tool_name:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call=ToolCall(call_id="call_0", tool_name=tool_name, arguments={}),
parse_status=ToolCallParseStatus.succeeded,
),
)
)
# Complete
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=""),
stop_reason=StopReason.end_of_turn,
)
)
return gen()
@pytest.fixture
async def telemetry(tmp_path: Path) -> AsyncGenerator[TelemetryAdapter, None]:
db_path = tmp_path / "trace_store.db"
cfg = TelemetryConfig(
sinks=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE],
sqlite_db_path=str(db_path),
)
telemetry = TelemetryAdapter(cfg, deps={})
telemetry_tracing.setup_logger(telemetry)
try:
yield telemetry
finally:
await telemetry.shutdown()
@pytest.fixture
async def kvstore(tmp_path: Path) -> SqliteKVStoreImpl:
kv_path = tmp_path / "agent_kvstore.db"
kv = SqliteKVStoreImpl(SqliteKVStoreConfig(db_path=str(kv_path)))
await kv.initialize()
return kv
@pytest.fixture
def span_patch():
with (
patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.get_current_span") as mock_span,
patch(
"llama_stack.providers.utils.telemetry.tracing.generate_span_id",
return_value="0000000000000abc",
),
):
mock_span.return_value = Mock(get_span_context=Mock(return_value=Mock(trace_id=0x123, span_id=0xABC)))
yield
@pytest.fixture
def make_completion_fn() -> Callable[[str | None, str], Callable]:
def _factory(tool_name: str | None = None, content: str = "") -> Callable:
async def chat_completion(*args: Any, **kwargs: Any):
return _chat_stream(tool_name, content)
return chat_completion
return _factory

View file

@ -0,0 +1,83 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import Any
from llama_stack.providers.utils.telemetry import tracing as telemetry_tracing
class TestAgentMetricsIntegration:
async def test_agent_metrics_end_to_end(
self: Any,
telemetry: Any,
kvstore: Any,
make_agent_fixture: Any,
span_patch: Any,
make_completion_fn: Any,
) -> None:
from llama_stack.apis.inference import (
SamplingParams,
UserMessage,
)
agent: Any = make_agent_fixture(telemetry, kvstore)
session_id = await agent.create_session("s")
sampling_params = SamplingParams(max_tokens=64)
# single trace: plain, knowledge_search, web_search
await telemetry_tracing.start_trace("agent_metrics")
agent.inference_api.chat_completion = make_completion_fn(None, "Hello! I can help you with that.")
async for _ in agent.run(
session_id,
"t1",
[UserMessage(content="Hello")],
sampling_params,
stream=True,
):
pass
agent.inference_api.chat_completion = make_completion_fn("knowledge_search", "")
async for _ in agent.run(
session_id,
"t2",
[UserMessage(content="Please search knowledge")],
sampling_params,
stream=True,
):
pass
agent.inference_api.chat_completion = make_completion_fn("web_search", "")
async for _ in agent.run(
session_id,
"t3",
[UserMessage(content="Please search web")],
sampling_params,
stream=True,
):
pass
await telemetry_tracing.end_trace()
# Poll briefly to avoid flake with async persistence
tool_labels: set[str] = set()
for _ in range(10):
resp = await telemetry.query_metrics("llama_stack_agent_tool_calls_total", start_time=0, end_time=None)
tool_labels.clear()
for series in getattr(resp, "data", []) or []:
for lbl in getattr(series, "labels", []) or []:
name = getattr(lbl, "name", None) or getattr(lbl, "key", None)
value = getattr(lbl, "value", None)
if name == "tool" and value:
tool_labels.add(value)
# Look for both web_search AND some form of knowledge search
if ("web_search" in tool_labels) and ("rag" in tool_labels or "knowledge_search" in tool_labels):
break
await asyncio.sleep(0.1)
# More descriptive assertion
assert bool(tool_labels & {"web_search", "rag", "knowledge_search"}), (
f"Expected tool calls not found. Got: {tool_labels}"
)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,212 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from unittest.mock import AsyncMock, Mock
import pytest
from opentelemetry.trace import SpanContext, TraceFlags
from llama_stack.providers.inline.agents.meta_reference.agent_instance import ChatAgent
class FakeSpan:
def __init__(self, trace_id: int = 123, span_id: int = 456):
self._context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(0x01),
)
def get_span_context(self):
return self._context
@pytest.fixture
def agent_with_telemetry():
"""Create a real ChatAgent with telemetry API"""
telemetry_api = AsyncMock()
agent = ChatAgent(
agent_id="test-agent",
agent_config=Mock(),
inference_api=Mock(),
safety_api=Mock(),
tool_runtime_api=Mock(),
tool_groups_api=Mock(),
vector_io_api=Mock(),
telemetry_api=telemetry_api,
persistence_store=Mock(),
created_at="2025-01-01T00:00:00Z",
policy=[],
)
return agent
@pytest.fixture
def agent_without_telemetry():
"""Create a real ChatAgent without telemetry API"""
agent = ChatAgent(
agent_id="test-agent",
agent_config=Mock(),
inference_api=Mock(),
safety_api=Mock(),
tool_runtime_api=Mock(),
tool_groups_api=Mock(),
vector_io_api=Mock(),
telemetry_api=None,
persistence_store=Mock(),
created_at="2025-01-01T00:00:00Z",
policy=[],
)
return agent
class TestAgentMetrics:
def test_step_execution_metrics(self, agent_with_telemetry, monkeypatch):
"""Test that step execution metrics are emitted correctly"""
fake_span = FakeSpan()
monkeypatch.setattr(
"llama_stack.providers.inline.agents.meta_reference.agent_instance.get_current_span", lambda: fake_span
)
# Capture the metric instead of actually creating async task
captured_metrics = []
async def capture_metric(metric):
captured_metrics.append(metric)
monkeypatch.setattr(agent_with_telemetry.telemetry_api, "log_event", capture_metric)
def mock_create_task(coro, *, name=None):
return asyncio.run(coro)
monkeypatch.setattr(
"llama_stack.providers.inline.agents.meta_reference.agent_instance.asyncio.create_task", mock_create_task
)
agent_with_telemetry._track_step()
assert len(captured_metrics) == 1
metric = captured_metrics[0]
assert metric.metric == "llama_stack_agent_steps_total"
assert metric.value == 1
assert metric.unit == "1"
assert metric.attributes["agent_id"] == "test-agent"
def test_workflow_completion_metrics(self, agent_with_telemetry, monkeypatch):
"""Test that workflow completion metrics are emitted correctly"""
fake_span = FakeSpan()
monkeypatch.setattr(
"llama_stack.providers.inline.agents.meta_reference.agent_instance.get_current_span", lambda: fake_span
)
captured_metrics = []
async def capture_metric(metric):
captured_metrics.append(metric)
monkeypatch.setattr(agent_with_telemetry.telemetry_api, "log_event", capture_metric)
def mock_create_task(coro, *, name=None):
return asyncio.run(coro)
monkeypatch.setattr(
"llama_stack.providers.inline.agents.meta_reference.agent_instance.asyncio.create_task", mock_create_task
)
agent_with_telemetry._track_workflow("completed", 2.5)
assert len(captured_metrics) == 2
# Check workflow count metric
count_metric = captured_metrics[0]
assert count_metric.metric == "llama_stack_agent_workflows_total"
assert count_metric.value == 1
assert count_metric.attributes["status"] == "completed"
# Check duration metric
duration_metric = captured_metrics[1]
assert duration_metric.metric == "llama_stack_agent_workflow_duration_seconds"
assert duration_metric.value == 2.5
assert duration_metric.unit == "s"
def test_tool_usage_metrics(self, agent_with_telemetry, monkeypatch):
"""Test that tool usage metrics are emitted correctly"""
fake_span = FakeSpan()
monkeypatch.setattr(
"llama_stack.providers.inline.agents.meta_reference.agent_instance.get_current_span", lambda: fake_span
)
captured_metrics = []
async def capture_metric(metric):
captured_metrics.append(metric)
monkeypatch.setattr(agent_with_telemetry.telemetry_api, "log_event", capture_metric)
def mock_create_task(coro, *, name=None):
return asyncio.run(coro)
monkeypatch.setattr(
"llama_stack.providers.inline.agents.meta_reference.agent_instance.asyncio.create_task", mock_create_task
)
agent_with_telemetry._track_tool("web_search")
assert len(captured_metrics) == 1
metric = captured_metrics[0]
assert metric.metric == "llama_stack_agent_tool_calls_total"
assert metric.attributes["tool"] == "web_search"
def test_knowledge_search_tool_mapping(self, agent_with_telemetry, monkeypatch):
"""Test that knowledge_search tool is mapped to rag"""
fake_span = FakeSpan()
monkeypatch.setattr(
"llama_stack.providers.inline.agents.meta_reference.agent_instance.get_current_span", lambda: fake_span
)
captured_metrics = []
async def capture_metric(metric):
captured_metrics.append(metric)
monkeypatch.setattr(agent_with_telemetry.telemetry_api, "log_event", capture_metric)
def mock_create_task(coro, *, name=None):
return asyncio.run(coro)
monkeypatch.setattr(
"llama_stack.providers.inline.agents.meta_reference.agent_instance.asyncio.create_task", mock_create_task
)
agent_with_telemetry._track_tool("knowledge_search")
assert len(captured_metrics) == 1
metric = captured_metrics[0]
assert metric.attributes["tool"] == "rag"
def test_no_telemetry_api(self, agent_without_telemetry):
"""Test that methods work gracefully when telemetry_api is None"""
# These should not crash
agent_without_telemetry._track_step()
agent_without_telemetry._track_workflow("failed", 1.0)
agent_without_telemetry._track_tool("web_search")
def test_no_active_span(self, agent_with_telemetry, monkeypatch):
"""Test that methods work gracefully when no span is active"""
monkeypatch.setattr(
"llama_stack.providers.inline.agents.meta_reference.agent_instance.get_current_span", lambda: None
)
# These should not crash and should not call telemetry
agent_with_telemetry._track_step()
agent_with_telemetry._track_workflow("failed", 1.0)
agent_with_telemetry._track_tool("web_search")
# Telemetry should not have been called
agent_with_telemetry.telemetry_api.log_event.assert_not_called()

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,244 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import Mock
import pytest
from llama_stack.apis.telemetry import MetricEvent, MetricType
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import TelemetryAdapter
class TestAgentMetricsHistogram:
"""Tests for agent histogram metrics"""
@pytest.fixture
def config(self):
return TelemetryConfig(service_name="test-service", sinks=[])
@pytest.fixture
def adapter(self, config):
adapter = TelemetryAdapter(config, {})
adapter.meter = Mock() # skip otel setup
return adapter
def test_histogram_creation(self, adapter):
mock_hist = Mock()
adapter.meter.create_histogram.return_value = mock_hist
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE
_GLOBAL_STORAGE["histograms"] = {}
result = adapter._get_or_create_histogram("test_histogram", "s")
assert result == mock_hist
adapter.meter.create_histogram.assert_called_once_with(
name="test_histogram",
unit="s",
description="test histogram",
)
assert _GLOBAL_STORAGE["histograms"]["test_histogram"] == mock_hist
def test_histogram_reuse(self, adapter):
mock_hist = Mock()
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE
_GLOBAL_STORAGE["histograms"] = {"existing_histogram": mock_hist}
result = adapter._get_or_create_histogram("existing_histogram", "ms")
assert result == mock_hist
adapter.meter.create_histogram.assert_not_called()
def test_workflow_duration_histogram(self, adapter):
mock_hist = Mock()
adapter.meter.create_histogram.return_value = mock_hist
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE
_GLOBAL_STORAGE["histograms"] = {}
event = MetricEvent(
trace_id="123",
span_id="456",
metric="llama_stack_agent_workflow_duration_seconds",
value=15.7,
timestamp=1234567890.0,
unit="s",
attributes={"agent_id": "test-agent"},
metric_type=MetricType.HISTOGRAM,
)
adapter._log_metric(event)
adapter.meter.create_histogram.assert_called_once_with(
name="llama_stack_agent_workflow_duration_seconds",
unit="s",
description="llama stack agent workflow duration seconds",
)
mock_hist.record.assert_called_once_with(15.7, attributes={"agent_id": "test-agent"})
def test_duration_buckets_configured_via_views(self, adapter):
mock_hist = Mock()
adapter.meter.create_histogram.return_value = mock_hist
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE
_GLOBAL_STORAGE["histograms"] = {}
event = MetricEvent(
trace_id="123",
span_id="456",
metric="custom_duration_seconds",
value=5.2,
timestamp=1234567890.0,
unit="s",
attributes={},
metric_type=MetricType.HISTOGRAM,
)
adapter._log_metric(event)
# buckets configured via otel views, not passed to create_histogram
mock_hist.record.assert_called_once_with(5.2, attributes={})
def test_non_duration_uses_counter(self, adapter):
mock_counter = Mock()
adapter.meter.create_counter.return_value = mock_counter
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE
_GLOBAL_STORAGE["counters"] = {}
event = MetricEvent(
trace_id="123",
span_id="456",
metric="llama_stack_agent_workflows_total",
value=1,
timestamp=1234567890.0,
unit="1",
attributes={"agent_id": "test-agent", "status": "completed"},
)
adapter._log_metric(event)
adapter.meter.create_counter.assert_called_once()
adapter.meter.create_histogram.assert_not_called()
mock_counter.add.assert_called_once_with(1, attributes={"agent_id": "test-agent", "status": "completed"})
def test_no_meter_doesnt_crash(self, adapter):
adapter.meter = None
event = MetricEvent(
trace_id="123",
span_id="456",
metric="test_duration_seconds",
value=1.0,
timestamp=1234567890.0,
unit="s",
attributes={},
)
adapter._log_metric(event) # shouldn't crash
def test_histogram_vs_counter_by_type(self, adapter):
mock_hist = Mock()
mock_counter = Mock()
adapter.meter.create_histogram.return_value = mock_hist
adapter.meter.create_counter.return_value = mock_counter
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE
_GLOBAL_STORAGE["histograms"] = {}
_GLOBAL_STORAGE["counters"] = {}
# histogram metric
hist_event = MetricEvent(
trace_id="123",
span_id="456",
metric="workflow_duration_seconds",
value=1.0,
timestamp=1234567890.0,
unit="s",
attributes={},
metric_type=MetricType.HISTOGRAM,
)
adapter._log_metric(hist_event)
mock_hist.record.assert_called()
# counter metric (default type)
counter_event = MetricEvent(
trace_id="123",
span_id="456",
metric="workflow_total",
value=1,
timestamp=1234567890.0,
unit="1",
attributes={},
)
adapter._log_metric(counter_event)
mock_counter.add.assert_called()
def test_storage_separation(self, adapter):
mock_hist = Mock()
mock_counter = Mock()
adapter.meter.create_histogram.return_value = mock_hist
adapter.meter.create_counter.return_value = mock_counter
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE
_GLOBAL_STORAGE["histograms"] = {}
_GLOBAL_STORAGE["counters"] = {}
# create both types
hist_event = MetricEvent(
trace_id="123",
span_id="456",
metric="test_duration_seconds",
value=1.0,
timestamp=1234567890.0,
unit="s",
attributes={},
metric_type=MetricType.HISTOGRAM,
)
counter_event = MetricEvent(
trace_id="123",
span_id="456",
metric="test_counter",
value=1,
timestamp=1234567890.0,
unit="1",
attributes={},
)
adapter._log_metric(hist_event)
adapter._log_metric(counter_event)
# check they're stored separately
assert "test_duration_seconds" in _GLOBAL_STORAGE["histograms"]
assert "test_counter" in _GLOBAL_STORAGE["counters"]
assert "test_duration_seconds" not in _GLOBAL_STORAGE["counters"]
assert "test_counter" not in _GLOBAL_STORAGE["histograms"]
def test_histogram_uses_views_for_buckets(self, adapter):
mock_hist = Mock()
adapter.meter.create_histogram.return_value = mock_hist
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE
_GLOBAL_STORAGE["histograms"] = {}
result = adapter._get_or_create_histogram("test_histogram", "s")
# buckets come from otel views, not create_histogram params
adapter.meter.create_histogram.assert_called_once_with(
name="test_histogram",
unit="s",
description="test histogram",
)
assert result == mock_hist