diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 9ddb070d7..e02e611f4 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -13616,6 +13616,10 @@ "unit": { "type": "string", "description": "The unit of measurement for the metric value" + }, + "metric_type": { + "$ref": "#/components/schemas/MetricType", + "description": "The type of metric (optional, inferred if not provided for backwards compatibility)" } }, "additionalProperties": false, @@ -13631,6 +13635,17 @@ "title": "MetricEvent", "description": "A metric event containing a measured value." }, + "MetricType": { + "type": "string", + "enum": [ + "counter", + "up_down_counter", + "histogram", + "gauge" + ], + "title": "MetricType", + "description": "The type of metric being recorded." + }, "SpanEndPayload": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 94dc5c0f9..650faadbe 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -10122,6 +10122,10 @@ components: type: string description: >- The unit of measurement for the metric value + metric_type: + $ref: '#/components/schemas/MetricType' + description: >- + The type of metric (optional, inferred if not provided for backwards compatibility) additionalProperties: false required: - trace_id @@ -10134,6 +10138,15 @@ components: title: MetricEvent description: >- A metric event containing a measured value. + MetricType: + type: string + enum: + - counter + - up_down_counter + - histogram + - gauge + title: MetricType + description: The type of metric being recorded. SpanEndPayload: type: object properties: diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index 8d1b5d697..6d0bde6ef 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -90,6 +90,21 @@ class EventType(Enum): METRIC = "metric" +@json_schema_type +class MetricType(Enum): + """The type of metric being recorded. + :cvar COUNTER: A counter metric that only increases (e.g., requests_total) + :cvar UP_DOWN_COUNTER: A counter that can increase or decrease (e.g., active_connections) + :cvar HISTOGRAM: A histogram metric for measuring distributions (e.g., request_duration_seconds) + :cvar GAUGE: A gauge metric for point-in-time values (e.g., cpu_usage_percent) + """ + + COUNTER = "counter" + UP_DOWN_COUNTER = "up_down_counter" + HISTOGRAM = "histogram" + GAUGE = "gauge" + + @json_schema_type class LogSeverity(Enum): """The severity level of a log message. @@ -143,12 +158,14 @@ class MetricEvent(EventCommon): :param metric: The name of the metric being measured :param value: The numeric value of the metric measurement :param unit: The unit of measurement for the metric value + :param metric_type: The type of metric (optional, inferred if not provided for backwards compatibility) """ type: Literal[EventType.METRIC] = EventType.METRIC metric: str # this would be an enum value: int | float unit: str + metric_type: MetricType | None = None @json_schema_type diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index fde38515b..d732283c7 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -4,17 +4,20 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio import copy import json import re import secrets import string +import time import uuid import warnings from collections.abc import AsyncGenerator from datetime import UTC, datetime import httpx +from opentelemetry.trace import get_current_span from llama_stack.apis.agents import ( AgentConfig, @@ -60,6 +63,7 @@ from llama_stack.apis.inference import ( UserMessage, ) from llama_stack.apis.safety import Safety +from llama_stack.apis.telemetry import MetricEvent, MetricType, Telemetry from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime from llama_stack.apis.vector_io import VectorIO from llama_stack.core.datatypes import AccessRule @@ -97,6 +101,7 @@ class ChatAgent(ShieldRunnerMixin): tool_runtime_api: ToolRuntime, tool_groups_api: ToolGroups, vector_io_api: VectorIO, + telemetry_api: Telemetry | None, persistence_store: KVStore, created_at: str, policy: list[AccessRule], @@ -106,6 +111,7 @@ class ChatAgent(ShieldRunnerMixin): self.inference_api = inference_api self.safety_api = safety_api self.vector_io_api = vector_io_api + self.telemetry_api = telemetry_api self.storage = AgentPersistence(agent_id, persistence_store, policy) self.tool_runtime_api = tool_runtime_api self.tool_groups_api = tool_groups_api @@ -118,6 +124,9 @@ class ChatAgent(ShieldRunnerMixin): output_shields=agent_config.output_shields, ) + # Initialize workflow start time to None + self._workflow_start_time: float | None = None + def turn_to_messages(self, turn: Turn) -> list[Message]: messages = [] @@ -167,6 +176,72 @@ class ChatAgent(ShieldRunnerMixin): async def create_session(self, name: str) -> str: return await self.storage.create_session(name) + def _emit_metric( + self, + metric_name: str, + value: int | float, + unit: str, + attributes: dict[str, str] | None = None, + metric_type: MetricType | None = None, + ) -> None: + """Emit a single metric event""" + logger.info(f"_emit_metric called: {metric_name} = {value} {unit}") + + if not self.telemetry_api: + logger.warning(f"No telemetry_api available for metric {metric_name}") + return + + span = get_current_span() + if not span: + logger.warning(f"No current span available for metric {metric_name}") + return + + context = span.get_span_context() + metric = MetricEvent( + trace_id=format(context.trace_id, "x"), + span_id=format(context.span_id, "x"), + metric=metric_name, + value=value, + timestamp=time.time(), + unit=unit, + attributes={"agent_id": self.agent_id, **(attributes or {})}, + metric_type=metric_type, + ) + + # Create task with name for better debugging and capture any async errors + task_name = f"metric-{metric_name}-{self.agent_id}" + logger.info(f"Creating telemetry task: {task_name}") + task = asyncio.create_task(self.telemetry_api.log_event(metric), name=task_name) + + def _on_metric_task_done(t: asyncio.Task) -> None: + try: + exc = t.exception() + except asyncio.CancelledError: + logger.debug("Metric task %s was cancelled", task_name) + return + if exc is not None: + logger.warning("Metric task %s failed: %s", task_name, exc) + + # Only add callback if task creation succeeded (not None from mocking) + if task is not None: + task.add_done_callback(_on_metric_task_done) + + def _track_step(self): + logger.info("_track_step called") + self._emit_metric("llama_stack_agent_steps_total", 1, "1", metric_type=MetricType.COUNTER) + + def _track_workflow(self, status: str, duration: float): + logger.info(f"_track_workflow called: status={status}, duration={duration:.2f}s") + self._emit_metric("llama_stack_agent_workflows_total", 1, "1", {"status": status}, MetricType.COUNTER) + self._emit_metric( + "llama_stack_agent_workflow_duration_seconds", duration, "s", metric_type=MetricType.HISTOGRAM + ) + + def _track_tool(self, tool_name: str): + logger.info(f"_track_tool called: {tool_name}") + normalized_name = "rag" if tool_name == "knowledge_search" else tool_name + self._emit_metric("llama_stack_agent_tool_calls_total", 1, "1", {"tool": normalized_name}, MetricType.COUNTER) + async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]: messages = [] if self.agent_config.instructions != "": @@ -201,6 +276,9 @@ class ChatAgent(ShieldRunnerMixin): if self.agent_config.name: span.set_attribute("agent_name", self.agent_config.name) + # Set workflow start time for resume operations + self._workflow_start_time = time.time() + await self._initialize_tools() async for chunk in self._run_turn(request): yield chunk @@ -212,6 +290,9 @@ class ChatAgent(ShieldRunnerMixin): ) -> AsyncGenerator: assert request.stream is True, "Non-streaming not supported" + # Track workflow start time for metrics + self._workflow_start_time = time.time() + is_resume = isinstance(request, AgentTurnResumeRequest) session_info = await self.storage.get_session_info(request.session_id) if session_info is None: @@ -313,6 +394,10 @@ class ChatAgent(ShieldRunnerMixin): ) ) else: + # Track workflow completion when turn is actually complete + workflow_duration = time.time() - (self._workflow_start_time or time.time()) + self._track_workflow("completed", workflow_duration) + chunk = AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseTurnCompletePayload( @@ -726,6 +811,10 @@ class ChatAgent(ShieldRunnerMixin): ) ) + # Track step execution metric + self._track_step() + self._track_tool(tool_call.tool_name) + # Add the result message to input_messages for the next iteration input_messages.append(result_message) @@ -900,6 +989,7 @@ class ChatAgent(ShieldRunnerMixin): }, ) logger.debug(f"tool call {tool_name_str} completed with result: {result}") + return result diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 8bdde86b0..11243311d 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -38,6 +38,7 @@ from llama_stack.apis.inference import ( UserMessage, ) from llama_stack.apis.safety import Safety +from llama_stack.apis.telemetry import Telemetry from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO from llama_stack.core.datatypes import AccessRule @@ -64,6 +65,7 @@ class MetaReferenceAgentsImpl(Agents): tool_runtime_api: ToolRuntime, tool_groups_api: ToolGroups, policy: list[AccessRule], + telemetry_api: Telemetry | None = None, ): self.config = config self.inference_api = inference_api @@ -71,6 +73,7 @@ class MetaReferenceAgentsImpl(Agents): self.safety_api = safety_api self.tool_runtime_api = tool_runtime_api self.tool_groups_api = tool_groups_api + self.telemetry_api = telemetry_api self.in_memory_store = InmemoryKVStoreImpl() self.openai_responses_impl: OpenAIResponsesImpl | None = None @@ -130,6 +133,7 @@ class MetaReferenceAgentsImpl(Agents): vector_io_api=self.vector_io_api, tool_runtime_api=self.tool_runtime_api, tool_groups_api=self.tool_groups_api, + telemetry_api=self.telemetry_api, persistence_store=( self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store ), diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index 9224c3792..9ae52f307 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -12,7 +12,9 @@ from opentelemetry import metrics, trace from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics._internal.aggregation import ExplicitBucketHistogramAggregation from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader +from opentelemetry.sdk.metrics.view import View from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor @@ -24,6 +26,7 @@ from llama_stack.apis.telemetry import ( MetricEvent, MetricLabelMatcher, MetricQueryType, + MetricType, QueryCondition, QueryMetricsResponse, QuerySpanTreeResponse, @@ -56,6 +59,7 @@ _GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = { "counters": {}, "gauges": {}, "up_down_counters": {}, + "histograms": {}, } _global_lock = threading.Lock() _TRACER_PROVIDER = None @@ -108,7 +112,17 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): if TelemetrySink.OTEL_METRIC in self.config.sinks: metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter()) - metric_provider = MeterProvider(resource=resource, metric_readers=[metric_reader]) + + # decent default buckets for agent workflow timings + hist_buckets = [0.1, 0.5, 1.0, 2.5, 5.0, 10.0, 25.0, 50.0, 100.0] + views = [ + View( + instrument_type=metrics.Histogram, + aggregation=ExplicitBucketHistogramAggregation(boundaries=hist_buckets), + ) + ] + + metric_provider = MeterProvider(resource=resource, metric_readers=[metric_reader], views=views) metrics.set_meter_provider(metric_provider) if TelemetrySink.SQLITE in self.config.sinks: @@ -138,8 +152,6 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): self._log_metric(event) elif isinstance(event, StructuredLogEvent): self._log_structured(event, ttl_seconds) - else: - raise ValueError(f"Unknown event type: {event}") async def query_metrics( self, @@ -209,7 +221,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): _GLOBAL_STORAGE["counters"][name] = self.meter.create_counter( name=name, unit=unit, - description=f"Counter for {name}", + description=name.replace("_", " "), ) return _GLOBAL_STORAGE["counters"][name] @@ -219,7 +231,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): _GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge( name=name, unit=unit, - description=f"Gauge for {name}", + description=name.replace("_", " "), ) return _GLOBAL_STORAGE["gauges"][name] @@ -258,12 +270,19 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): # Log to OpenTelemetry meter if available if self.meter is None: return - if isinstance(event.value, int): - counter = self._get_or_create_counter(event.metric, event.unit) - counter.add(event.value, attributes=event.attributes) - elif isinstance(event.value, float): + + if event.metric_type == MetricType.HISTOGRAM: + histogram = self._get_or_create_histogram( + event.metric, + event.unit, + ) + histogram.record(event.value, attributes=event.attributes) + elif event.metric_type == MetricType.UP_DOWN_COUNTER: up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit) up_down_counter.add(event.value, attributes=event.attributes) + else: + counter = self._get_or_create_counter(event.metric, event.unit) + counter.add(event.value, attributes=event.attributes) def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter: assert self.meter is not None @@ -271,10 +290,20 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): _GLOBAL_STORAGE["up_down_counters"][name] = self.meter.create_up_down_counter( name=name, unit=unit, - description=f"UpDownCounter for {name}", + description=name.replace("_", " "), ) return _GLOBAL_STORAGE["up_down_counters"][name] + def _get_or_create_histogram(self, name: str, unit: str) -> metrics.Histogram: + assert self.meter is not None + if name not in _GLOBAL_STORAGE["histograms"]: + _GLOBAL_STORAGE["histograms"][name] = self.meter.create_histogram( + name=name, + unit=unit, + description=name.replace("_", " "), + ) + return _GLOBAL_STORAGE["histograms"][name] + def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None: with self._lock: span_id = int(event.span_id, 16) diff --git a/llama_stack/providers/registry/agents.py b/llama_stack/providers/registry/agents.py index 57110d129..9108f0dc5 100644 --- a/llama_stack/providers/registry/agents.py +++ b/llama_stack/providers/registry/agents.py @@ -35,6 +35,7 @@ def available_providers() -> list[ProviderSpec]: Api.vector_dbs, Api.tool_runtime, Api.tool_groups, + Api.telemetry, ], description="Meta's reference implementation of an agent system that can use tools, access vector databases, and perform complex reasoning tasks.", ), diff --git a/pyproject.toml b/pyproject.toml index 86a32f978..95084b24e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,8 +25,8 @@ classifiers = [ ] dependencies = [ "aiohttp", - "fastapi>=0.115.0,<1.0", # server - "fire", # for MCP in LLS client + "fastapi>=0.115.0,<1.0", # server + "fire", # for MCP in LLS client "httpx", "huggingface-hub>=0.34.0,<1.0", "jinja2>=3.1.6", @@ -43,12 +43,12 @@ dependencies = [ "tiktoken", "pillow", "h11>=0.16.0", - "python-multipart>=0.0.20", # For fastapi Form - "uvicorn>=0.34.0", # server - "opentelemetry-sdk>=1.30.0", # server + "python-multipart>=0.0.20", # For fastapi Form + "uvicorn>=0.34.0", # server + "opentelemetry-sdk>=1.30.0", # server "opentelemetry-exporter-otlp-proto-http>=1.30.0", # server - "aiosqlite>=0.21.0", # server - for metadata store - "asyncpg", # for metadata store + "aiosqlite>=0.21.0", # server - for metadata store + "asyncpg", # for metadata store ] [project.optional-dependencies] diff --git a/tests/integration/agents/conftest.py b/tests/integration/agents/conftest.py new file mode 100644 index 000000000..df5f3d875 --- /dev/null +++ b/tests/integration/agents/conftest.py @@ -0,0 +1,170 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from collections.abc import AsyncGenerator, Callable +from pathlib import Path +from typing import Any +from unittest.mock import Mock, patch + +import pytest + +from llama_stack.apis.inference import ToolDefinition +from llama_stack.apis.tools import ToolInvocationResult +from llama_stack.providers.inline.agents.meta_reference.agent_instance import ChatAgent +from llama_stack.providers.inline.telemetry.meta_reference.config import ( + TelemetryConfig, + TelemetrySink, +) +from llama_stack.providers.inline.telemetry.meta_reference.telemetry import ( + TelemetryAdapter, +) +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig +from llama_stack.providers.utils.kvstore.sqlite.sqlite import SqliteKVStoreImpl +from llama_stack.providers.utils.telemetry import tracing as telemetry_tracing + + +@pytest.fixture +def make_agent_fixture(): + def _make(telemetry, kvstore) -> ChatAgent: + agent = ChatAgent( + agent_id="test-agent", + agent_config=Mock(), + inference_api=Mock(), + safety_api=Mock(), + tool_runtime_api=Mock(), + tool_groups_api=Mock(), + vector_io_api=Mock(), + telemetry_api=telemetry, + persistence_store=kvstore, + created_at="2025-01-01T00:00:00Z", + policy=[], + ) + agent.agent_config.client_tools = [] + agent.agent_config.max_infer_iters = 5 + agent.input_shields = [] + agent.output_shields = [] + agent.tool_defs = [ + ToolDefinition(tool_name="web_search", description="", parameters={}), + ToolDefinition(tool_name="knowledge_search", description="", parameters={}), + ] + agent.tool_name_to_args = {} + + # Stub tool runtime invoke_tool + async def _mock_invoke_tool( + *args: Any, + tool_name: str | None = None, + kwargs: dict | None = None, + **extra: Any, + ): + return ToolInvocationResult(content="Tool execution result") + + agent.tool_runtime_api.invoke_tool = _mock_invoke_tool + return agent + + return _make + + +def _chat_stream(tool_name: str | None, content: str = ""): + from llama_stack.apis.common.content_types import ( + TextDelta, + ToolCallDelta, + ToolCallParseStatus, + ) + from llama_stack.apis.inference import ( + ChatCompletionResponseEvent, + ChatCompletionResponseEventType, + ChatCompletionResponseStreamChunk, + StopReason, + ) + from llama_stack.models.llama.datatypes import ToolCall + + async def gen(): + # Start + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.start, + delta=TextDelta(text=""), + ) + ) + + # Content + if content: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=TextDelta(text=content), + ) + ) + + # Tool call if specified + if tool_name: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + tool_call=ToolCall(call_id="call_0", tool_name=tool_name, arguments={}), + parse_status=ToolCallParseStatus.succeeded, + ), + ) + ) + + # Complete + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.complete, + delta=TextDelta(text=""), + stop_reason=StopReason.end_of_turn, + ) + ) + + return gen() + + +@pytest.fixture +async def telemetry(tmp_path: Path) -> AsyncGenerator[TelemetryAdapter, None]: + db_path = tmp_path / "trace_store.db" + cfg = TelemetryConfig( + sinks=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE], + sqlite_db_path=str(db_path), + ) + telemetry = TelemetryAdapter(cfg, deps={}) + telemetry_tracing.setup_logger(telemetry) + try: + yield telemetry + finally: + await telemetry.shutdown() + + +@pytest.fixture +async def kvstore(tmp_path: Path) -> SqliteKVStoreImpl: + kv_path = tmp_path / "agent_kvstore.db" + kv = SqliteKVStoreImpl(SqliteKVStoreConfig(db_path=str(kv_path))) + await kv.initialize() + return kv + + +@pytest.fixture +def span_patch(): + with ( + patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.get_current_span") as mock_span, + patch( + "llama_stack.providers.utils.telemetry.tracing.generate_span_id", + return_value="0000000000000abc", + ), + ): + mock_span.return_value = Mock(get_span_context=Mock(return_value=Mock(trace_id=0x123, span_id=0xABC))) + yield + + +@pytest.fixture +def make_completion_fn() -> Callable[[str | None, str], Callable]: + def _factory(tool_name: str | None = None, content: str = "") -> Callable: + async def chat_completion(*args: Any, **kwargs: Any): + return _chat_stream(tool_name, content) + + return chat_completion + + return _factory diff --git a/tests/integration/agents/test_agent_metrics_integration.py b/tests/integration/agents/test_agent_metrics_integration.py new file mode 100644 index 000000000..994e05901 --- /dev/null +++ b/tests/integration/agents/test_agent_metrics_integration.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +from typing import Any + +from llama_stack.providers.utils.telemetry import tracing as telemetry_tracing + + +class TestAgentMetricsIntegration: + async def test_agent_metrics_end_to_end( + self: Any, + telemetry: Any, + kvstore: Any, + make_agent_fixture: Any, + span_patch: Any, + make_completion_fn: Any, + ) -> None: + from llama_stack.apis.inference import ( + SamplingParams, + UserMessage, + ) + + agent: Any = make_agent_fixture(telemetry, kvstore) + + session_id = await agent.create_session("s") + sampling_params = SamplingParams(max_tokens=64) + + # single trace: plain, knowledge_search, web_search + await telemetry_tracing.start_trace("agent_metrics") + agent.inference_api.chat_completion = make_completion_fn(None, "Hello! I can help you with that.") + async for _ in agent.run( + session_id, + "t1", + [UserMessage(content="Hello")], + sampling_params, + stream=True, + ): + pass + agent.inference_api.chat_completion = make_completion_fn("knowledge_search", "") + async for _ in agent.run( + session_id, + "t2", + [UserMessage(content="Please search knowledge")], + sampling_params, + stream=True, + ): + pass + agent.inference_api.chat_completion = make_completion_fn("web_search", "") + async for _ in agent.run( + session_id, + "t3", + [UserMessage(content="Please search web")], + sampling_params, + stream=True, + ): + pass + await telemetry_tracing.end_trace() + + # Poll briefly to avoid flake with async persistence + tool_labels: set[str] = set() + for _ in range(10): + resp = await telemetry.query_metrics("llama_stack_agent_tool_calls_total", start_time=0, end_time=None) + tool_labels.clear() + for series in getattr(resp, "data", []) or []: + for lbl in getattr(series, "labels", []) or []: + name = getattr(lbl, "name", None) or getattr(lbl, "key", None) + value = getattr(lbl, "value", None) + if name == "tool" and value: + tool_labels.add(value) + + # Look for both web_search AND some form of knowledge search + if ("web_search" in tool_labels) and ("rag" in tool_labels or "knowledge_search" in tool_labels): + break + await asyncio.sleep(0.1) + + # More descriptive assertion + assert bool(tool_labels & {"web_search", "rag", "knowledge_search"}), ( + f"Expected tool calls not found. Got: {tool_labels}" + ) diff --git a/tests/unit/providers/agents/__init__.py b/tests/unit/providers/agents/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/tests/unit/providers/agents/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/tests/unit/providers/agents/test_agent_metrics.py b/tests/unit/providers/agents/test_agent_metrics.py new file mode 100644 index 000000000..acaec837d --- /dev/null +++ b/tests/unit/providers/agents/test_agent_metrics.py @@ -0,0 +1,212 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +from unittest.mock import AsyncMock, Mock + +import pytest +from opentelemetry.trace import SpanContext, TraceFlags + +from llama_stack.providers.inline.agents.meta_reference.agent_instance import ChatAgent + + +class FakeSpan: + def __init__(self, trace_id: int = 123, span_id: int = 456): + self._context = SpanContext( + trace_id=trace_id, + span_id=span_id, + is_remote=False, + trace_flags=TraceFlags(0x01), + ) + + def get_span_context(self): + return self._context + + +@pytest.fixture +def agent_with_telemetry(): + """Create a real ChatAgent with telemetry API""" + telemetry_api = AsyncMock() + + agent = ChatAgent( + agent_id="test-agent", + agent_config=Mock(), + inference_api=Mock(), + safety_api=Mock(), + tool_runtime_api=Mock(), + tool_groups_api=Mock(), + vector_io_api=Mock(), + telemetry_api=telemetry_api, + persistence_store=Mock(), + created_at="2025-01-01T00:00:00Z", + policy=[], + ) + return agent + + +@pytest.fixture +def agent_without_telemetry(): + """Create a real ChatAgent without telemetry API""" + agent = ChatAgent( + agent_id="test-agent", + agent_config=Mock(), + inference_api=Mock(), + safety_api=Mock(), + tool_runtime_api=Mock(), + tool_groups_api=Mock(), + vector_io_api=Mock(), + telemetry_api=None, + persistence_store=Mock(), + created_at="2025-01-01T00:00:00Z", + policy=[], + ) + return agent + + +class TestAgentMetrics: + def test_step_execution_metrics(self, agent_with_telemetry, monkeypatch): + """Test that step execution metrics are emitted correctly""" + fake_span = FakeSpan() + monkeypatch.setattr( + "llama_stack.providers.inline.agents.meta_reference.agent_instance.get_current_span", lambda: fake_span + ) + + # Capture the metric instead of actually creating async task + captured_metrics = [] + + async def capture_metric(metric): + captured_metrics.append(metric) + + monkeypatch.setattr(agent_with_telemetry.telemetry_api, "log_event", capture_metric) + + def mock_create_task(coro, *, name=None): + return asyncio.run(coro) + + monkeypatch.setattr( + "llama_stack.providers.inline.agents.meta_reference.agent_instance.asyncio.create_task", mock_create_task + ) + + agent_with_telemetry._track_step() + + assert len(captured_metrics) == 1 + metric = captured_metrics[0] + assert metric.metric == "llama_stack_agent_steps_total" + assert metric.value == 1 + assert metric.unit == "1" + assert metric.attributes["agent_id"] == "test-agent" + + def test_workflow_completion_metrics(self, agent_with_telemetry, monkeypatch): + """Test that workflow completion metrics are emitted correctly""" + fake_span = FakeSpan() + monkeypatch.setattr( + "llama_stack.providers.inline.agents.meta_reference.agent_instance.get_current_span", lambda: fake_span + ) + + captured_metrics = [] + + async def capture_metric(metric): + captured_metrics.append(metric) + + monkeypatch.setattr(agent_with_telemetry.telemetry_api, "log_event", capture_metric) + + def mock_create_task(coro, *, name=None): + return asyncio.run(coro) + + monkeypatch.setattr( + "llama_stack.providers.inline.agents.meta_reference.agent_instance.asyncio.create_task", mock_create_task + ) + + agent_with_telemetry._track_workflow("completed", 2.5) + + assert len(captured_metrics) == 2 + + # Check workflow count metric + count_metric = captured_metrics[0] + assert count_metric.metric == "llama_stack_agent_workflows_total" + assert count_metric.value == 1 + assert count_metric.attributes["status"] == "completed" + + # Check duration metric + duration_metric = captured_metrics[1] + assert duration_metric.metric == "llama_stack_agent_workflow_duration_seconds" + assert duration_metric.value == 2.5 + assert duration_metric.unit == "s" + + def test_tool_usage_metrics(self, agent_with_telemetry, monkeypatch): + """Test that tool usage metrics are emitted correctly""" + fake_span = FakeSpan() + monkeypatch.setattr( + "llama_stack.providers.inline.agents.meta_reference.agent_instance.get_current_span", lambda: fake_span + ) + + captured_metrics = [] + + async def capture_metric(metric): + captured_metrics.append(metric) + + monkeypatch.setattr(agent_with_telemetry.telemetry_api, "log_event", capture_metric) + + def mock_create_task(coro, *, name=None): + return asyncio.run(coro) + + monkeypatch.setattr( + "llama_stack.providers.inline.agents.meta_reference.agent_instance.asyncio.create_task", mock_create_task + ) + + agent_with_telemetry._track_tool("web_search") + + assert len(captured_metrics) == 1 + metric = captured_metrics[0] + assert metric.metric == "llama_stack_agent_tool_calls_total" + assert metric.attributes["tool"] == "web_search" + + def test_knowledge_search_tool_mapping(self, agent_with_telemetry, monkeypatch): + """Test that knowledge_search tool is mapped to rag""" + fake_span = FakeSpan() + monkeypatch.setattr( + "llama_stack.providers.inline.agents.meta_reference.agent_instance.get_current_span", lambda: fake_span + ) + + captured_metrics = [] + + async def capture_metric(metric): + captured_metrics.append(metric) + + monkeypatch.setattr(agent_with_telemetry.telemetry_api, "log_event", capture_metric) + + def mock_create_task(coro, *, name=None): + return asyncio.run(coro) + + monkeypatch.setattr( + "llama_stack.providers.inline.agents.meta_reference.agent_instance.asyncio.create_task", mock_create_task + ) + + agent_with_telemetry._track_tool("knowledge_search") + + assert len(captured_metrics) == 1 + metric = captured_metrics[0] + assert metric.attributes["tool"] == "rag" + + def test_no_telemetry_api(self, agent_without_telemetry): + """Test that methods work gracefully when telemetry_api is None""" + # These should not crash + agent_without_telemetry._track_step() + agent_without_telemetry._track_workflow("failed", 1.0) + agent_without_telemetry._track_tool("web_search") + + def test_no_active_span(self, agent_with_telemetry, monkeypatch): + """Test that methods work gracefully when no span is active""" + monkeypatch.setattr( + "llama_stack.providers.inline.agents.meta_reference.agent_instance.get_current_span", lambda: None + ) + + # These should not crash and should not call telemetry + agent_with_telemetry._track_step() + agent_with_telemetry._track_workflow("failed", 1.0) + agent_with_telemetry._track_tool("web_search") + + # Telemetry should not have been called + agent_with_telemetry.telemetry_api.log_event.assert_not_called() diff --git a/tests/unit/providers/telemetry/__init__.py b/tests/unit/providers/telemetry/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/tests/unit/providers/telemetry/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/tests/unit/providers/telemetry/test_agent_metrics_histogram.py b/tests/unit/providers/telemetry/test_agent_metrics_histogram.py new file mode 100644 index 000000000..d0ee196f0 --- /dev/null +++ b/tests/unit/providers/telemetry/test_agent_metrics_histogram.py @@ -0,0 +1,244 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import Mock + +import pytest + +from llama_stack.apis.telemetry import MetricEvent, MetricType +from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig +from llama_stack.providers.inline.telemetry.meta_reference.telemetry import TelemetryAdapter + + +class TestAgentMetricsHistogram: + """Tests for agent histogram metrics""" + + @pytest.fixture + def config(self): + return TelemetryConfig(service_name="test-service", sinks=[]) + + @pytest.fixture + def adapter(self, config): + adapter = TelemetryAdapter(config, {}) + adapter.meter = Mock() # skip otel setup + return adapter + + def test_histogram_creation(self, adapter): + mock_hist = Mock() + adapter.meter.create_histogram.return_value = mock_hist + + from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE + + _GLOBAL_STORAGE["histograms"] = {} + + result = adapter._get_or_create_histogram("test_histogram", "s") + + assert result == mock_hist + adapter.meter.create_histogram.assert_called_once_with( + name="test_histogram", + unit="s", + description="test histogram", + ) + assert _GLOBAL_STORAGE["histograms"]["test_histogram"] == mock_hist + + def test_histogram_reuse(self, adapter): + mock_hist = Mock() + from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE + + _GLOBAL_STORAGE["histograms"] = {"existing_histogram": mock_hist} + + result = adapter._get_or_create_histogram("existing_histogram", "ms") + + assert result == mock_hist + adapter.meter.create_histogram.assert_not_called() + + def test_workflow_duration_histogram(self, adapter): + mock_hist = Mock() + adapter.meter.create_histogram.return_value = mock_hist + + from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE + + _GLOBAL_STORAGE["histograms"] = {} + + event = MetricEvent( + trace_id="123", + span_id="456", + metric="llama_stack_agent_workflow_duration_seconds", + value=15.7, + timestamp=1234567890.0, + unit="s", + attributes={"agent_id": "test-agent"}, + metric_type=MetricType.HISTOGRAM, + ) + + adapter._log_metric(event) + + adapter.meter.create_histogram.assert_called_once_with( + name="llama_stack_agent_workflow_duration_seconds", + unit="s", + description="llama stack agent workflow duration seconds", + ) + mock_hist.record.assert_called_once_with(15.7, attributes={"agent_id": "test-agent"}) + + def test_duration_buckets_configured_via_views(self, adapter): + mock_hist = Mock() + adapter.meter.create_histogram.return_value = mock_hist + + from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE + + _GLOBAL_STORAGE["histograms"] = {} + + event = MetricEvent( + trace_id="123", + span_id="456", + metric="custom_duration_seconds", + value=5.2, + timestamp=1234567890.0, + unit="s", + attributes={}, + metric_type=MetricType.HISTOGRAM, + ) + + adapter._log_metric(event) + + # buckets configured via otel views, not passed to create_histogram + mock_hist.record.assert_called_once_with(5.2, attributes={}) + + def test_non_duration_uses_counter(self, adapter): + mock_counter = Mock() + adapter.meter.create_counter.return_value = mock_counter + + from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE + + _GLOBAL_STORAGE["counters"] = {} + + event = MetricEvent( + trace_id="123", + span_id="456", + metric="llama_stack_agent_workflows_total", + value=1, + timestamp=1234567890.0, + unit="1", + attributes={"agent_id": "test-agent", "status": "completed"}, + ) + + adapter._log_metric(event) + + adapter.meter.create_counter.assert_called_once() + adapter.meter.create_histogram.assert_not_called() + mock_counter.add.assert_called_once_with(1, attributes={"agent_id": "test-agent", "status": "completed"}) + + def test_no_meter_doesnt_crash(self, adapter): + adapter.meter = None + + event = MetricEvent( + trace_id="123", + span_id="456", + metric="test_duration_seconds", + value=1.0, + timestamp=1234567890.0, + unit="s", + attributes={}, + ) + + adapter._log_metric(event) # shouldn't crash + + def test_histogram_vs_counter_by_type(self, adapter): + mock_hist = Mock() + mock_counter = Mock() + adapter.meter.create_histogram.return_value = mock_hist + adapter.meter.create_counter.return_value = mock_counter + + from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE + + _GLOBAL_STORAGE["histograms"] = {} + _GLOBAL_STORAGE["counters"] = {} + + # histogram metric + hist_event = MetricEvent( + trace_id="123", + span_id="456", + metric="workflow_duration_seconds", + value=1.0, + timestamp=1234567890.0, + unit="s", + attributes={}, + metric_type=MetricType.HISTOGRAM, + ) + adapter._log_metric(hist_event) + mock_hist.record.assert_called() + + # counter metric (default type) + counter_event = MetricEvent( + trace_id="123", + span_id="456", + metric="workflow_total", + value=1, + timestamp=1234567890.0, + unit="1", + attributes={}, + ) + adapter._log_metric(counter_event) + mock_counter.add.assert_called() + + def test_storage_separation(self, adapter): + mock_hist = Mock() + mock_counter = Mock() + adapter.meter.create_histogram.return_value = mock_hist + adapter.meter.create_counter.return_value = mock_counter + + from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE + + _GLOBAL_STORAGE["histograms"] = {} + _GLOBAL_STORAGE["counters"] = {} + + # create both types + hist_event = MetricEvent( + trace_id="123", + span_id="456", + metric="test_duration_seconds", + value=1.0, + timestamp=1234567890.0, + unit="s", + attributes={}, + metric_type=MetricType.HISTOGRAM, + ) + counter_event = MetricEvent( + trace_id="123", + span_id="456", + metric="test_counter", + value=1, + timestamp=1234567890.0, + unit="1", + attributes={}, + ) + + adapter._log_metric(hist_event) + adapter._log_metric(counter_event) + + # check they're stored separately + assert "test_duration_seconds" in _GLOBAL_STORAGE["histograms"] + assert "test_counter" in _GLOBAL_STORAGE["counters"] + assert "test_duration_seconds" not in _GLOBAL_STORAGE["counters"] + assert "test_counter" not in _GLOBAL_STORAGE["histograms"] + + def test_histogram_uses_views_for_buckets(self, adapter): + mock_hist = Mock() + adapter.meter.create_histogram.return_value = mock_hist + + from llama_stack.providers.inline.telemetry.meta_reference.telemetry import _GLOBAL_STORAGE + + _GLOBAL_STORAGE["histograms"] = {} + + result = adapter._get_or_create_histogram("test_histogram", "s") + + # buckets come from otel views, not create_histogram params + adapter.meter.create_histogram.assert_called_once_with( + name="test_histogram", + unit="s", + description="test histogram", + ) + assert result == mock_hist