mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 06:00:48 +00:00
Merge 9f14b67f5c
into 61582f327c
This commit is contained in:
commit
03d44ac2e2
12 changed files with 655 additions and 11 deletions
15
docs/_static/llama-stack-spec.html
vendored
15
docs/_static/llama-stack-spec.html
vendored
|
@ -13092,6 +13092,10 @@
|
|||
"unit": {
|
||||
"type": "string",
|
||||
"description": "The unit of measurement for the metric value"
|
||||
},
|
||||
"metric_type": {
|
||||
"$ref": "#/components/schemas/MetricType",
|
||||
"description": "The type of metric (optional, inferred if not provided for backwards compatibility)"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
@ -13107,6 +13111,17 @@
|
|||
"title": "MetricEvent",
|
||||
"description": "A metric event containing a measured value."
|
||||
},
|
||||
"MetricType": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"counter",
|
||||
"up_down_counter",
|
||||
"histogram",
|
||||
"gauge"
|
||||
],
|
||||
"title": "MetricType",
|
||||
"description": "The type of metric being recorded."
|
||||
},
|
||||
"SpanEndPayload": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
|
13
docs/_static/llama-stack-spec.yaml
vendored
13
docs/_static/llama-stack-spec.yaml
vendored
|
@ -9736,6 +9736,10 @@ components:
|
|||
type: string
|
||||
description: >-
|
||||
The unit of measurement for the metric value
|
||||
metric_type:
|
||||
$ref: '#/components/schemas/MetricType'
|
||||
description: >-
|
||||
The type of metric (optional, inferred if not provided for backwards compatibility)
|
||||
additionalProperties: false
|
||||
required:
|
||||
- trace_id
|
||||
|
@ -9748,6 +9752,15 @@ components:
|
|||
title: MetricEvent
|
||||
description: >-
|
||||
A metric event containing a measured value.
|
||||
MetricType:
|
||||
type: string
|
||||
enum:
|
||||
- counter
|
||||
- up_down_counter
|
||||
- histogram
|
||||
- gauge
|
||||
title: MetricType
|
||||
description: The type of metric being recorded.
|
||||
SpanEndPayload:
|
||||
type: object
|
||||
properties:
|
||||
|
|
|
@ -90,6 +90,21 @@ class EventType(Enum):
|
|||
METRIC = "metric"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetricType(Enum):
|
||||
"""The type of metric being recorded.
|
||||
:cvar COUNTER: A counter metric that only increases (e.g., requests_total)
|
||||
:cvar UP_DOWN_COUNTER: A counter that can increase or decrease (e.g., active_connections)
|
||||
:cvar HISTOGRAM: A histogram metric for measuring distributions (e.g., request_duration_seconds)
|
||||
:cvar GAUGE: A gauge metric for point-in-time values (e.g., cpu_usage_percent)
|
||||
"""
|
||||
|
||||
COUNTER = "counter"
|
||||
UP_DOWN_COUNTER = "up_down_counter"
|
||||
HISTOGRAM = "histogram"
|
||||
GAUGE = "gauge"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class LogSeverity(Enum):
|
||||
"""The severity level of a log message.
|
||||
|
@ -143,12 +158,14 @@ class MetricEvent(EventCommon):
|
|||
:param metric: The name of the metric being measured
|
||||
:param value: The numeric value of the metric measurement
|
||||
:param unit: The unit of measurement for the metric value
|
||||
:param metric_type: The type of metric (optional, inferred if not provided for backwards compatibility)
|
||||
"""
|
||||
|
||||
type: Literal[EventType.METRIC] = EventType.METRIC
|
||||
metric: str # this would be an enum
|
||||
value: int | float
|
||||
unit: str
|
||||
metric_type: MetricType | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -4,17 +4,20 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
import re
|
||||
import secrets
|
||||
import string
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import httpx
|
||||
from opentelemetry.trace import get_current_span
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
|
@ -60,6 +63,7 @@ from llama_stack.apis.inference import (
|
|||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.telemetry import MetricEvent, Telemetry
|
||||
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
|
@ -97,6 +101,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_runtime_api: ToolRuntime,
|
||||
tool_groups_api: ToolGroups,
|
||||
vector_io_api: VectorIO,
|
||||
telemetry_api: Telemetry | None,
|
||||
persistence_store: KVStore,
|
||||
created_at: str,
|
||||
policy: list[AccessRule],
|
||||
|
@ -106,6 +111,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self.inference_api = inference_api
|
||||
self.safety_api = safety_api
|
||||
self.vector_io_api = vector_io_api
|
||||
self.telemetry_api = telemetry_api
|
||||
self.storage = AgentPersistence(agent_id, persistence_store, policy)
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
|
@ -167,6 +173,56 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async def create_session(self, name: str) -> str:
|
||||
return await self.storage.create_session(name)
|
||||
|
||||
def _emit_metric(
|
||||
self, metric_name: str, value: int | float, unit: str, attributes: dict[str, str] | None = None
|
||||
) -> None:
|
||||
"""Emit a single metric event"""
|
||||
if not self.telemetry_api:
|
||||
return
|
||||
|
||||
span = get_current_span()
|
||||
if not span:
|
||||
return
|
||||
|
||||
context = span.get_span_context()
|
||||
metric = MetricEvent(
|
||||
trace_id=format(context.trace_id, "x"),
|
||||
span_id=format(context.span_id, "x"),
|
||||
metric=metric_name,
|
||||
value=value,
|
||||
timestamp=time.time(),
|
||||
unit=unit,
|
||||
attributes={"agent_id": self.agent_id, **(attributes or {})},
|
||||
)
|
||||
|
||||
# Create task with name for better debugging and capture any async errors
|
||||
task_name = f"metric-{metric_name}-{self.agent_id}"
|
||||
task = asyncio.create_task(self.telemetry_api.log_event(metric), name=task_name)
|
||||
|
||||
def _on_metric_task_done(t: asyncio.Task) -> None:
|
||||
try:
|
||||
exc = t.exception()
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Metric task %s was cancelled", task_name)
|
||||
return
|
||||
if exc is not None:
|
||||
logger.warning("Metric task %s failed: %s", task_name, exc)
|
||||
|
||||
# Only add callback if task creation succeeded (not None from mocking)
|
||||
if task is not None:
|
||||
task.add_done_callback(_on_metric_task_done)
|
||||
|
||||
def _track_step(self):
|
||||
self._emit_metric("llama_stack_agent_steps_total", 1, "1")
|
||||
|
||||
def _track_workflow(self, status: str, duration: float):
|
||||
self._emit_metric("llama_stack_agent_workflows_total", 1, "1", {"status": status})
|
||||
self._emit_metric("llama_stack_agent_workflow_duration_seconds", duration, "s")
|
||||
|
||||
def _track_tool(self, tool_name: str):
|
||||
normalized_name = "rag" if tool_name == "knowledge_search" else tool_name
|
||||
self._emit_metric("llama_stack_agent_tool_calls_total", 1, "1", {"tool": normalized_name})
|
||||
|
||||
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
|
||||
messages = []
|
||||
if self.agent_config.instructions != "":
|
||||
|
@ -726,6 +782,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
|
||||
# Track step execution metric
|
||||
self._track_step()
|
||||
|
||||
# Add the result message to input_messages for the next iteration
|
||||
input_messages.append(result_message)
|
||||
|
||||
|
@ -900,6 +959,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
},
|
||||
)
|
||||
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
@ -39,6 +39,7 @@ from llama_stack.apis.inference import (
|
|||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
|
@ -64,6 +65,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tool_runtime_api: ToolRuntime,
|
||||
tool_groups_api: ToolGroups,
|
||||
policy: list[AccessRule],
|
||||
telemetry_api: Telemetry | None = None,
|
||||
):
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
|
@ -71,6 +73,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self.safety_api = safety_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.telemetry_api = telemetry_api
|
||||
|
||||
self.in_memory_store = InmemoryKVStoreImpl()
|
||||
self.openai_responses_impl: OpenAIResponsesImpl | None = None
|
||||
|
@ -130,6 +133,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
vector_io_api=self.vector_io_api,
|
||||
tool_runtime_api=self.tool_runtime_api,
|
||||
tool_groups_api=self.tool_groups_api,
|
||||
telemetry_api=self.telemetry_api,
|
||||
persistence_store=(
|
||||
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
|
||||
),
|
||||
|
|
|
@ -26,6 +26,7 @@ from llama_stack.apis.telemetry import (
|
|||
MetricEvent,
|
||||
MetricLabelMatcher,
|
||||
MetricQueryType,
|
||||
MetricType,
|
||||
QueryCondition,
|
||||
QueryMetricsResponse,
|
||||
QuerySpanTreeResponse,
|
||||
|
@ -57,6 +58,7 @@ _GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
|
|||
"counters": {},
|
||||
"gauges": {},
|
||||
"up_down_counters": {},
|
||||
"histograms": {},
|
||||
}
|
||||
_global_lock = threading.Lock()
|
||||
_TRACER_PROVIDER = None
|
||||
|
@ -227,12 +229,20 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
# Log to OpenTelemetry meter if available
|
||||
if self.meter is None:
|
||||
return
|
||||
if isinstance(event.value, int):
|
||||
counter = self._get_or_create_counter(event.metric, event.unit)
|
||||
counter.add(event.value, attributes=event.attributes)
|
||||
elif isinstance(event.value, float):
|
||||
|
||||
if event.metric_type == MetricType.HISTOGRAM:
|
||||
histogram = self._get_or_create_histogram(
|
||||
event.metric,
|
||||
event.unit,
|
||||
[0.1, 0.5, 1.0, 2.5, 5.0, 10.0, 25.0, 50.0, 100.0],
|
||||
)
|
||||
histogram.record(event.value, attributes=event.attributes)
|
||||
elif event.metric_type == MetricType.UP_DOWN_COUNTER:
|
||||
up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit)
|
||||
up_down_counter.add(event.value, attributes=event.attributes)
|
||||
else:
|
||||
counter = self._get_or_create_counter(event.metric, event.unit)
|
||||
counter.add(event.value, attributes=event.attributes)
|
||||
|
||||
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
|
||||
assert self.meter is not None
|
||||
|
@ -244,6 +254,16 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
)
|
||||
return _GLOBAL_STORAGE["up_down_counters"][name]
|
||||
|
||||
def _get_or_create_histogram(self, name: str, unit: str, buckets: list[float] | None = None) -> metrics.Histogram:
|
||||
assert self.meter is not None
|
||||
if name not in _GLOBAL_STORAGE["histograms"]:
|
||||
_GLOBAL_STORAGE["histograms"][name] = self.meter.create_histogram(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"Histogram for {name}",
|
||||
)
|
||||
return _GLOBAL_STORAGE["histograms"][name]
|
||||
|
||||
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
|
||||
with self._lock:
|
||||
span_id = int(event.span_id, 16)
|
||||
|
|
|
@ -35,6 +35,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
Api.vector_dbs,
|
||||
Api.tool_runtime,
|
||||
Api.tool_groups,
|
||||
Api.telemetry,
|
||||
],
|
||||
description="Meta's reference implementation of an agent system that can use tools, access vector databases, and perform complex reasoning tasks.",
|
||||
),
|
||||
|
|
|
@ -25,8 +25,8 @@ classifiers = [
|
|||
]
|
||||
dependencies = [
|
||||
"aiohttp",
|
||||
"fastapi>=0.115.0,<1.0", # server
|
||||
"fire", # for MCP in LLS client
|
||||
"fastapi>=0.115.0,<1.0", # server
|
||||
"fire", # for MCP in LLS client
|
||||
"httpx",
|
||||
"huggingface-hub>=0.34.0,<1.0",
|
||||
"jinja2>=3.1.6",
|
||||
|
@ -44,12 +44,12 @@ dependencies = [
|
|||
"tiktoken",
|
||||
"pillow",
|
||||
"h11>=0.16.0",
|
||||
"python-multipart>=0.0.20", # For fastapi Form
|
||||
"uvicorn>=0.34.0", # server
|
||||
"opentelemetry-sdk>=1.30.0", # server
|
||||
"python-multipart>=0.0.20", # For fastapi Form
|
||||
"uvicorn>=0.34.0", # server
|
||||
"opentelemetry-sdk>=1.30.0", # server
|
||||
"opentelemetry-exporter-otlp-proto-http>=1.30.0", # server
|
||||
"aiosqlite>=0.21.0", # server - for metadata store
|
||||
"asyncpg", # for metadata store
|
||||
"aiosqlite>=0.21.0", # server - for metadata store
|
||||
"asyncpg", # for metadata store
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
5
tests/unit/providers/agents/__init__.py
Normal file
5
tests/unit/providers/agents/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
212
tests/unit/providers/agents/test_agent_metrics.py
Normal file
212
tests/unit/providers/agents/test_agent_metrics.py
Normal file
|
@ -0,0 +1,212 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
from opentelemetry.trace import SpanContext, TraceFlags
|
||||
|
||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import ChatAgent
|
||||
|
||||
|
||||
class FakeSpan:
|
||||
def __init__(self, trace_id: int = 123, span_id: int = 456):
|
||||
self._context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(0x01),
|
||||
)
|
||||
|
||||
def get_span_context(self):
|
||||
return self._context
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_with_telemetry():
|
||||
"""Create a real ChatAgent with telemetry API"""
|
||||
telemetry_api = AsyncMock()
|
||||
|
||||
agent = ChatAgent(
|
||||
agent_id="test-agent",
|
||||
agent_config=Mock(),
|
||||
inference_api=Mock(),
|
||||
safety_api=Mock(),
|
||||
tool_runtime_api=Mock(),
|
||||
tool_groups_api=Mock(),
|
||||
vector_io_api=Mock(),
|
||||
telemetry_api=telemetry_api,
|
||||
persistence_store=Mock(),
|
||||
created_at="2025-01-01T00:00:00Z",
|
||||
policy=[],
|
||||
)
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_without_telemetry():
|
||||
"""Create a real ChatAgent without telemetry API"""
|
||||
agent = ChatAgent(
|
||||
agent_id="test-agent",
|
||||
agent_config=Mock(),
|
||||
inference_api=Mock(),
|
||||
safety_api=Mock(),
|
||||
tool_runtime_api=Mock(),
|
||||
tool_groups_api=Mock(),
|
||||
vector_io_api=Mock(),
|
||||
telemetry_api=None,
|
||||
persistence_store=Mock(),
|
||||
created_at="2025-01-01T00:00:00Z",
|
||||
policy=[],
|
||||
)
|
||||
return agent
|
||||
|
||||
|
||||
class TestAgentMetrics:
|
||||
def test_step_execution_metrics(self, agent_with_telemetry, monkeypatch):
|
||||
"""Test that step execution metrics are emitted correctly"""
|
||||
fake_span = FakeSpan()
|
||||
monkeypatch.setattr(
|
||||
"llama_stack.providers.inline.agents.meta_reference.agent_instance.get_current_span", lambda: fake_span
|
||||
)
|
||||
|
||||
# Capture the metric instead of actually creating async task
|
||||
captured_metrics = []
|
||||
|
||||
async def capture_metric(metric):
|
||||
captured_metrics.append(metric)
|
||||
|
||||
monkeypatch.setattr(agent_with_telemetry.telemetry_api, "log_event", capture_metric)
|
||||
|
||||
def mock_create_task(coro, *, name=None):
|
||||
return asyncio.run(coro)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"llama_stack.providers.inline.agents.meta_reference.agent_instance.asyncio.create_task", mock_create_task
|
||||
)
|
||||
|
||||
agent_with_telemetry._track_step()
|
||||
|
||||
assert len(captured_metrics) == 1
|
||||
metric = captured_metrics[0]
|
||||
assert metric.metric == "llama_stack_agent_steps_total"
|
||||
assert metric.value == 1
|
||||
assert metric.unit == "1"
|
||||
assert metric.attributes["agent_id"] == "test-agent"
|
||||
|
||||
def test_workflow_completion_metrics(self, agent_with_telemetry, monkeypatch):
|
||||
"""Test that workflow completion metrics are emitted correctly"""
|
||||
fake_span = FakeSpan()
|
||||
monkeypatch.setattr(
|
||||
"llama_stack.providers.inline.agents.meta_reference.agent_instance.get_current_span", lambda: fake_span
|
||||
)
|
||||
|
||||
captured_metrics = []
|
||||
|
||||
async def capture_metric(metric):
|
||||
captured_metrics.append(metric)
|
||||
|
||||
monkeypatch.setattr(agent_with_telemetry.telemetry_api, "log_event", capture_metric)
|
||||
|
||||
def mock_create_task(coro, *, name=None):
|
||||
return asyncio.run(coro)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"llama_stack.providers.inline.agents.meta_reference.agent_instance.asyncio.create_task", mock_create_task
|
||||
)
|
||||
|
||||
agent_with_telemetry._track_workflow("completed", 2.5)
|
||||
|
||||
assert len(captured_metrics) == 2
|
||||
|
||||
# Check workflow count metric
|
||||
count_metric = captured_metrics[0]
|
||||
assert count_metric.metric == "llama_stack_agent_workflows_total"
|
||||
assert count_metric.value == 1
|
||||
assert count_metric.attributes["status"] == "completed"
|
||||
|
||||
# Check duration metric
|
||||
duration_metric = captured_metrics[1]
|
||||
assert duration_metric.metric == "llama_stack_agent_workflow_duration_seconds"
|
||||
assert duration_metric.value == 2.5
|
||||
assert duration_metric.unit == "s"
|
||||
|
||||
def test_tool_usage_metrics(self, agent_with_telemetry, monkeypatch):
|
||||
"""Test that tool usage metrics are emitted correctly"""
|
||||
fake_span = FakeSpan()
|
||||
monkeypatch.setattr(
|
||||
"llama_stack.providers.inline.agents.meta_reference.agent_instance.get_current_span", lambda: fake_span
|
||||
)
|
||||
|
||||
captured_metrics = []
|
||||
|
||||
async def capture_metric(metric):
|
||||
captured_metrics.append(metric)
|
||||
|
||||
monkeypatch.setattr(agent_with_telemetry.telemetry_api, "log_event", capture_metric)
|
||||
|
||||
def mock_create_task(coro, *, name=None):
|
||||
return asyncio.run(coro)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"llama_stack.providers.inline.agents.meta_reference.agent_instance.asyncio.create_task", mock_create_task
|
||||
)
|
||||
|
||||
agent_with_telemetry._track_tool("web_search")
|
||||
|
||||
assert len(captured_metrics) == 1
|
||||
metric = captured_metrics[0]
|
||||
assert metric.metric == "llama_stack_agent_tool_calls_total"
|
||||
assert metric.attributes["tool"] == "web_search"
|
||||
|
||||
def test_knowledge_search_tool_mapping(self, agent_with_telemetry, monkeypatch):
|
||||
"""Test that knowledge_search tool is mapped to rag"""
|
||||
fake_span = FakeSpan()
|
||||
monkeypatch.setattr(
|
||||
"llama_stack.providers.inline.agents.meta_reference.agent_instance.get_current_span", lambda: fake_span
|
||||
)
|
||||
|
||||
captured_metrics = []
|
||||
|
||||
async def capture_metric(metric):
|
||||
captured_metrics.append(metric)
|
||||
|
||||
monkeypatch.setattr(agent_with_telemetry.telemetry_api, "log_event", capture_metric)
|
||||
|
||||
def mock_create_task(coro, *, name=None):
|
||||
return asyncio.run(coro)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"llama_stack.providers.inline.agents.meta_reference.agent_instance.asyncio.create_task", mock_create_task
|
||||
)
|
||||
|
||||
agent_with_telemetry._track_tool("knowledge_search")
|
||||
|
||||
assert len(captured_metrics) == 1
|
||||
metric = captured_metrics[0]
|
||||
assert metric.attributes["tool"] == "rag"
|
||||
|
||||
def test_no_telemetry_api(self, agent_without_telemetry):
|
||||
"""Test that methods work gracefully when telemetry_api is None"""
|
||||
# These should not crash
|
||||
agent_without_telemetry._track_step()
|
||||
agent_without_telemetry._track_workflow("failed", 1.0)
|
||||
agent_without_telemetry._track_tool("web_search")
|
||||
|
||||
def test_no_active_span(self, agent_with_telemetry, monkeypatch):
|
||||
"""Test that methods work gracefully when no span is active"""
|
||||
monkeypatch.setattr(
|
||||
"llama_stack.providers.inline.agents.meta_reference.agent_instance.get_current_span", lambda: None
|
||||
)
|
||||
|
||||
# These should not crash and should not call telemetry
|
||||
agent_with_telemetry._track_step()
|
||||
agent_with_telemetry._track_workflow("failed", 1.0)
|
||||
agent_with_telemetry._track_tool("web_search")
|
||||
|
||||
# Telemetry should not have been called
|
||||
agent_with_telemetry.telemetry_api.log_event.assert_not_called()
|
5
tests/unit/providers/telemetry/__init__.py
Normal file
5
tests/unit/providers/telemetry/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
292
tests/unit/providers/telemetry/test_agent_metrics_histogram.py
Normal file
292
tests/unit/providers/telemetry/test_agent_metrics_histogram.py
Normal file
|
@ -0,0 +1,292 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.telemetry import MetricEvent, MetricType
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import TelemetryAdapter
|
||||
|
||||
|
||||
class TestAgentMetricsHistogram:
|
||||
"""Unit tests for histogram support in telemetry adapter for agent metrics"""
|
||||
|
||||
@pytest.fixture
|
||||
def telemetry_config(self):
|
||||
"""Basic telemetry config for testing"""
|
||||
return TelemetryConfig(
|
||||
service_name="test-service",
|
||||
sinks=[],
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def telemetry_adapter(self, telemetry_config):
|
||||
"""TelemetryAdapter with mocked meter"""
|
||||
adapter = TelemetryAdapter(telemetry_config, {})
|
||||
# Mock the meter to avoid OpenTelemetry setup
|
||||
adapter.meter = Mock()
|
||||
return adapter
|
||||
|
||||
def test_get_or_create_histogram_new(self, telemetry_adapter):
|
||||
"""Test creating a new histogram"""
|
||||
mock_histogram = Mock()
|
||||
telemetry_adapter.meter.create_histogram.return_value = mock_histogram
|
||||
|
||||
# Clear global storage to ensure clean state
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE
|
||||
|
||||
_GLOBAL_STORAGE["histograms"] = {}
|
||||
|
||||
result = telemetry_adapter._get_or_create_histogram("test_histogram", "s", [0.1, 0.5, 1.0, 5.0, 10.0])
|
||||
|
||||
assert result == mock_histogram
|
||||
telemetry_adapter.meter.create_histogram.assert_called_once_with(
|
||||
name="test_histogram",
|
||||
unit="s",
|
||||
description="Histogram for test_histogram",
|
||||
)
|
||||
assert _GLOBAL_STORAGE["histograms"]["test_histogram"] == mock_histogram
|
||||
|
||||
def test_get_or_create_histogram_existing(self, telemetry_adapter):
|
||||
"""Test retrieving an existing histogram"""
|
||||
mock_histogram = Mock()
|
||||
|
||||
# Pre-populate global storage
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE
|
||||
|
||||
_GLOBAL_STORAGE["histograms"] = {"existing_histogram": mock_histogram}
|
||||
|
||||
result = telemetry_adapter._get_or_create_histogram("existing_histogram", "ms")
|
||||
|
||||
assert result == mock_histogram
|
||||
# Should not create a new histogram
|
||||
telemetry_adapter.meter.create_histogram.assert_not_called()
|
||||
|
||||
def test_log_metric_duration_histogram(self, telemetry_adapter):
|
||||
"""Test logging duration metrics creates histogram"""
|
||||
mock_histogram = Mock()
|
||||
telemetry_adapter.meter.create_histogram.return_value = mock_histogram
|
||||
|
||||
# Clear global storage
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE
|
||||
|
||||
_GLOBAL_STORAGE["histograms"] = {}
|
||||
|
||||
metric_event = MetricEvent(
|
||||
trace_id="123",
|
||||
span_id="456",
|
||||
metric="llama_stack_agent_workflow_duration_seconds",
|
||||
value=15.7,
|
||||
timestamp=1234567890.0,
|
||||
unit="s",
|
||||
attributes={"agent_id": "test-agent"},
|
||||
metric_type=MetricType.HISTOGRAM,
|
||||
)
|
||||
|
||||
telemetry_adapter._log_metric(metric_event)
|
||||
|
||||
# Verify histogram was created and recorded
|
||||
telemetry_adapter.meter.create_histogram.assert_called_once_with(
|
||||
name="llama_stack_agent_workflow_duration_seconds",
|
||||
unit="s",
|
||||
description="Histogram for llama_stack_agent_workflow_duration_seconds",
|
||||
)
|
||||
mock_histogram.record.assert_called_once_with(15.7, attributes={"agent_id": "test-agent"})
|
||||
|
||||
def test_log_metric_duration_histogram_default_buckets(self, telemetry_adapter):
|
||||
"""Test that duration metrics use default buckets"""
|
||||
mock_histogram = Mock()
|
||||
telemetry_adapter.meter.create_histogram.return_value = mock_histogram
|
||||
|
||||
# Clear global storage
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE
|
||||
|
||||
_GLOBAL_STORAGE["histograms"] = {}
|
||||
|
||||
metric_event = MetricEvent(
|
||||
trace_id="123",
|
||||
span_id="456",
|
||||
metric="custom_duration_seconds",
|
||||
value=5.2,
|
||||
timestamp=1234567890.0,
|
||||
unit="s",
|
||||
attributes={},
|
||||
metric_type=MetricType.HISTOGRAM,
|
||||
)
|
||||
|
||||
telemetry_adapter._log_metric(metric_event)
|
||||
|
||||
# Verify histogram was created (buckets are not passed to create_histogram in OpenTelemetry)
|
||||
mock_histogram.record.assert_called_once_with(5.2, attributes={})
|
||||
|
||||
def test_log_metric_non_duration_counter(self, telemetry_adapter):
|
||||
"""Test that non-duration metrics still use counters"""
|
||||
mock_counter = Mock()
|
||||
telemetry_adapter.meter.create_counter.return_value = mock_counter
|
||||
|
||||
# Clear global storage
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE
|
||||
|
||||
_GLOBAL_STORAGE["counters"] = {}
|
||||
|
||||
metric_event = MetricEvent(
|
||||
trace_id="123",
|
||||
span_id="456",
|
||||
metric="llama_stack_agent_workflows_total",
|
||||
value=1,
|
||||
timestamp=1234567890.0,
|
||||
unit="1",
|
||||
attributes={"agent_id": "test-agent", "status": "completed"},
|
||||
)
|
||||
|
||||
telemetry_adapter._log_metric(metric_event)
|
||||
|
||||
# Verify counter was used, not histogram
|
||||
telemetry_adapter.meter.create_counter.assert_called_once()
|
||||
telemetry_adapter.meter.create_histogram.assert_not_called()
|
||||
mock_counter.add.assert_called_once_with(1, attributes={"agent_id": "test-agent", "status": "completed"})
|
||||
|
||||
def test_log_metric_no_meter(self, telemetry_adapter):
|
||||
"""Test metric logging when meter is None"""
|
||||
telemetry_adapter.meter = None
|
||||
|
||||
metric_event = MetricEvent(
|
||||
trace_id="123",
|
||||
span_id="456",
|
||||
metric="test_duration_seconds",
|
||||
value=1.0,
|
||||
timestamp=1234567890.0,
|
||||
unit="s",
|
||||
attributes={},
|
||||
)
|
||||
|
||||
# Should not raise exception
|
||||
telemetry_adapter._log_metric(metric_event)
|
||||
|
||||
def test_histogram_name_detection_patterns(self, telemetry_adapter):
|
||||
"""Test various duration metric name patterns"""
|
||||
mock_histogram = Mock()
|
||||
telemetry_adapter.meter.create_histogram.return_value = mock_histogram
|
||||
|
||||
# Clear global storage
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE
|
||||
|
||||
_GLOBAL_STORAGE["histograms"] = {}
|
||||
|
||||
duration_metrics = [
|
||||
"workflow_duration_seconds",
|
||||
"request_duration_seconds",
|
||||
"processing_duration_seconds",
|
||||
"llama_stack_agent_workflow_duration_seconds",
|
||||
]
|
||||
|
||||
for metric_name in duration_metrics:
|
||||
_GLOBAL_STORAGE["histograms"] = {} # Reset for each test
|
||||
|
||||
metric_event = MetricEvent(
|
||||
trace_id="123",
|
||||
span_id="456",
|
||||
metric=metric_name,
|
||||
value=1.0,
|
||||
timestamp=1234567890.0,
|
||||
unit="s",
|
||||
attributes={},
|
||||
metric_type=MetricType.HISTOGRAM,
|
||||
)
|
||||
|
||||
telemetry_adapter._log_metric(metric_event)
|
||||
mock_histogram.record.assert_called()
|
||||
|
||||
# Reset call count for negative test
|
||||
mock_histogram.record.reset_mock()
|
||||
telemetry_adapter.meter.create_histogram.reset_mock()
|
||||
|
||||
# Test non-duration metric
|
||||
non_duration_metric = MetricEvent(
|
||||
trace_id="123",
|
||||
span_id="456",
|
||||
metric="workflow_total", # No "_duration_seconds" suffix
|
||||
value=1,
|
||||
timestamp=1234567890.0,
|
||||
unit="1",
|
||||
attributes={},
|
||||
)
|
||||
|
||||
telemetry_adapter._log_metric(non_duration_metric)
|
||||
|
||||
# Should not create histogram for non-duration metric
|
||||
telemetry_adapter.meter.create_histogram.assert_not_called()
|
||||
mock_histogram.record.assert_not_called()
|
||||
|
||||
def test_histogram_global_storage_isolation(self, telemetry_adapter):
|
||||
"""Test that histogram storage doesn't interfere with counters"""
|
||||
mock_histogram = Mock()
|
||||
mock_counter = Mock()
|
||||
|
||||
telemetry_adapter.meter.create_histogram.return_value = mock_histogram
|
||||
telemetry_adapter.meter.create_counter.return_value = mock_counter
|
||||
|
||||
# Clear global storage
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE
|
||||
|
||||
_GLOBAL_STORAGE["histograms"] = {}
|
||||
_GLOBAL_STORAGE["counters"] = {}
|
||||
|
||||
# Create histogram
|
||||
duration_metric = MetricEvent(
|
||||
trace_id="123",
|
||||
span_id="456",
|
||||
metric="test_duration_seconds",
|
||||
value=1.0,
|
||||
timestamp=1234567890.0,
|
||||
unit="s",
|
||||
attributes={},
|
||||
metric_type=MetricType.HISTOGRAM,
|
||||
)
|
||||
telemetry_adapter._log_metric(duration_metric)
|
||||
|
||||
# Create counter
|
||||
counter_metric = MetricEvent(
|
||||
trace_id="123",
|
||||
span_id="456",
|
||||
metric="test_counter",
|
||||
value=1,
|
||||
timestamp=1234567890.0,
|
||||
unit="1",
|
||||
attributes={},
|
||||
)
|
||||
telemetry_adapter._log_metric(counter_metric)
|
||||
|
||||
# Verify both were created and stored separately
|
||||
assert "test_duration_seconds" in _GLOBAL_STORAGE["histograms"]
|
||||
assert "test_counter" in _GLOBAL_STORAGE["counters"]
|
||||
assert "test_duration_seconds" not in _GLOBAL_STORAGE["counters"]
|
||||
assert "test_counter" not in _GLOBAL_STORAGE["histograms"]
|
||||
|
||||
def test_histogram_buckets_parameter_ignored(self, telemetry_adapter):
|
||||
"""Test that buckets parameter doesn't affect histogram creation (OpenTelemetry handles buckets internally)"""
|
||||
mock_histogram = Mock()
|
||||
telemetry_adapter.meter.create_histogram.return_value = mock_histogram
|
||||
|
||||
# Clear global storage
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE
|
||||
|
||||
_GLOBAL_STORAGE["histograms"] = {}
|
||||
|
||||
# Call with buckets parameter
|
||||
result = telemetry_adapter._get_or_create_histogram(
|
||||
"test_histogram", "s", buckets=[0.1, 0.5, 1.0, 5.0, 10.0, 25.0, 50.0, 100.0]
|
||||
)
|
||||
|
||||
# Buckets are not passed to OpenTelemetry create_histogram
|
||||
telemetry_adapter.meter.create_histogram.assert_called_once_with(
|
||||
name="test_histogram",
|
||||
unit="s",
|
||||
description="Histogram for test_histogram",
|
||||
)
|
||||
assert result == mock_histogram
|
Loading…
Add table
Add a link
Reference in a new issue