rework(telemetry): remove legacy telemetry api

This commit is contained in:
Emilio Garcia 2025-10-07 17:25:54 -04:00
parent 0c843ec87f
commit 8e29d0eb79
13 changed files with 397 additions and 1645 deletions

View file

@ -7,7 +7,6 @@
import asyncio
import time
from collections.abc import AsyncGenerator, AsyncIterator
from datetime import UTC, datetime
from typing import Annotated, Any
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
@ -45,33 +44,53 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
from llama_stack.apis.telemetry import MetricInResponse
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.telemetry.tracing import enqueue_event, get_current_span
from opentelemetry import trace, metrics
logger = get_logger(name=__name__, category="core::routers")
class InferenceRouterTelemetry:
"""Telemetry for InferenceRouter"""
def __init__(self):
meter = metrics.get_meter(__name__)
self.prompt_tokens = meter.create_counter(
"prompt_tokens",
unit="tokens",
description="Number of tokens in the prompt",
)
self.completion_tokens = meter.create_counter(
"completion_tokens",
unit="tokens",
description="Number of tokens in the completion",
)
self.total_tokens = meter.create_counter(
"total_tokens",
unit="tokens",
description="Total number of tokens used",
)
class InferenceRouter(Inference):
"""Routes to an provider based on the model"""
def __init__(
self,
routing_table: RoutingTable,
telemetry: Telemetry | None = None,
store: InferenceStore | None = None,
) -> None:
logger.debug("Initializing InferenceRouter")
self.routing_table = routing_table
self.telemetry = telemetry
self.store = store
if self.telemetry:
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
self.telemetry = InferenceRouterTelemetry()
async def initialize(self) -> None:
logger.debug("InferenceRouter.initialize")
@ -97,64 +116,25 @@ class InferenceRouter(Inference):
)
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
def _construct_metrics(
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
model: Model,
) -> list[MetricEvent]:
"""Constructs a list of MetricEvent objects containing token usage metrics.
Args:
prompt_tokens: Number of tokens in the prompt
completion_tokens: Number of tokens in the completion
total_tokens: Total number of tokens used
model: Model object containing model_id and provider_id
Returns:
List of MetricEvent objects with token usage metrics
"""
span = get_current_span()
if span is None:
logger.warning("No span found for token usage metrics")
return []
metrics = [
("prompt_tokens", prompt_tokens),
("completion_tokens", completion_tokens),
("total_tokens", total_tokens),
]
metric_events = []
for metric_name, value in metrics:
metric_events.append(
MetricEvent(
trace_id=span.trace_id,
span_id=span.span_id,
metric=metric_name,
value=value,
timestamp=datetime.now(UTC),
unit="tokens",
attributes={
"model_id": model.model_id,
"provider_id": model.provider_id,
},
)
)
return metric_events
async def _compute_and_log_token_usage(
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
model: Model,
) -> list[MetricInResponse]:
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
def _record_metrics(
self, prompt_tokens: int | None, completion_tokens: int | None, total_tokens: int | None, model: Model
):
if self.telemetry:
for metric in metrics:
enqueue_event(metric)
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
span = trace.get_current_span()
span_context = span.get_span_context()
attributes = {
"model_id": model.model_id,
"provider_id": model.provider_id,
"trace_id": span_context.trace_id,
"span_id": span_context.span_id,
}
if prompt_tokens:
self.telemetry.prompt_tokens.add(prompt_tokens, attributes)
if completion_tokens:
self.telemetry.completion_tokens.add(completion_tokens, attributes)
if total_tokens:
self.telemetry.total_tokens.add(total_tokens, attributes)
async def _count_tokens(
self,
@ -237,19 +217,13 @@ class InferenceRouter(Inference):
response = await provider.openai_completion(**params)
if self.telemetry:
metrics = self._construct_metrics(
self._record_metrics(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
model=model_obj,
)
for metric in metrics:
enqueue_event(metric)
# these metrics will show up in the client response.
response.metrics = (
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
)
return response
async def openai_chat_completion(
@ -343,18 +317,13 @@ class InferenceRouter(Inference):
asyncio.create_task(self.store.store_chat_completion(response, messages))
if self.telemetry:
metrics = self._construct_metrics(
self._record_metrics(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
model=model_obj,
)
for metric in metrics:
enqueue_event(metric)
# these metrics will show up in the client response.
response.metrics = (
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
)
return response
async def openai_embeddings(
@ -466,77 +435,15 @@ class InferenceRouter(Inference):
# Create a separate span for streaming completion metrics
if self.telemetry:
# Log metrics in the new span context
completion_metrics = self._construct_metrics(
prompt_tokens=prompt_tokens,
# Only log completion and total tokens
completion_metrics = self._record_metrics(
prompt_tokens=None,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
model=model,
)
for metric in completion_metrics:
if metric.metric in [
"completion_tokens",
"total_tokens",
]: # Only log completion and total tokens
enqueue_event(metric)
# Return metrics in response
async_metrics = [
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
]
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
else:
# Fallback if no telemetry
completion_metrics = self._construct_metrics(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
async_metrics = [
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
]
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
yield chunk
async def count_tokens_and_compute_metrics(
self,
response: ChatCompletionResponse | CompletionResponse,
prompt_tokens,
model,
tool_prompt_format: ToolPromptFormat | None = None,
):
if isinstance(response, ChatCompletionResponse):
content = [response.completion_message]
else:
content = response.content
completion_tokens = await self._count_tokens(messages=content, tool_prompt_format=tool_prompt_format)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
# Create a separate span for completion metrics
if self.telemetry:
# Log metrics in the new span context
completion_metrics = self._construct_metrics(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
model=model,
)
for metric in completion_metrics:
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
enqueue_event(metric)
# Return metrics in response
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics]
# Fallback if no telemetry
metrics = self._construct_metrics(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
async def stream_tokens_and_compute_metrics_openai_chat(
self,
response: AsyncIterator[OpenAIChatCompletionChunk],
@ -612,14 +519,12 @@ class InferenceRouter(Inference):
# Add metrics to the chunk
if self.telemetry and chunk.usage:
metrics = self._construct_metrics(
metrics = self._record_metrics(
prompt_tokens=chunk.usage.prompt_tokens,
completion_tokens=chunk.usage.completion_tokens,
total_tokens=chunk.usage.total_tokens,
model=model,
)
for metric in metrics:
enqueue_event(metric)
yield chunk
finally: