From d52722b0d1ed4258ae06ca8b56d30a6663ca44da Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Fri, 11 Jul 2025 20:50:48 -0400 Subject: [PATCH] fix: actually propagate inference metrics currently metrics are only propagated in `chat_completion` and `completion` since most providers use the openai_.. routes as the default in llama-stack-client inference chat-completion, metrics are currently not working as expected. in order to get them working the following had to be done: 1. get the completion as usual 2. use new openai_ versions of the metric gathering functions which use .usage from the OpenAI.. response types to gather the metrics which are already populated. 3. define a `stream_generator` which counts the tokens and computes the metrics 4. use a NEW span and log_metrics because the span of the request ends before this processing is complete, leading to no logging unless a custom span is used 5. add metrics to response Signed-off-by: Charlie Doern --- llama_stack/core/routers/inference.py | 421 ++++++++++++++---- .../providers/utils/inference/stream_utils.py | 129 ------ .../providers/utils/telemetry/tracing.py | 2 +- 3 files changed, 335 insertions(+), 217 deletions(-) delete mode 100644 llama_stack/providers/utils/inference/stream_utils.py diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index 6152acd57..55cee6203 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -7,6 +7,7 @@ import asyncio import time from collections.abc import AsyncGenerator, AsyncIterator +from datetime import UTC, datetime from typing import Annotated, Any from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam @@ -25,14 +26,21 @@ from llama_stack.apis.inference import ( ChatCompletionResponseEventType, ChatCompletionResponseStreamChunk, CompletionMessage, + CompletionResponse, + CompletionResponseStreamChunk, EmbeddingsResponse, EmbeddingTaskType, Inference, ListOpenAIChatCompletionResponse, LogProbConfig, Message, + OpenAIAssistantMessageParam, OpenAIChatCompletion, OpenAIChatCompletionChunk, + OpenAIChatCompletionToolCall, + OpenAIChatCompletionToolCallFunction, + OpenAIChoice, + OpenAIChoiceLogprobs, OpenAICompletion, OpenAICompletionWithInputMessages, OpenAIEmbeddingsResponse, @@ -55,7 +63,6 @@ from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable from llama_stack.providers.utils.inference.inference_store import InferenceStore -from llama_stack.providers.utils.inference.stream_utils import stream_and_store_openai_completion from llama_stack.providers.utils.telemetry.tracing import get_current_span logger = get_logger(name=__name__, category="core") @@ -119,6 +126,7 @@ class InferenceRouter(Inference): if span is None: logger.warning("No span found for token usage metrics") return [] + metrics = [ ("prompt_tokens", prompt_tokens), ("completion_tokens", completion_tokens), @@ -132,7 +140,7 @@ class InferenceRouter(Inference): span_id=span.span_id, metric=metric_name, value=value, - timestamp=time.time(), + timestamp=datetime.now(UTC), unit="tokens", attributes={ "model_id": model.model_id, @@ -234,49 +242,26 @@ class InferenceRouter(Inference): prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format) if stream: - - async def stream_generator(): - completion_text = "" - async for chunk in await provider.chat_completion(**params): - if chunk.event.event_type == ChatCompletionResponseEventType.progress: - if chunk.event.delta.type == "text": - completion_text += chunk.event.delta.text - if chunk.event.event_type == ChatCompletionResponseEventType.complete: - completion_tokens = await self._count_tokens( - [ - CompletionMessage( - content=completion_text, - stop_reason=StopReason.end_of_turn, - ) - ], - tool_config.tool_prompt_format, - ) - total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) - metrics = await self._compute_and_log_token_usage( - prompt_tokens or 0, - completion_tokens or 0, - total_tokens, - model, - ) - chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics - yield chunk - - return stream_generator() - else: - response = await provider.chat_completion(**params) - completion_tokens = await self._count_tokens( - [response.completion_message], - tool_config.tool_prompt_format, + response_stream = await provider.chat_completion(**params) + return self.stream_tokens_and_compute_metrics( + response=response_stream, + prompt_tokens=prompt_tokens, + model=model, + tool_prompt_format=tool_config.tool_prompt_format, ) - total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) - metrics = await self._compute_and_log_token_usage( - prompt_tokens or 0, - completion_tokens or 0, - total_tokens, - model, - ) - response.metrics = metrics if response.metrics is None else response.metrics + metrics - return response + + response = await provider.chat_completion(**params) + metrics = await self.count_tokens_and_compute_metrics( + response=response, + prompt_tokens=prompt_tokens, + model=model, + tool_prompt_format=tool_config.tool_prompt_format, + ) + # these metrics will show up in the client response. + response.metrics = ( + metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics + ) + return response async def batch_chat_completion( self, @@ -332,39 +317,20 @@ class InferenceRouter(Inference): ) prompt_tokens = await self._count_tokens(content) - + response = await provider.completion(**params) if stream: - - async def stream_generator(): - completion_text = "" - async for chunk in await provider.completion(**params): - if hasattr(chunk, "delta"): - completion_text += chunk.delta - if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry: - completion_tokens = await self._count_tokens(completion_text) - total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) - metrics = await self._compute_and_log_token_usage( - prompt_tokens or 0, - completion_tokens or 0, - total_tokens, - model, - ) - chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics - yield chunk - - return stream_generator() - else: - response = await provider.completion(**params) - completion_tokens = await self._count_tokens(response.content) - total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) - metrics = await self._compute_and_log_token_usage( - prompt_tokens or 0, - completion_tokens or 0, - total_tokens, - model, + return self.stream_tokens_and_compute_metrics( + response=response, + prompt_tokens=prompt_tokens, + model=model, ) - response.metrics = metrics if response.metrics is None else response.metrics + metrics - return response + + metrics = await self.count_tokens_and_compute_metrics( + response=response, prompt_tokens=prompt_tokens, model=model + ) + response.metrics = metrics if response.metrics is None else response.metrics + metrics + + return response async def batch_completion( self, @@ -457,9 +423,29 @@ class InferenceRouter(Inference): prompt_logprobs=prompt_logprobs, suffix=suffix, ) + provider = self.routing_table.get_provider_impl(model_obj.identifier) + if stream: + return await provider.openai_completion(**params) + # TODO: Metrics do NOT work with openai_completion stream=True due to the fact + # that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently. + # response_stream = await provider.openai_completion(**params) - provider = await self.routing_table.get_provider_impl(model_obj.identifier) - return await provider.openai_completion(**params) + response = await provider.openai_completion(**params) + if self.telemetry: + metrics = self._construct_metrics( + prompt_tokens=response.usage.prompt_tokens, + completion_tokens=response.usage.completion_tokens, + total_tokens=response.usage.total_tokens, + model=model_obj, + ) + for metric in metrics: + await self.telemetry.log_event(metric) + + # these metrics will show up in the client response. + response.metrics = ( + metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics + ) + return response async def openai_chat_completion( self, @@ -537,18 +523,38 @@ class InferenceRouter(Inference): top_p=top_p, user=user, ) - - provider = await self.routing_table.get_provider_impl(model_obj.identifier) + provider = self.routing_table.get_provider_impl(model_obj.identifier) if stream: response_stream = await provider.openai_chat_completion(**params) - if self.store: - return stream_and_store_openai_completion(response_stream, model, self.store, messages) - return response_stream - else: - response = await self._nonstream_openai_chat_completion(provider, params) - if self.store: - await self.store.store_chat_completion(response, messages) - return response + + # For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk] + # We need to add metrics to each chunk and store the final completion + return self.stream_tokens_and_compute_metrics_openai_chat( + response=response_stream, + model=model_obj, + messages=messages, + ) + + response = await self._nonstream_openai_chat_completion(provider, params) + + # Store the response with the ID that will be returned to the client + if self.store: + await self.store.store_chat_completion(response, messages) + + if self.telemetry: + metrics = self._construct_metrics( + prompt_tokens=response.usage.prompt_tokens, + completion_tokens=response.usage.completion_tokens, + total_tokens=response.usage.total_tokens, + model=model_obj, + ) + for metric in metrics: + await self.telemetry.log_event(metric) + # these metrics will show up in the client response. + response.metrics = ( + metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics + ) + return response async def openai_embeddings( self, @@ -625,3 +631,244 @@ class InferenceRouter(Inference): status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}" ) return health_statuses + + async def stream_tokens_and_compute_metrics( + self, + response, + prompt_tokens, + model, + tool_prompt_format: ToolPromptFormat | None = None, + ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None] | AsyncGenerator[CompletionResponseStreamChunk, None]: + completion_text = "" + async for chunk in response: + complete = False + if hasattr(chunk, "event"): # only ChatCompletions have .event + if chunk.event.event_type == ChatCompletionResponseEventType.progress: + if chunk.event.delta.type == "text": + completion_text += chunk.event.delta.text + if chunk.event.event_type == ChatCompletionResponseEventType.complete: + complete = True + completion_tokens = await self._count_tokens( + [ + CompletionMessage( + content=completion_text, + stop_reason=StopReason.end_of_turn, + ) + ], + tool_prompt_format=tool_prompt_format, + ) + else: + if hasattr(chunk, "delta"): + completion_text += chunk.delta + if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry: + complete = True + completion_tokens = await self._count_tokens(completion_text) + # if we are done receiving tokens + if complete: + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + + # Create a separate span for streaming completion metrics + if self.telemetry: + # Log metrics in the new span context + completion_metrics = self._construct_metrics( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + model=model, + ) + for metric in completion_metrics: + if metric.metric in [ + "completion_tokens", + "total_tokens", + ]: # Only log completion and total tokens + await self.telemetry.log_event(metric) + + # Return metrics in response + async_metrics = [ + MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics + ] + chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics + else: + # Fallback if no telemetry + completion_metrics = self._construct_metrics( + prompt_tokens or 0, + completion_tokens or 0, + total_tokens, + model, + ) + async_metrics = [ + MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics + ] + chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics + yield chunk + + async def count_tokens_and_compute_metrics( + self, + response: ChatCompletionResponse | CompletionResponse, + prompt_tokens, + model, + tool_prompt_format: ToolPromptFormat | None = None, + ): + if isinstance(response, ChatCompletionResponse): + content = [response.completion_message] + else: + content = response.content + completion_tokens = await self._count_tokens(messages=content, tool_prompt_format=tool_prompt_format) + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + + # Create a separate span for completion metrics + if self.telemetry: + # Log metrics in the new span context + completion_metrics = self._construct_metrics( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + model=model, + ) + for metric in completion_metrics: + if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens + await self.telemetry.log_event(metric) + + # Return metrics in response + return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics] + + # Fallback if no telemetry + metrics = self._construct_metrics( + prompt_tokens or 0, + completion_tokens or 0, + total_tokens, + model, + ) + return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] + + async def stream_tokens_and_compute_metrics_openai_chat( + self, + response: AsyncIterator[OpenAIChatCompletionChunk], + model: Model, + messages: list[OpenAIMessageParam] | None = None, + ) -> AsyncIterator[OpenAIChatCompletionChunk]: + """Stream OpenAI chat completion chunks, compute metrics, and store the final completion.""" + id = None + created = None + choices_data: dict[int, dict[str, Any]] = {} + + try: + async for chunk in response: + # Skip None chunks + if chunk is None: + continue + + # Capture ID and created timestamp from first chunk + if id is None and chunk.id: + id = chunk.id + if created is None and chunk.created: + created = chunk.created + + # Accumulate choice data for final assembly + if chunk.choices: + for choice_delta in chunk.choices: + idx = choice_delta.index + if idx not in choices_data: + choices_data[idx] = { + "content_parts": [], + "tool_calls_builder": {}, + "finish_reason": None, + "logprobs_content_parts": [], + } + current_choice_data = choices_data[idx] + + if choice_delta.delta: + delta = choice_delta.delta + if delta.content: + current_choice_data["content_parts"].append(delta.content) + if delta.tool_calls: + for tool_call_delta in delta.tool_calls: + tc_idx = tool_call_delta.index + if tc_idx not in current_choice_data["tool_calls_builder"]: + current_choice_data["tool_calls_builder"][tc_idx] = { + "id": None, + "type": "function", + "function_name_parts": [], + "function_arguments_parts": [], + } + builder = current_choice_data["tool_calls_builder"][tc_idx] + if tool_call_delta.id: + builder["id"] = tool_call_delta.id + if tool_call_delta.type: + builder["type"] = tool_call_delta.type + if tool_call_delta.function: + if tool_call_delta.function.name: + builder["function_name_parts"].append(tool_call_delta.function.name) + if tool_call_delta.function.arguments: + builder["function_arguments_parts"].append( + tool_call_delta.function.arguments + ) + if choice_delta.finish_reason: + current_choice_data["finish_reason"] = choice_delta.finish_reason + if choice_delta.logprobs and choice_delta.logprobs.content: + current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content) + + # Compute metrics on final chunk + if chunk.choices and chunk.choices[0].finish_reason: + completion_text = "" + for choice_data in choices_data.values(): + completion_text += "".join(choice_data["content_parts"]) + + # Add metrics to the chunk + if self.telemetry and chunk.usage: + metrics = self._construct_metrics( + prompt_tokens=chunk.usage.prompt_tokens, + completion_tokens=chunk.usage.completion_tokens, + total_tokens=chunk.usage.total_tokens, + model=model, + ) + for metric in metrics: + await self.telemetry.log_event(metric) + + yield chunk + finally: + # Store the final assembled completion + if id and self.store and messages: + assembled_choices: list[OpenAIChoice] = [] + for choice_idx, choice_data in choices_data.items(): + content_str = "".join(choice_data["content_parts"]) + assembled_tool_calls: list[OpenAIChatCompletionToolCall] = [] + if choice_data["tool_calls_builder"]: + for tc_build_data in choice_data["tool_calls_builder"].values(): + if tc_build_data["id"]: + func_name = "".join(tc_build_data["function_name_parts"]) + func_args = "".join(tc_build_data["function_arguments_parts"]) + assembled_tool_calls.append( + OpenAIChatCompletionToolCall( + id=tc_build_data["id"], + type=tc_build_data["type"], + function=OpenAIChatCompletionToolCallFunction( + name=func_name, arguments=func_args + ), + ) + ) + message = OpenAIAssistantMessageParam( + role="assistant", + content=content_str if content_str else None, + tool_calls=assembled_tool_calls if assembled_tool_calls else None, + ) + logprobs_content = choice_data["logprobs_content_parts"] + final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None + + assembled_choices.append( + OpenAIChoice( + finish_reason=choice_data["finish_reason"], + index=choice_idx, + message=message, + logprobs=final_logprobs, + ) + ) + + final_response = OpenAIChatCompletion( + id=id, + choices=assembled_choices, + created=created or int(time.time()), + model=model.identifier, + object="chat.completion", + ) + await self.store.store_chat_completion(final_response, messages) diff --git a/llama_stack/providers/utils/inference/stream_utils.py b/llama_stack/providers/utils/inference/stream_utils.py deleted file mode 100644 index bbfac13a3..000000000 --- a/llama_stack/providers/utils/inference/stream_utils.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from collections.abc import AsyncIterator -from datetime import UTC, datetime -from typing import Any - -from llama_stack.apis.inference import ( - OpenAIAssistantMessageParam, - OpenAIChatCompletion, - OpenAIChatCompletionChunk, - OpenAIChatCompletionToolCall, - OpenAIChatCompletionToolCallFunction, - OpenAIChoice, - OpenAIChoiceLogprobs, - OpenAIMessageParam, -) -from llama_stack.providers.utils.inference.inference_store import InferenceStore - - -async def stream_and_store_openai_completion( - provider_stream: AsyncIterator[OpenAIChatCompletionChunk], - model: str, - store: InferenceStore, - input_messages: list[OpenAIMessageParam], -) -> AsyncIterator[OpenAIChatCompletionChunk]: - """ - Wraps a provider's stream, yields chunks, and stores the full completion at the end. - """ - id = None - created = None - choices_data: dict[int, dict[str, Any]] = {} - - try: - async for chunk in provider_stream: - if id is None and chunk.id: - id = chunk.id - if created is None and chunk.created: - created = chunk.created - - if chunk.choices: - for choice_delta in chunk.choices: - idx = choice_delta.index - if idx not in choices_data: - choices_data[idx] = { - "content_parts": [], - "tool_calls_builder": {}, - "finish_reason": None, - "logprobs_content_parts": [], - } - current_choice_data = choices_data[idx] - - if choice_delta.delta: - delta = choice_delta.delta - if delta.content: - current_choice_data["content_parts"].append(delta.content) - if delta.tool_calls: - for tool_call_delta in delta.tool_calls: - tc_idx = tool_call_delta.index - if tc_idx not in current_choice_data["tool_calls_builder"]: - # Initialize with correct structure for _ToolCallBuilderData - current_choice_data["tool_calls_builder"][tc_idx] = { - "id": None, - "type": "function", - "function_name_parts": [], - "function_arguments_parts": [], - } - builder = current_choice_data["tool_calls_builder"][tc_idx] - if tool_call_delta.id: - builder["id"] = tool_call_delta.id - if tool_call_delta.type: - builder["type"] = tool_call_delta.type - if tool_call_delta.function: - if tool_call_delta.function.name: - builder["function_name_parts"].append(tool_call_delta.function.name) - if tool_call_delta.function.arguments: - builder["function_arguments_parts"].append(tool_call_delta.function.arguments) - if choice_delta.finish_reason: - current_choice_data["finish_reason"] = choice_delta.finish_reason - if choice_delta.logprobs and choice_delta.logprobs.content: - # Ensure that we are extending with the correct type - current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content) - yield chunk - finally: - if id: - assembled_choices: list[OpenAIChoice] = [] - for choice_idx, choice_data in choices_data.items(): - content_str = "".join(choice_data["content_parts"]) - assembled_tool_calls: list[OpenAIChatCompletionToolCall] = [] - if choice_data["tool_calls_builder"]: - for tc_build_data in choice_data["tool_calls_builder"].values(): - if tc_build_data["id"]: - func_name = "".join(tc_build_data["function_name_parts"]) - func_args = "".join(tc_build_data["function_arguments_parts"]) - assembled_tool_calls.append( - OpenAIChatCompletionToolCall( - id=tc_build_data["id"], - type=tc_build_data["type"], # No or "function" needed, already set - function=OpenAIChatCompletionToolCallFunction(name=func_name, arguments=func_args), - ) - ) - message = OpenAIAssistantMessageParam( - role="assistant", - content=content_str if content_str else None, - tool_calls=assembled_tool_calls if assembled_tool_calls else None, - ) - logprobs_content = choice_data["logprobs_content_parts"] - final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None - - assembled_choices.append( - OpenAIChoice( - finish_reason=choice_data["finish_reason"], - index=choice_idx, - message=message, - logprobs=final_logprobs, - ) - ) - - final_response = OpenAIChatCompletion( - id=id, - choices=assembled_choices, - created=created or int(datetime.now(UTC).timestamp()), - model=model, - object="chat.completion", - ) - await store.store_chat_completion(final_response, input_messages) diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index c85722bdc..75b29cdce 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -81,7 +81,7 @@ BACKGROUND_LOGGER = None class BackgroundLogger: - def __init__(self, api: Telemetry, capacity: int = 1000): + def __init__(self, api: Telemetry, capacity: int = 100000): self.api = api self.log_queue = queue.Queue(maxsize=capacity) self.worker_thread = threading.Thread(target=self._process_logs, daemon=True)