improve agent metrics integration test and cleanup fixtures

- simplified test to use telemetry.query_metrics for verification
- test now validates actual queryable metrics data
- verified by query metrics functionality added in #3074
This commit is contained in:
skamenan7 2025-09-19 10:47:16 -04:00
parent 69b692af91
commit 8f0413e743
5 changed files with 406 additions and 208 deletions

View file

@ -0,0 +1,170 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator, Callable
from pathlib import Path
from typing import Any
from unittest.mock import Mock, patch
import pytest
from llama_stack.apis.inference import ToolDefinition
from llama_stack.apis.tools import ToolInvocationResult
from llama_stack.providers.inline.agents.meta_reference.agent_instance import ChatAgent
from llama_stack.providers.inline.telemetry.meta_reference.config import (
TelemetryConfig,
TelemetrySink,
)
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
TelemetryAdapter,
)
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.sqlite.sqlite import SqliteKVStoreImpl
from llama_stack.providers.utils.telemetry import tracing as telemetry_tracing
@pytest.fixture
def make_agent_fixture():
def _make(telemetry, kvstore) -> ChatAgent:
agent = ChatAgent(
agent_id="test-agent",
agent_config=Mock(),
inference_api=Mock(),
safety_api=Mock(),
tool_runtime_api=Mock(),
tool_groups_api=Mock(),
vector_io_api=Mock(),
telemetry_api=telemetry,
persistence_store=kvstore,
created_at="2025-01-01T00:00:00Z",
policy=[],
)
agent.agent_config.client_tools = []
agent.agent_config.max_infer_iters = 5
agent.input_shields = []
agent.output_shields = []
agent.tool_defs = [
ToolDefinition(tool_name="web_search", description="", parameters={}),
ToolDefinition(tool_name="knowledge_search", description="", parameters={}),
]
agent.tool_name_to_args = {}
# Stub tool runtime invoke_tool
async def _mock_invoke_tool(
*args: Any,
tool_name: str | None = None,
kwargs: dict | None = None,
**extra: Any,
):
return ToolInvocationResult(content="Tool execution result")
agent.tool_runtime_api.invoke_tool = _mock_invoke_tool
return agent
return _make
def _chat_stream(tool_name: str | None, content: str = ""):
from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
StopReason,
)
from llama_stack.models.llama.datatypes import ToolCall
async def gen():
# Start
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
)
)
# Content
if content:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=TextDelta(text=content),
)
)
# Tool call if specified
if tool_name:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call=ToolCall(call_id="call_0", tool_name=tool_name, arguments={}),
parse_status=ToolCallParseStatus.succeeded,
),
)
)
# Complete
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=""),
stop_reason=StopReason.end_of_turn,
)
)
return gen()
@pytest.fixture
async def telemetry(tmp_path: Path) -> AsyncGenerator[TelemetryAdapter, None]:
db_path = tmp_path / "trace_store.db"
cfg = TelemetryConfig(
sinks=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE],
sqlite_db_path=str(db_path),
)
telemetry = TelemetryAdapter(cfg, deps={})
telemetry_tracing.setup_logger(telemetry)
try:
yield telemetry
finally:
await telemetry.shutdown()
@pytest.fixture
async def kvstore(tmp_path: Path) -> SqliteKVStoreImpl:
kv_path = tmp_path / "agent_kvstore.db"
kv = SqliteKVStoreImpl(SqliteKVStoreConfig(db_path=str(kv_path)))
await kv.initialize()
return kv
@pytest.fixture
def span_patch():
with (
patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.get_current_span") as mock_span,
patch(
"llama_stack.providers.utils.telemetry.tracing.generate_span_id",
return_value="0000000000000abc",
),
):
mock_span.return_value = Mock(get_span_context=Mock(return_value=Mock(trace_id=0x123, span_id=0xABC)))
yield
@pytest.fixture
def make_completion_fn() -> Callable[[str | None, str], Callable]:
def _factory(tool_name: str | None = None, content: str = "") -> Callable:
async def chat_completion(*args: Any, **kwargs: Any):
return _chat_stream(tool_name, content)
return chat_completion
return _factory

View file

@ -5,55 +5,79 @@
# the root directory of this source tree.
import asyncio
from unittest.mock import AsyncMock, Mock, patch
from typing import Any
from llama_stack.providers.inline.agents.meta_reference.agent_instance import ChatAgent
from llama_stack.providers.utils.telemetry import tracing as telemetry_tracing
class TestAgentMetricsIntegration:
"""Smoke test for agent metrics integration"""
async def test_agent_metrics_methods_exist_and_work(self):
"""Test that metrics methods exist and can be called without errors"""
# Create a minimal agent instance with mocked dependencies
telemetry_api = AsyncMock()
telemetry_api.logged_events = []
async def mock_log_event(event):
telemetry_api.logged_events.append(event)
telemetry_api.log_event = mock_log_event
agent = ChatAgent(
agent_id="test-agent",
agent_config=Mock(),
inference_api=Mock(),
safety_api=Mock(),
tool_runtime_api=Mock(),
tool_groups_api=Mock(),
vector_io_api=Mock(),
telemetry_api=telemetry_api,
persistence_store=Mock(),
created_at="2025-01-01T00:00:00Z",
policy=[],
async def test_agent_metrics_end_to_end(
self: Any,
telemetry: Any,
kvstore: Any,
make_agent_fixture: Any,
span_patch: Any,
make_completion_fn: Any,
) -> None:
from llama_stack.apis.inference import (
SamplingParams,
UserMessage,
)
with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.get_current_span") as mock_span:
mock_span.return_value = Mock(get_span_context=Mock(return_value=Mock(trace_id=123, span_id=456)))
agent: Any = make_agent_fixture(telemetry, kvstore)
# Test all metrics methods work
agent._track_step()
agent._track_workflow("completed", 2.5)
agent._track_tool("web_search")
session_id = await agent.create_session("s")
sampling_params = SamplingParams(max_tokens=64)
# Wait for async operations
await asyncio.sleep(0.01)
# single trace: plain, knowledge_search, web_search
await telemetry_tracing.start_trace("agent_metrics")
agent.inference_api.chat_completion = make_completion_fn(None, "Hello! I can help you with that.")
async for _ in agent.run(
session_id,
"t1",
[UserMessage(content="Hello")],
sampling_params,
stream=True,
):
pass
agent.inference_api.chat_completion = make_completion_fn("knowledge_search", "")
async for _ in agent.run(
session_id,
"t2",
[UserMessage(content="Please search knowledge")],
sampling_params,
stream=True,
):
pass
agent.inference_api.chat_completion = make_completion_fn("web_search", "")
async for _ in agent.run(
session_id,
"t3",
[UserMessage(content="Please search web")],
sampling_params,
stream=True,
):
pass
await telemetry_tracing.end_trace()
# Basic verification that telemetry was called
assert len(telemetry_api.logged_events) >= 3
# Poll briefly to avoid flake with async persistence
tool_labels: set[str] = set()
for _ in range(10):
resp = await telemetry.query_metrics("llama_stack_agent_tool_calls_total", start_time=0, end_time=None)
tool_labels.clear()
for series in getattr(resp, "data", []) or []:
for lbl in getattr(series, "labels", []) or []:
name = getattr(lbl, "name", None) or getattr(lbl, "key", None)
value = getattr(lbl, "value", None)
if name == "tool" and value:
tool_labels.add(value)
# Verify we can call the methods without exceptions
agent._track_tool("knowledge_search") # Test tool mapping
await asyncio.sleep(0.01)
# Look for both web_search AND some form of knowledge search
if ("web_search" in tool_labels) and ("rag" in tool_labels or "knowledge_search" in tool_labels):
break
await asyncio.sleep(0.1)
assert len(telemetry_api.logged_events) >= 4
# More descriptive assertion
assert bool(tool_labels & {"web_search", "rag", "knowledge_search"}), (
f"Expected tool calls not found. Got: {tool_labels}"
)

View file

@ -14,70 +14,56 @@ from llama_stack.providers.inline.telemetry.meta_reference.telemetry import Tele
class TestAgentMetricsHistogram:
"""Unit tests for histogram support in telemetry adapter for agent metrics"""
"""Tests for agent histogram metrics"""
@pytest.fixture
def telemetry_config(self):
"""Basic telemetry config for testing"""
return TelemetryConfig(
service_name="test-service",
sinks=[],
)
def config(self):
return TelemetryConfig(service_name="test-service", sinks=[])
@pytest.fixture
def telemetry_adapter(self, telemetry_config):
"""TelemetryAdapter with mocked meter"""
adapter = TelemetryAdapter(telemetry_config, {})
# Mock the meter to avoid OpenTelemetry setup
adapter.meter = Mock()
def adapter(self, config):
adapter = TelemetryAdapter(config, {})
adapter.meter = Mock() # skip otel setup
return adapter
def test_get_or_create_histogram_new(self, telemetry_adapter):
"""Test creating a new histogram"""
mock_histogram = Mock()
telemetry_adapter.meter.create_histogram.return_value = mock_histogram
def test_histogram_creation(self, adapter):
mock_hist = Mock()
adapter.meter.create_histogram.return_value = mock_hist
# Clear global storage to ensure clean state
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE
_GLOBAL_STORAGE["histograms"] = {}
result = telemetry_adapter._get_or_create_histogram("test_histogram", "s", [0.1, 0.5, 1.0, 5.0, 10.0])
result = adapter._get_or_create_histogram("test_histogram", "s")
assert result == mock_histogram
telemetry_adapter.meter.create_histogram.assert_called_once_with(
assert result == mock_hist
adapter.meter.create_histogram.assert_called_once_with(
name="test_histogram",
unit="s",
description="Histogram for test_histogram",
description="test histogram",
)
assert _GLOBAL_STORAGE["histograms"]["test_histogram"] == mock_histogram
assert _GLOBAL_STORAGE["histograms"]["test_histogram"] == mock_hist
def test_get_or_create_histogram_existing(self, telemetry_adapter):
"""Test retrieving an existing histogram"""
mock_histogram = Mock()
# Pre-populate global storage
def test_histogram_reuse(self, adapter):
mock_hist = Mock()
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE
_GLOBAL_STORAGE["histograms"] = {"existing_histogram": mock_histogram}
_GLOBAL_STORAGE["histograms"] = {"existing_histogram": mock_hist}
result = telemetry_adapter._get_or_create_histogram("existing_histogram", "ms")
result = adapter._get_or_create_histogram("existing_histogram", "ms")
assert result == mock_histogram
# Should not create a new histogram
telemetry_adapter.meter.create_histogram.assert_not_called()
assert result == mock_hist
adapter.meter.create_histogram.assert_not_called()
def test_log_metric_duration_histogram(self, telemetry_adapter):
"""Test logging duration metrics creates histogram"""
mock_histogram = Mock()
telemetry_adapter.meter.create_histogram.return_value = mock_histogram
def test_workflow_duration_histogram(self, adapter):
mock_hist = Mock()
adapter.meter.create_histogram.return_value = mock_hist
# Clear global storage
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE
_GLOBAL_STORAGE["histograms"] = {}
metric_event = MetricEvent(
event = MetricEvent(
trace_id="123",
span_id="456",
metric="llama_stack_agent_workflow_duration_seconds",
@ -88,27 +74,24 @@ class TestAgentMetricsHistogram:
metric_type=MetricType.HISTOGRAM,
)
telemetry_adapter._log_metric(metric_event)
adapter._log_metric(event)
# Verify histogram was created and recorded
telemetry_adapter.meter.create_histogram.assert_called_once_with(
adapter.meter.create_histogram.assert_called_once_with(
name="llama_stack_agent_workflow_duration_seconds",
unit="s",
description="Histogram for llama_stack_agent_workflow_duration_seconds",
description="llama stack agent workflow duration seconds",
)
mock_histogram.record.assert_called_once_with(15.7, attributes={"agent_id": "test-agent"})
mock_hist.record.assert_called_once_with(15.7, attributes={"agent_id": "test-agent"})
def test_log_metric_duration_histogram_default_buckets(self, telemetry_adapter):
"""Test that duration metrics use default buckets"""
mock_histogram = Mock()
telemetry_adapter.meter.create_histogram.return_value = mock_histogram
def test_duration_buckets_configured_via_views(self, adapter):
mock_hist = Mock()
adapter.meter.create_histogram.return_value = mock_hist
# Clear global storage
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE
_GLOBAL_STORAGE["histograms"] = {}
metric_event = MetricEvent(
event = MetricEvent(
trace_id="123",
span_id="456",
metric="custom_duration_seconds",
@ -119,22 +102,20 @@ class TestAgentMetricsHistogram:
metric_type=MetricType.HISTOGRAM,
)
telemetry_adapter._log_metric(metric_event)
adapter._log_metric(event)
# Verify histogram was created (buckets are not passed to create_histogram in OpenTelemetry)
mock_histogram.record.assert_called_once_with(5.2, attributes={})
# buckets configured via otel views, not passed to create_histogram
mock_hist.record.assert_called_once_with(5.2, attributes={})
def test_log_metric_non_duration_counter(self, telemetry_adapter):
"""Test that non-duration metrics still use counters"""
def test_non_duration_uses_counter(self, adapter):
mock_counter = Mock()
telemetry_adapter.meter.create_counter.return_value = mock_counter
adapter.meter.create_counter.return_value = mock_counter
# Clear global storage
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE
_GLOBAL_STORAGE["counters"] = {}
metric_event = MetricEvent(
event = MetricEvent(
trace_id="123",
span_id="456",
metric="llama_stack_agent_workflows_total",
@ -144,18 +125,16 @@ class TestAgentMetricsHistogram:
attributes={"agent_id": "test-agent", "status": "completed"},
)
telemetry_adapter._log_metric(metric_event)
adapter._log_metric(event)
# Verify counter was used, not histogram
telemetry_adapter.meter.create_counter.assert_called_once()
telemetry_adapter.meter.create_histogram.assert_not_called()
adapter.meter.create_counter.assert_called_once()
adapter.meter.create_histogram.assert_not_called()
mock_counter.add.assert_called_once_with(1, attributes={"agent_id": "test-agent", "status": "completed"})
def test_log_metric_no_meter(self, telemetry_adapter):
"""Test metric logging when meter is None"""
telemetry_adapter.meter = None
def test_no_meter_doesnt_crash(self, adapter):
adapter.meter = None
metric_event = MetricEvent(
event = MetricEvent(
trace_id="123",
span_id="456",
metric="test_duration_seconds",
@ -165,80 +144,59 @@ class TestAgentMetricsHistogram:
attributes={},
)
# Should not raise exception
telemetry_adapter._log_metric(metric_event)
adapter._log_metric(event) # shouldn't crash
def test_histogram_name_detection_patterns(self, telemetry_adapter):
"""Test various duration metric name patterns"""
mock_histogram = Mock()
telemetry_adapter.meter.create_histogram.return_value = mock_histogram
# Clear global storage
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE
_GLOBAL_STORAGE["histograms"] = {}
duration_metrics = [
"workflow_duration_seconds",
"request_duration_seconds",
"processing_duration_seconds",
"llama_stack_agent_workflow_duration_seconds",
]
for metric_name in duration_metrics:
_GLOBAL_STORAGE["histograms"] = {} # Reset for each test
metric_event = MetricEvent(
trace_id="123",
span_id="456",
metric=metric_name,
value=1.0,
timestamp=1234567890.0,
unit="s",
attributes={},
metric_type=MetricType.HISTOGRAM,
)
telemetry_adapter._log_metric(metric_event)
mock_histogram.record.assert_called()
# Reset call count for negative test
mock_histogram.record.reset_mock()
telemetry_adapter.meter.create_histogram.reset_mock()
# Test non-duration metric
non_duration_metric = MetricEvent(
trace_id="123",
span_id="456",
metric="workflow_total", # No "_duration_seconds" suffix
value=1,
timestamp=1234567890.0,
unit="1",
attributes={},
)
telemetry_adapter._log_metric(non_duration_metric)
# Should not create histogram for non-duration metric
telemetry_adapter.meter.create_histogram.assert_not_called()
mock_histogram.record.assert_not_called()
def test_histogram_global_storage_isolation(self, telemetry_adapter):
"""Test that histogram storage doesn't interfere with counters"""
mock_histogram = Mock()
def test_histogram_vs_counter_by_type(self, adapter):
mock_hist = Mock()
mock_counter = Mock()
adapter.meter.create_histogram.return_value = mock_hist
adapter.meter.create_counter.return_value = mock_counter
telemetry_adapter.meter.create_histogram.return_value = mock_histogram
telemetry_adapter.meter.create_counter.return_value = mock_counter
# Clear global storage
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE
_GLOBAL_STORAGE["histograms"] = {}
_GLOBAL_STORAGE["counters"] = {}
# Create histogram
duration_metric = MetricEvent(
# histogram metric
hist_event = MetricEvent(
trace_id="123",
span_id="456",
metric="workflow_duration_seconds",
value=1.0,
timestamp=1234567890.0,
unit="s",
attributes={},
metric_type=MetricType.HISTOGRAM,
)
adapter._log_metric(hist_event)
mock_hist.record.assert_called()
# counter metric (default type)
counter_event = MetricEvent(
trace_id="123",
span_id="456",
metric="workflow_total",
value=1,
timestamp=1234567890.0,
unit="1",
attributes={},
)
adapter._log_metric(counter_event)
mock_counter.add.assert_called()
def test_storage_separation(self, adapter):
mock_hist = Mock()
mock_counter = Mock()
adapter.meter.create_histogram.return_value = mock_hist
adapter.meter.create_counter.return_value = mock_counter
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE
_GLOBAL_STORAGE["histograms"] = {}
_GLOBAL_STORAGE["counters"] = {}
# create both types
hist_event = MetricEvent(
trace_id="123",
span_id="456",
metric="test_duration_seconds",
@ -248,10 +206,7 @@ class TestAgentMetricsHistogram:
attributes={},
metric_type=MetricType.HISTOGRAM,
)
telemetry_adapter._log_metric(duration_metric)
# Create counter
counter_metric = MetricEvent(
counter_event = MetricEvent(
trace_id="123",
span_id="456",
metric="test_counter",
@ -260,33 +215,30 @@ class TestAgentMetricsHistogram:
unit="1",
attributes={},
)
telemetry_adapter._log_metric(counter_metric)
# Verify both were created and stored separately
adapter._log_metric(hist_event)
adapter._log_metric(counter_event)
# check they're stored separately
assert "test_duration_seconds" in _GLOBAL_STORAGE["histograms"]
assert "test_counter" in _GLOBAL_STORAGE["counters"]
assert "test_duration_seconds" not in _GLOBAL_STORAGE["counters"]
assert "test_counter" not in _GLOBAL_STORAGE["histograms"]
def test_histogram_buckets_parameter_ignored(self, telemetry_adapter):
"""Test that buckets parameter doesn't affect histogram creation (OpenTelemetry handles buckets internally)"""
mock_histogram = Mock()
telemetry_adapter.meter.create_histogram.return_value = mock_histogram
def test_histogram_uses_views_for_buckets(self, adapter):
mock_hist = Mock()
adapter.meter.create_histogram.return_value = mock_hist
# Clear global storage
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE
_GLOBAL_STORAGE["histograms"] = {}
# Call with buckets parameter
result = telemetry_adapter._get_or_create_histogram(
"test_histogram", "s", buckets=[0.1, 0.5, 1.0, 5.0, 10.0, 25.0, 50.0, 100.0]
)
result = adapter._get_or_create_histogram("test_histogram", "s")
# Buckets are not passed to OpenTelemetry create_histogram
telemetry_adapter.meter.create_histogram.assert_called_once_with(
# buckets come from otel views, not create_histogram params
adapter.meter.create_histogram.assert_called_once_with(
name="test_histogram",
unit="s",
description="Histogram for test_histogram",
description="test histogram",
)
assert result == mock_histogram
assert result == mock_hist