fix(metrics): capture token metrics using auto instrumentation

This commit is contained in:
Emilio Garcia 2025-11-10 15:53:53 -05:00
parent aa2a7dae07
commit f5455a64d1
6 changed files with 6 additions and 118 deletions

View file

@ -392,8 +392,6 @@ async def instantiate_provider(
args = [config, deps] args = [config, deps]
if "policy" in inspect.signature(getattr(module, method)).parameters: if "policy" in inspect.signature(getattr(module, method)).parameters:
args.append(policy) args.append(policy)
if "telemetry_enabled" in inspect.signature(getattr(module, method)).parameters and run_config.telemetry:
args.append(run_config.telemetry.enabled)
fn = getattr(module, method) fn = getattr(module, method)
impl = await fn(*args) impl = await fn(*args)

View file

@ -85,7 +85,6 @@ async def get_auto_router_impl(
) )
await inference_store.initialize() await inference_store.initialize()
api_to_dep_impl["store"] = inference_store api_to_dep_impl["store"] = inference_store
api_to_dep_impl["telemetry_enabled"] = run_config.telemetry.enabled
elif api == Api.vector_io: elif api == Api.vector_io:
api_to_dep_impl["vector_stores_config"] = run_config.vector_stores api_to_dep_impl["vector_stores_config"] = run_config.vector_stores

View file

@ -7,7 +7,6 @@
import asyncio import asyncio
import time import time
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from datetime import UTC, datetime
from typing import Annotated, Any from typing import Annotated, Any
from fastapi import Body from fastapi import Body
@ -15,11 +14,7 @@ from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatC
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
from pydantic import TypeAdapter from pydantic import TypeAdapter
from llama_stack.core.telemetry.telemetry import MetricEvent
from llama_stack.core.telemetry.tracing import enqueue_event, get_current_span
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.inference_store import InferenceStore from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack_api import ( from llama_stack_api import (
HealthResponse, HealthResponse,
@ -49,6 +44,12 @@ from llama_stack_api import (
RerankResponse, RerankResponse,
RoutingTable, RoutingTable,
) )
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
from pydantic import TypeAdapter
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.inference_store import InferenceStore
logger = get_logger(name=__name__, category="core::routers") logger = get_logger(name=__name__, category="core::routers")
@ -60,15 +61,10 @@ class InferenceRouter(Inference):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
store: InferenceStore | None = None, store: InferenceStore | None = None,
telemetry_enabled: bool = False,
) -> None: ) -> None:
logger.debug("Initializing InferenceRouter") logger.debug("Initializing InferenceRouter")
self.routing_table = routing_table self.routing_table = routing_table
self.telemetry_enabled = telemetry_enabled
self.store = store self.store = store
if self.telemetry_enabled:
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
async def initialize(self) -> None: async def initialize(self) -> None:
logger.debug("InferenceRouter.initialize") logger.debug("InferenceRouter.initialize")
@ -94,54 +90,6 @@ class InferenceRouter(Inference):
) )
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
def _construct_metrics(
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
fully_qualified_model_id: str,
provider_id: str,
) -> list[MetricEvent]:
"""Constructs a list of MetricEvent objects containing token usage metrics.
Args:
prompt_tokens: Number of tokens in the prompt
completion_tokens: Number of tokens in the completion
total_tokens: Total number of tokens used
fully_qualified_model_id:
provider_id: The provider identifier
Returns:
List of MetricEvent objects with token usage metrics
"""
span = get_current_span()
if span is None:
logger.warning("No span found for token usage metrics")
return []
metrics = [
("prompt_tokens", prompt_tokens),
("completion_tokens", completion_tokens),
("total_tokens", total_tokens),
]
metric_events = []
for metric_name, value in metrics:
metric_events.append(
MetricEvent(
trace_id=span.trace_id,
span_id=span.span_id,
metric=metric_name,
value=value,
timestamp=datetime.now(UTC),
unit="tokens",
attributes={
"model_id": fully_qualified_model_id,
"provider_id": provider_id,
},
)
)
return metric_events
async def _get_model_provider(self, model_id: str, expected_model_type: str) -> tuple[Inference, str]: async def _get_model_provider(self, model_id: str, expected_model_type: str) -> tuple[Inference, str]:
model = await self.routing_table.get_object_by_identifier("model", model_id) model = await self.routing_table.get_object_by_identifier("model", model_id)
if model: if model:
@ -186,26 +134,9 @@ class InferenceRouter(Inference):
if params.stream: if params.stream:
return await provider.openai_completion(params) return await provider.openai_completion(params)
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
response = await provider.openai_completion(params) response = await provider.openai_completion(params)
response.model = request_model_id response.model = request_model_id
if self.telemetry_enabled and response.usage is not None:
metrics = self._construct_metrics(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
fully_qualified_model_id=request_model_id,
provider_id=provider.__provider_id__,
)
for metric in metrics:
enqueue_event(metric)
# these metrics will show up in the client response.
response.metrics = (
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
)
return response return response
async def openai_chat_completion( async def openai_chat_completion(
@ -254,20 +185,6 @@ class InferenceRouter(Inference):
if self.store: if self.store:
asyncio.create_task(self.store.store_chat_completion(response, params.messages)) asyncio.create_task(self.store.store_chat_completion(response, params.messages))
if self.telemetry_enabled and response.usage is not None:
metrics = self._construct_metrics(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
fully_qualified_model_id=request_model_id,
provider_id=provider.__provider_id__,
)
for metric in metrics:
enqueue_event(metric)
# these metrics will show up in the client response.
response.metrics = (
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
)
return response return response
async def openai_embeddings( async def openai_embeddings(
@ -411,18 +328,6 @@ class InferenceRouter(Inference):
for choice_data in choices_data.values(): for choice_data in choices_data.values():
completion_text += "".join(choice_data["content_parts"]) completion_text += "".join(choice_data["content_parts"])
# Add metrics to the chunk
if self.telemetry_enabled and hasattr(chunk, "usage") and chunk.usage:
metrics = self._construct_metrics(
prompt_tokens=chunk.usage.prompt_tokens,
completion_tokens=chunk.usage.completion_tokens,
total_tokens=chunk.usage.total_tokens,
fully_qualified_model_id=fully_qualified_model_id,
provider_id=provider_id,
)
for metric in metrics:
enqueue_event(metric)
yield chunk yield chunk
finally: finally:
# Store the final assembled completion # Store the final assembled completion

View file

@ -490,16 +490,6 @@ class Telemetry:
) )
return cast(metrics.Counter, _GLOBAL_STORAGE["counters"][name]) return cast(metrics.Counter, _GLOBAL_STORAGE["counters"][name])
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
assert self.meter is not None
if name not in _GLOBAL_STORAGE["gauges"]:
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
name=name,
unit=unit,
description=f"Gauge for {name}",
)
return cast(metrics.ObservableGauge, _GLOBAL_STORAGE["gauges"][name])
def _get_or_create_histogram(self, name: str, unit: str) -> metrics.Histogram: def _get_or_create_histogram(self, name: str, unit: str) -> metrics.Histogram:
assert self.meter is not None assert self.meter is not None
if name not in _GLOBAL_STORAGE["histograms"]: if name not in _GLOBAL_STORAGE["histograms"]:

View file

@ -15,7 +15,6 @@ async def get_provider_impl(
config: MetaReferenceAgentsImplConfig, config: MetaReferenceAgentsImplConfig,
deps: dict[Api, Any], deps: dict[Api, Any],
policy: list[AccessRule], policy: list[AccessRule],
telemetry_enabled: bool = False,
): ):
from .agents import MetaReferenceAgentsImpl from .agents import MetaReferenceAgentsImpl
@ -29,7 +28,6 @@ async def get_provider_impl(
deps[Api.conversations], deps[Api.conversations],
deps[Api.prompts], deps[Api.prompts],
deps[Api.files], deps[Api.files],
telemetry_enabled,
policy, policy,
) )
await impl.initialize() await impl.initialize()

View file

@ -50,7 +50,6 @@ class MetaReferenceAgentsImpl(Agents):
prompts_api: Prompts, prompts_api: Prompts,
files_api: Files, files_api: Files,
policy: list[AccessRule], policy: list[AccessRule],
telemetry_enabled: bool = False,
): ):
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
@ -59,7 +58,6 @@ class MetaReferenceAgentsImpl(Agents):
self.tool_runtime_api = tool_runtime_api self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api self.tool_groups_api = tool_groups_api
self.conversations_api = conversations_api self.conversations_api = conversations_api
self.telemetry_enabled = telemetry_enabled
self.prompts_api = prompts_api self.prompts_api = prompts_api
self.files_api = files_api self.files_api = files_api
self.in_memory_store = InmemoryKVStoreImpl() self.in_memory_store = InmemoryKVStoreImpl()