fix(metrics): capture token metrics using auto instrumentation

This commit is contained in:
Emilio Garcia 2025-11-10 15:53:53 -05:00
parent 5ea1be69fe
commit 153e21bc21
6 changed files with 6 additions and 118 deletions

View file

@ -392,8 +392,6 @@ async def instantiate_provider(
args = [config, deps]
if "policy" in inspect.signature(getattr(module, method)).parameters:
args.append(policy)
if "telemetry_enabled" in inspect.signature(getattr(module, method)).parameters and run_config.telemetry:
args.append(run_config.telemetry.enabled)
fn = getattr(module, method)
impl = await fn(*args)

View file

@ -85,7 +85,6 @@ async def get_auto_router_impl(
)
await inference_store.initialize()
api_to_dep_impl["store"] = inference_store
api_to_dep_impl["telemetry_enabled"] = run_config.telemetry.enabled
elif api == Api.vector_io:
api_to_dep_impl["vector_stores_config"] = run_config.vector_stores

View file

@ -7,7 +7,6 @@
import asyncio
import time
from collections.abc import AsyncIterator
from datetime import UTC, datetime
from typing import Annotated, Any
from fastapi import Body
@ -15,11 +14,7 @@ from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatC
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
from pydantic import TypeAdapter
from llama_stack.core.telemetry.telemetry import MetricEvent
from llama_stack.core.telemetry.tracing import enqueue_event, get_current_span
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack_api import (
HealthResponse,
@ -49,6 +44,12 @@ from llama_stack_api import (
RerankResponse,
RoutingTable,
)
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
from pydantic import TypeAdapter
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.inference_store import InferenceStore
logger = get_logger(name=__name__, category="core::routers")
@ -60,15 +61,10 @@ class InferenceRouter(Inference):
self,
routing_table: RoutingTable,
store: InferenceStore | None = None,
telemetry_enabled: bool = False,
) -> None:
logger.debug("Initializing InferenceRouter")
self.routing_table = routing_table
self.telemetry_enabled = telemetry_enabled
self.store = store
if self.telemetry_enabled:
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
async def initialize(self) -> None:
logger.debug("InferenceRouter.initialize")
@ -94,54 +90,6 @@ class InferenceRouter(Inference):
)
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
def _construct_metrics(
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
fully_qualified_model_id: str,
provider_id: str,
) -> list[MetricEvent]:
"""Constructs a list of MetricEvent objects containing token usage metrics.
Args:
prompt_tokens: Number of tokens in the prompt
completion_tokens: Number of tokens in the completion
total_tokens: Total number of tokens used
fully_qualified_model_id:
provider_id: The provider identifier
Returns:
List of MetricEvent objects with token usage metrics
"""
span = get_current_span()
if span is None:
logger.warning("No span found for token usage metrics")
return []
metrics = [
("prompt_tokens", prompt_tokens),
("completion_tokens", completion_tokens),
("total_tokens", total_tokens),
]
metric_events = []
for metric_name, value in metrics:
metric_events.append(
MetricEvent(
trace_id=span.trace_id,
span_id=span.span_id,
metric=metric_name,
value=value,
timestamp=datetime.now(UTC),
unit="tokens",
attributes={
"model_id": fully_qualified_model_id,
"provider_id": provider_id,
},
)
)
return metric_events
async def _get_model_provider(self, model_id: str, expected_model_type: str) -> tuple[Inference, str]:
model = await self.routing_table.get_object_by_identifier("model", model_id)
if model:
@ -186,26 +134,9 @@ class InferenceRouter(Inference):
if params.stream:
return await provider.openai_completion(params)
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
response = await provider.openai_completion(params)
response.model = request_model_id
if self.telemetry_enabled and response.usage is not None:
metrics = self._construct_metrics(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
fully_qualified_model_id=request_model_id,
provider_id=provider.__provider_id__,
)
for metric in metrics:
enqueue_event(metric)
# these metrics will show up in the client response.
response.metrics = (
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
)
return response
async def openai_chat_completion(
@ -254,20 +185,6 @@ class InferenceRouter(Inference):
if self.store:
asyncio.create_task(self.store.store_chat_completion(response, params.messages))
if self.telemetry_enabled and response.usage is not None:
metrics = self._construct_metrics(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
fully_qualified_model_id=request_model_id,
provider_id=provider.__provider_id__,
)
for metric in metrics:
enqueue_event(metric)
# these metrics will show up in the client response.
response.metrics = (
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
)
return response
async def openai_embeddings(
@ -411,18 +328,6 @@ class InferenceRouter(Inference):
for choice_data in choices_data.values():
completion_text += "".join(choice_data["content_parts"])
# Add metrics to the chunk
if self.telemetry_enabled and hasattr(chunk, "usage") and chunk.usage:
metrics = self._construct_metrics(
prompt_tokens=chunk.usage.prompt_tokens,
completion_tokens=chunk.usage.completion_tokens,
total_tokens=chunk.usage.total_tokens,
fully_qualified_model_id=fully_qualified_model_id,
provider_id=provider_id,
)
for metric in metrics:
enqueue_event(metric)
yield chunk
finally:
# Store the final assembled completion

View file

@ -490,16 +490,6 @@ class Telemetry:
)
return cast(metrics.Counter, _GLOBAL_STORAGE["counters"][name])
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
assert self.meter is not None
if name not in _GLOBAL_STORAGE["gauges"]:
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
name=name,
unit=unit,
description=f"Gauge for {name}",
)
return cast(metrics.ObservableGauge, _GLOBAL_STORAGE["gauges"][name])
def _get_or_create_histogram(self, name: str, unit: str) -> metrics.Histogram:
assert self.meter is not None
if name not in _GLOBAL_STORAGE["histograms"]:

View file

@ -15,7 +15,6 @@ async def get_provider_impl(
config: MetaReferenceAgentsImplConfig,
deps: dict[Api, Any],
policy: list[AccessRule],
telemetry_enabled: bool = False,
):
from .agents import MetaReferenceAgentsImpl
@ -28,7 +27,6 @@ async def get_provider_impl(
deps[Api.tool_groups],
deps[Api.conversations],
policy,
telemetry_enabled,
)
await impl.initialize()
return impl

View file

@ -46,7 +46,6 @@ class MetaReferenceAgentsImpl(Agents):
tool_groups_api: ToolGroups,
conversations_api: Conversations,
policy: list[AccessRule],
telemetry_enabled: bool = False,
):
self.config = config
self.inference_api = inference_api
@ -55,7 +54,6 @@ class MetaReferenceAgentsImpl(Agents):
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
self.conversations_api = conversations_api
self.telemetry_enabled = telemetry_enabled
self.in_memory_store = InmemoryKVStoreImpl()
self.openai_responses_impl: OpenAIResponsesImpl | None = None