fix: actually propagate inference metrics

currently metrics are only propagated in `chat_completion` and `completion`

since most providers use the openai_.. routes as the default in llama-stack-client inference chat-completion, metrics are currently not working as expected.

in order to get them working the following had to be done:

1. get the completion as usual
2. use new openai_ versions of the metric gathering functions which use .usage from the OpenAI.. response types to gather the metrics which are already populated.
3. define a `stream_generator` which counts the tokens and computes the metrics
4. use a NEW span and log_metrics because the span of the request ends before this processing is complete, leading to no logging unless a custom span is used
5. add metrics to response

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-07-11 20:50:48 -04:00
parent c252dfa3ef
commit d52722b0d1
3 changed files with 335 additions and 217 deletions

View file

@ -7,6 +7,7 @@
import asyncio import asyncio
import time import time
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator, AsyncIterator
from datetime import UTC, datetime
from typing import Annotated, Any from typing import Annotated, Any
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
@ -25,14 +26,21 @@ from llama_stack.apis.inference import (
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk, ChatCompletionResponseStreamChunk,
CompletionMessage, CompletionMessage,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType, EmbeddingTaskType,
Inference, Inference,
ListOpenAIChatCompletionResponse, ListOpenAIChatCompletionResponse,
LogProbConfig, LogProbConfig,
Message, Message,
OpenAIAssistantMessageParam,
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionChunk, OpenAIChatCompletionChunk,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoice,
OpenAIChoiceLogprobs,
OpenAICompletion, OpenAICompletion,
OpenAICompletionWithInputMessages, OpenAICompletionWithInputMessages,
OpenAIEmbeddingsResponse, OpenAIEmbeddingsResponse,
@ -55,7 +63,6 @@ from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
from llama_stack.providers.utils.inference.inference_store import InferenceStore from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.inference.stream_utils import stream_and_store_openai_completion
from llama_stack.providers.utils.telemetry.tracing import get_current_span from llama_stack.providers.utils.telemetry.tracing import get_current_span
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core")
@ -119,6 +126,7 @@ class InferenceRouter(Inference):
if span is None: if span is None:
logger.warning("No span found for token usage metrics") logger.warning("No span found for token usage metrics")
return [] return []
metrics = [ metrics = [
("prompt_tokens", prompt_tokens), ("prompt_tokens", prompt_tokens),
("completion_tokens", completion_tokens), ("completion_tokens", completion_tokens),
@ -132,7 +140,7 @@ class InferenceRouter(Inference):
span_id=span.span_id, span_id=span.span_id,
metric=metric_name, metric=metric_name,
value=value, value=value,
timestamp=time.time(), timestamp=datetime.now(UTC),
unit="tokens", unit="tokens",
attributes={ attributes={
"model_id": model.model_id, "model_id": model.model_id,
@ -234,49 +242,26 @@ class InferenceRouter(Inference):
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format) prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
if stream: if stream:
response_stream = await provider.chat_completion(**params)
async def stream_generator(): return self.stream_tokens_and_compute_metrics(
completion_text = "" response=response_stream,
async for chunk in await provider.chat_completion(**params): prompt_tokens=prompt_tokens,
if chunk.event.event_type == ChatCompletionResponseEventType.progress: model=model,
if chunk.event.delta.type == "text": tool_prompt_format=tool_config.tool_prompt_format,
completion_text += chunk.event.delta.text
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
completion_tokens = await self._count_tokens(
[
CompletionMessage(
content=completion_text,
stop_reason=StopReason.end_of_turn,
)
],
tool_config.tool_prompt_format,
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
yield chunk
return stream_generator()
else:
response = await provider.chat_completion(**params)
completion_tokens = await self._count_tokens(
[response.completion_message],
tool_config.tool_prompt_format,
) )
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage( response = await provider.chat_completion(**params)
prompt_tokens or 0, metrics = await self.count_tokens_and_compute_metrics(
completion_tokens or 0, response=response,
total_tokens, prompt_tokens=prompt_tokens,
model, model=model,
) tool_prompt_format=tool_config.tool_prompt_format,
response.metrics = metrics if response.metrics is None else response.metrics + metrics )
return response # these metrics will show up in the client response.
response.metrics = (
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
)
return response
async def batch_chat_completion( async def batch_chat_completion(
self, self,
@ -332,39 +317,20 @@ class InferenceRouter(Inference):
) )
prompt_tokens = await self._count_tokens(content) prompt_tokens = await self._count_tokens(content)
response = await provider.completion(**params)
if stream: if stream:
return self.stream_tokens_and_compute_metrics(
async def stream_generator(): response=response,
completion_text = "" prompt_tokens=prompt_tokens,
async for chunk in await provider.completion(**params): model=model,
if hasattr(chunk, "delta"):
completion_text += chunk.delta
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
completion_tokens = await self._count_tokens(completion_text)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
yield chunk
return stream_generator()
else:
response = await provider.completion(**params)
completion_tokens = await self._count_tokens(response.content)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
) )
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response metrics = await self.count_tokens_and_compute_metrics(
response=response, prompt_tokens=prompt_tokens, model=model
)
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
async def batch_completion( async def batch_completion(
self, self,
@ -457,9 +423,29 @@ class InferenceRouter(Inference):
prompt_logprobs=prompt_logprobs, prompt_logprobs=prompt_logprobs,
suffix=suffix, suffix=suffix,
) )
provider = self.routing_table.get_provider_impl(model_obj.identifier)
if stream:
return await provider.openai_completion(**params)
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
# response_stream = await provider.openai_completion(**params)
provider = await self.routing_table.get_provider_impl(model_obj.identifier) response = await provider.openai_completion(**params)
return await provider.openai_completion(**params) if self.telemetry:
metrics = self._construct_metrics(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
model=model_obj,
)
for metric in metrics:
await self.telemetry.log_event(metric)
# these metrics will show up in the client response.
response.metrics = (
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
)
return response
async def openai_chat_completion( async def openai_chat_completion(
self, self,
@ -537,18 +523,38 @@ class InferenceRouter(Inference):
top_p=top_p, top_p=top_p,
user=user, user=user,
) )
provider = self.routing_table.get_provider_impl(model_obj.identifier)
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
if stream: if stream:
response_stream = await provider.openai_chat_completion(**params) response_stream = await provider.openai_chat_completion(**params)
if self.store:
return stream_and_store_openai_completion(response_stream, model, self.store, messages) # For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk]
return response_stream # We need to add metrics to each chunk and store the final completion
else: return self.stream_tokens_and_compute_metrics_openai_chat(
response = await self._nonstream_openai_chat_completion(provider, params) response=response_stream,
if self.store: model=model_obj,
await self.store.store_chat_completion(response, messages) messages=messages,
return response )
response = await self._nonstream_openai_chat_completion(provider, params)
# Store the response with the ID that will be returned to the client
if self.store:
await self.store.store_chat_completion(response, messages)
if self.telemetry:
metrics = self._construct_metrics(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
model=model_obj,
)
for metric in metrics:
await self.telemetry.log_event(metric)
# these metrics will show up in the client response.
response.metrics = (
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
)
return response
async def openai_embeddings( async def openai_embeddings(
self, self,
@ -625,3 +631,244 @@ class InferenceRouter(Inference):
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}" status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
) )
return health_statuses return health_statuses
async def stream_tokens_and_compute_metrics(
self,
response,
prompt_tokens,
model,
tool_prompt_format: ToolPromptFormat | None = None,
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None] | AsyncGenerator[CompletionResponseStreamChunk, None]:
completion_text = ""
async for chunk in response:
complete = False
if hasattr(chunk, "event"): # only ChatCompletions have .event
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
if chunk.event.delta.type == "text":
completion_text += chunk.event.delta.text
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
complete = True
completion_tokens = await self._count_tokens(
[
CompletionMessage(
content=completion_text,
stop_reason=StopReason.end_of_turn,
)
],
tool_prompt_format=tool_prompt_format,
)
else:
if hasattr(chunk, "delta"):
completion_text += chunk.delta
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
complete = True
completion_tokens = await self._count_tokens(completion_text)
# if we are done receiving tokens
if complete:
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
# Create a separate span for streaming completion metrics
if self.telemetry:
# Log metrics in the new span context
completion_metrics = self._construct_metrics(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
model=model,
)
for metric in completion_metrics:
if metric.metric in [
"completion_tokens",
"total_tokens",
]: # Only log completion and total tokens
await self.telemetry.log_event(metric)
# Return metrics in response
async_metrics = [
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
]
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
else:
# Fallback if no telemetry
completion_metrics = self._construct_metrics(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
async_metrics = [
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
]
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
yield chunk
async def count_tokens_and_compute_metrics(
self,
response: ChatCompletionResponse | CompletionResponse,
prompt_tokens,
model,
tool_prompt_format: ToolPromptFormat | None = None,
):
if isinstance(response, ChatCompletionResponse):
content = [response.completion_message]
else:
content = response.content
completion_tokens = await self._count_tokens(messages=content, tool_prompt_format=tool_prompt_format)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
# Create a separate span for completion metrics
if self.telemetry:
# Log metrics in the new span context
completion_metrics = self._construct_metrics(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
model=model,
)
for metric in completion_metrics:
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
await self.telemetry.log_event(metric)
# Return metrics in response
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics]
# Fallback if no telemetry
metrics = self._construct_metrics(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
async def stream_tokens_and_compute_metrics_openai_chat(
self,
response: AsyncIterator[OpenAIChatCompletionChunk],
model: Model,
messages: list[OpenAIMessageParam] | None = None,
) -> AsyncIterator[OpenAIChatCompletionChunk]:
"""Stream OpenAI chat completion chunks, compute metrics, and store the final completion."""
id = None
created = None
choices_data: dict[int, dict[str, Any]] = {}
try:
async for chunk in response:
# Skip None chunks
if chunk is None:
continue
# Capture ID and created timestamp from first chunk
if id is None and chunk.id:
id = chunk.id
if created is None and chunk.created:
created = chunk.created
# Accumulate choice data for final assembly
if chunk.choices:
for choice_delta in chunk.choices:
idx = choice_delta.index
if idx not in choices_data:
choices_data[idx] = {
"content_parts": [],
"tool_calls_builder": {},
"finish_reason": None,
"logprobs_content_parts": [],
}
current_choice_data = choices_data[idx]
if choice_delta.delta:
delta = choice_delta.delta
if delta.content:
current_choice_data["content_parts"].append(delta.content)
if delta.tool_calls:
for tool_call_delta in delta.tool_calls:
tc_idx = tool_call_delta.index
if tc_idx not in current_choice_data["tool_calls_builder"]:
current_choice_data["tool_calls_builder"][tc_idx] = {
"id": None,
"type": "function",
"function_name_parts": [],
"function_arguments_parts": [],
}
builder = current_choice_data["tool_calls_builder"][tc_idx]
if tool_call_delta.id:
builder["id"] = tool_call_delta.id
if tool_call_delta.type:
builder["type"] = tool_call_delta.type
if tool_call_delta.function:
if tool_call_delta.function.name:
builder["function_name_parts"].append(tool_call_delta.function.name)
if tool_call_delta.function.arguments:
builder["function_arguments_parts"].append(
tool_call_delta.function.arguments
)
if choice_delta.finish_reason:
current_choice_data["finish_reason"] = choice_delta.finish_reason
if choice_delta.logprobs and choice_delta.logprobs.content:
current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content)
# Compute metrics on final chunk
if chunk.choices and chunk.choices[0].finish_reason:
completion_text = ""
for choice_data in choices_data.values():
completion_text += "".join(choice_data["content_parts"])
# Add metrics to the chunk
if self.telemetry and chunk.usage:
metrics = self._construct_metrics(
prompt_tokens=chunk.usage.prompt_tokens,
completion_tokens=chunk.usage.completion_tokens,
total_tokens=chunk.usage.total_tokens,
model=model,
)
for metric in metrics:
await self.telemetry.log_event(metric)
yield chunk
finally:
# Store the final assembled completion
if id and self.store and messages:
assembled_choices: list[OpenAIChoice] = []
for choice_idx, choice_data in choices_data.items():
content_str = "".join(choice_data["content_parts"])
assembled_tool_calls: list[OpenAIChatCompletionToolCall] = []
if choice_data["tool_calls_builder"]:
for tc_build_data in choice_data["tool_calls_builder"].values():
if tc_build_data["id"]:
func_name = "".join(tc_build_data["function_name_parts"])
func_args = "".join(tc_build_data["function_arguments_parts"])
assembled_tool_calls.append(
OpenAIChatCompletionToolCall(
id=tc_build_data["id"],
type=tc_build_data["type"],
function=OpenAIChatCompletionToolCallFunction(
name=func_name, arguments=func_args
),
)
)
message = OpenAIAssistantMessageParam(
role="assistant",
content=content_str if content_str else None,
tool_calls=assembled_tool_calls if assembled_tool_calls else None,
)
logprobs_content = choice_data["logprobs_content_parts"]
final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None
assembled_choices.append(
OpenAIChoice(
finish_reason=choice_data["finish_reason"],
index=choice_idx,
message=message,
logprobs=final_logprobs,
)
)
final_response = OpenAIChatCompletion(
id=id,
choices=assembled_choices,
created=created or int(time.time()),
model=model.identifier,
object="chat.completion",
)
await self.store.store_chat_completion(final_response, messages)

View file

@ -1,129 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncIterator
from datetime import UTC, datetime
from typing import Any
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoice,
OpenAIChoiceLogprobs,
OpenAIMessageParam,
)
from llama_stack.providers.utils.inference.inference_store import InferenceStore
async def stream_and_store_openai_completion(
provider_stream: AsyncIterator[OpenAIChatCompletionChunk],
model: str,
store: InferenceStore,
input_messages: list[OpenAIMessageParam],
) -> AsyncIterator[OpenAIChatCompletionChunk]:
"""
Wraps a provider's stream, yields chunks, and stores the full completion at the end.
"""
id = None
created = None
choices_data: dict[int, dict[str, Any]] = {}
try:
async for chunk in provider_stream:
if id is None and chunk.id:
id = chunk.id
if created is None and chunk.created:
created = chunk.created
if chunk.choices:
for choice_delta in chunk.choices:
idx = choice_delta.index
if idx not in choices_data:
choices_data[idx] = {
"content_parts": [],
"tool_calls_builder": {},
"finish_reason": None,
"logprobs_content_parts": [],
}
current_choice_data = choices_data[idx]
if choice_delta.delta:
delta = choice_delta.delta
if delta.content:
current_choice_data["content_parts"].append(delta.content)
if delta.tool_calls:
for tool_call_delta in delta.tool_calls:
tc_idx = tool_call_delta.index
if tc_idx not in current_choice_data["tool_calls_builder"]:
# Initialize with correct structure for _ToolCallBuilderData
current_choice_data["tool_calls_builder"][tc_idx] = {
"id": None,
"type": "function",
"function_name_parts": [],
"function_arguments_parts": [],
}
builder = current_choice_data["tool_calls_builder"][tc_idx]
if tool_call_delta.id:
builder["id"] = tool_call_delta.id
if tool_call_delta.type:
builder["type"] = tool_call_delta.type
if tool_call_delta.function:
if tool_call_delta.function.name:
builder["function_name_parts"].append(tool_call_delta.function.name)
if tool_call_delta.function.arguments:
builder["function_arguments_parts"].append(tool_call_delta.function.arguments)
if choice_delta.finish_reason:
current_choice_data["finish_reason"] = choice_delta.finish_reason
if choice_delta.logprobs and choice_delta.logprobs.content:
# Ensure that we are extending with the correct type
current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content)
yield chunk
finally:
if id:
assembled_choices: list[OpenAIChoice] = []
for choice_idx, choice_data in choices_data.items():
content_str = "".join(choice_data["content_parts"])
assembled_tool_calls: list[OpenAIChatCompletionToolCall] = []
if choice_data["tool_calls_builder"]:
for tc_build_data in choice_data["tool_calls_builder"].values():
if tc_build_data["id"]:
func_name = "".join(tc_build_data["function_name_parts"])
func_args = "".join(tc_build_data["function_arguments_parts"])
assembled_tool_calls.append(
OpenAIChatCompletionToolCall(
id=tc_build_data["id"],
type=tc_build_data["type"], # No or "function" needed, already set
function=OpenAIChatCompletionToolCallFunction(name=func_name, arguments=func_args),
)
)
message = OpenAIAssistantMessageParam(
role="assistant",
content=content_str if content_str else None,
tool_calls=assembled_tool_calls if assembled_tool_calls else None,
)
logprobs_content = choice_data["logprobs_content_parts"]
final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None
assembled_choices.append(
OpenAIChoice(
finish_reason=choice_data["finish_reason"],
index=choice_idx,
message=message,
logprobs=final_logprobs,
)
)
final_response = OpenAIChatCompletion(
id=id,
choices=assembled_choices,
created=created or int(datetime.now(UTC).timestamp()),
model=model,
object="chat.completion",
)
await store.store_chat_completion(final_response, input_messages)

View file

@ -81,7 +81,7 @@ BACKGROUND_LOGGER = None
class BackgroundLogger: class BackgroundLogger:
def __init__(self, api: Telemetry, capacity: int = 1000): def __init__(self, api: Telemetry, capacity: int = 100000):
self.api = api self.api = api
self.log_queue = queue.Queue(maxsize=capacity) self.log_queue = queue.Queue(maxsize=capacity)
self.worker_thread = threading.Thread(target=self._process_logs, daemon=True) self.worker_thread = threading.Thread(target=self._process_logs, daemon=True)