feat: enable streaming usage metrics for OpenAI-compatible providers (#4326)

Inject `stream_options={"include_usage": True} `when streaming and
OpenTelemetry telemetry is active. Telemetry always overrides any caller
preference to ensure complete and consistent observability metrics.

Changes:
- Add conditional stream_options injection to OpenAIMixin (benefits
OpenAI, Bedrock, Runpod, Together, Fireworks providers)
- Add conditional stream_options injection to LiteLLMOpenAIMixin
(benefits WatsonX and other litellm-based providers)
- Check telemetry status using trace.get_current_span().is_recording()
- Override include_usage=False when telemetry active to prevent metric
gaps
- Unit tests for this functionality

Fixes #3981

Note: this work originated in PR #4200, which I closed after rebasing on
the telemetry changes. This PR rebases those commits, incorporates the
Bedrock feedback, and carries forward the same scope described there.
## Test Plan
#### OpenAIMixin + telemetry injection tests 
PYTHONPATH=src python -m pytest
tests/unit/providers/utils/inference/test_openai_mixin.py

#### LiteLLM OpenAIMixin tests
PYTHONPATH=src python -m pytest
tests/unit/providers/inference/test_litellm_openai_mixin.py -v

#### Broader inference provider
PYTHONPATH=src python -m pytest tests/unit/providers/inference/
--ignore=tests/unit/providers/inference/test_inference_client_caching.py
-v
This commit is contained in:
Sumanth Kamenani 2025-12-19 18:53:53 -05:00 committed by GitHub
parent 5ebcde3042
commit bd35aa4d78
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 558 additions and 130 deletions

View file

@ -81,14 +81,7 @@ class BedrockInferenceAdapter(OpenAIMixin):
self,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""Override to enable streaming usage metrics and handle authentication errors."""
# Enable streaming usage metrics when telemetry is active
if params.stream:
if params.stream_options is None:
params.stream_options = {"include_usage": True}
elif "include_usage" not in params.stream_options:
params.stream_options = {**params.stream_options, "include_usage": True}
"""Override to handle authentication errors and null responses."""
try:
logger.debug(f"Calling Bedrock OpenAI API with model={params.model}, stream={params.stream}")
result = await super().openai_chat_completion(params=params)

View file

@ -28,6 +28,9 @@ class OllamaInferenceAdapter(OpenAIMixin):
# automatically set by the resolver when instantiating the provider
__provider_id__: str
# Ollama does not support the stream_options parameter
supports_stream_options: bool = False
embedding_model_metadata: dict[str, dict[str, int]] = {
"all-minilm:l6-v2": {
"embedding_dimension": 384,

View file

@ -4,14 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncIterator
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack_api import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
)
from .config import RunpodImplConfig
@ -29,15 +22,3 @@ class RunpodInferenceAdapter(OpenAIMixin):
def get_base_url(self) -> str:
"""Get base URL for OpenAI client."""
return str(self.config.base_url)
async def openai_chat_completion(
self,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""Override to add RunPod-specific stream_options requirement."""
params = params.model_copy()
if params.stream and not params.stream_options:
params.stream_options = {"include_usage": True}
return await super().openai_chat_completion(params)

View file

@ -30,6 +30,9 @@ class VLLMInferenceAdapter(OpenAIMixin):
model_config = ConfigDict(arbitrary_types_allowed=True)
# vLLM does not support the stream_options parameter
supports_stream_options: bool = False
provider_data_api_key_field: str = "vllm_api_token"
def get_api_key(self) -> str | None:

View file

@ -13,8 +13,6 @@ import requests
from llama_stack.log import get_logger
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
from llama_stack.providers.utils.inference.stream_utils import wrap_async_stream
from llama_stack_api import (
Model,
ModelType,
@ -22,7 +20,6 @@ from llama_stack_api import (
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIChatCompletionUsage,
OpenAICompletion,
OpenAICompletionRequestWithExtraBody,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
@ -48,57 +45,25 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
openai_compat_api_base=self.get_base_url(),
)
def _litellm_extra_request_params(
self,
params: OpenAIChatCompletionRequestWithExtraBody | OpenAICompletionRequestWithExtraBody,
) -> dict[str, Any]:
# These are watsonx-specific parameters used by LiteLLM.
return {"timeout": self.config.timeout, "project_id": self.config.project_id}
async def openai_chat_completion(
self,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""
Override parent method to add timeout and inject usage object when missing.
Override parent method to inject usage object when missing.
This works around a LiteLLM defect where usage block is sometimes dropped.
Note: request parameter construction (including telemetry-driven stream_options injection)
is handled by LiteLLMOpenAIMixin via _litellm_extra_request_params().
"""
# Add usage tracking for streaming when telemetry is active
stream_options = params.stream_options
if params.stream:
if stream_options is None:
stream_options = {"include_usage": True}
elif "include_usage" not in stream_options:
stream_options = {**stream_options, "include_usage": True}
model_obj = await self.model_store.get_model(params.model)
request_params = await prepare_openai_completion_params(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
messages=params.messages,
frequency_penalty=params.frequency_penalty,
function_call=params.function_call,
functions=params.functions,
logit_bias=params.logit_bias,
logprobs=params.logprobs,
max_completion_tokens=params.max_completion_tokens,
max_tokens=params.max_tokens,
n=params.n,
parallel_tool_calls=params.parallel_tool_calls,
presence_penalty=params.presence_penalty,
response_format=params.response_format,
seed=params.seed,
stop=params.stop,
stream=params.stream,
stream_options=stream_options,
temperature=params.temperature,
tool_choice=params.tool_choice,
tools=params.tools,
top_logprobs=params.top_logprobs,
top_p=params.top_p,
user=params.user,
api_key=self.get_api_key(),
api_base=self.api_base,
# These are watsonx-specific parameters
timeout=self.config.timeout,
project_id=self.config.project_id,
)
result = await litellm.acompletion(**request_params)
result = await super().openai_chat_completion(params)
# If not streaming, check and inject usage if missing
if not params.stream:
@ -175,49 +140,6 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
logger.error(f"Error normalizing stream: {e}", exc_info=True)
raise
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
"""
Override parent method to add watsonx-specific parameters.
"""
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
model_obj = await self.model_store.get_model(params.model)
request_params = await prepare_openai_completion_params(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
prompt=params.prompt,
best_of=params.best_of,
echo=params.echo,
frequency_penalty=params.frequency_penalty,
logit_bias=params.logit_bias,
logprobs=params.logprobs,
max_tokens=params.max_tokens,
n=params.n,
presence_penalty=params.presence_penalty,
seed=params.seed,
stop=params.stop,
stream=params.stream,
stream_options=params.stream_options,
temperature=params.temperature,
top_p=params.top_p,
user=params.user,
suffix=params.suffix,
api_key=self.get_api_key(),
api_base=self.api_base,
# These are watsonx-specific parameters
timeout=self.config.timeout,
project_id=self.config.project_id,
)
result = await litellm.atext_completion(**request_params)
if params.stream:
return wrap_async_stream(result) # type: ignore[arg-type] # LiteLLM streaming types
return result # type: ignore[return-value] # external lib lacks type stubs
async def openai_embeddings(
self,
params: OpenAIEmbeddingsRequestWithExtraBody,

View file

@ -7,6 +7,7 @@
import base64
import struct
from collections.abc import AsyncIterator
from typing import Any
import litellm
@ -14,6 +15,7 @@ from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
from llama_stack.providers.utils.inference.openai_compat import (
get_stream_options_for_telemetry,
prepare_openai_completion_params,
)
from llama_stack.providers.utils.inference.stream_utils import wrap_async_stream
@ -50,6 +52,7 @@ class LiteLLMOpenAIMixin(
openai_compat_api_base: str | None = None,
download_images: bool = False,
json_schema_strict: bool = True,
supports_stream_options: bool = True,
):
"""
Initialize the LiteLLMOpenAIMixin.
@ -61,6 +64,7 @@ class LiteLLMOpenAIMixin(
:param openai_compat_api_base: The base URL for OpenAI compatibility, or None if not using OpenAI compatibility.
:param download_images: Whether to download images and convert to base64 for message conversion.
:param json_schema_strict: Whether to use strict mode for JSON schema validation.
:param supports_stream_options: Whether the provider supports stream_options parameter.
"""
ModelRegistryHelper.__init__(self, model_entries=model_entries)
@ -70,6 +74,7 @@ class LiteLLMOpenAIMixin(
self.api_base = openai_compat_api_base
self.download_images = download_images
self.json_schema_strict = json_schema_strict
self.supports_stream_options = supports_stream_options
if openai_compat_api_base:
self.is_openai_compat = True
@ -180,6 +185,11 @@ class LiteLLMOpenAIMixin(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
# Inject stream_options when streaming and telemetry is active
stream_options = get_stream_options_for_telemetry(
params.stream_options, params.stream, self.supports_stream_options
)
if not self.model_store:
raise ValueError("Model store is not initialized")
@ -202,13 +212,14 @@ class LiteLLMOpenAIMixin(
seed=params.seed,
stop=params.stop,
stream=params.stream,
stream_options=params.stream_options,
stream_options=stream_options,
temperature=params.temperature,
top_p=params.top_p,
user=params.user,
suffix=params.suffix,
api_key=self.get_api_key(),
api_base=self.api_base,
**self._litellm_extra_request_params(params),
)
# LiteLLM returns compatible type but mypy can't verify external library
result = await litellm.atext_completion(**request_params)
@ -222,14 +233,10 @@ class LiteLLMOpenAIMixin(
self,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
# Add usage tracking for streaming when telemetry is active
stream_options = params.stream_options
if params.stream:
if stream_options is None:
stream_options = {"include_usage": True}
elif "include_usage" not in stream_options:
stream_options = {**stream_options, "include_usage": True}
# Inject stream_options when streaming and telemetry is active
stream_options = get_stream_options_for_telemetry(
params.stream_options, params.stream, self.supports_stream_options
)
if not self.model_store:
raise ValueError("Model store is not initialized")
@ -265,6 +272,7 @@ class LiteLLMOpenAIMixin(
user=params.user,
api_key=self.get_api_key(),
api_base=self.api_base,
**self._litellm_extra_request_params(params),
)
# LiteLLM returns compatible type but mypy can't verify external library
result = await litellm.acompletion(**request_params)
@ -288,6 +296,20 @@ class LiteLLMOpenAIMixin(
return model in litellm.models_by_provider[self.litellm_provider_name]
def _litellm_extra_request_params(
self,
params: OpenAIChatCompletionRequestWithExtraBody | OpenAICompletionRequestWithExtraBody,
) -> dict[str, Any]:
"""
Provider hook for extra LiteLLM/OpenAI-compat request params.
This is intentionally a narrow hook so provider adapters (e.g. WatsonX)
can add provider-specific kwargs (timeouts, project IDs, etc.) while the
mixin remains the single source of truth for telemetry-driven
stream_options injection.
"""
return {}
def b64_encode_openai_embeddings_response(
response_data: list[dict], encoding_format: str | None = "float"

View file

@ -235,3 +235,40 @@ def prepare_openai_embeddings_params(
params["user"] = user
return params
def get_stream_options_for_telemetry(
stream_options: dict[str, Any] | None,
is_streaming: bool,
supports_stream_options: bool = True,
) -> dict[str, Any] | None:
"""
Inject stream_options when streaming and telemetry is active.
Active telemetry takes precedence over caller preference to ensure
complete and consistent observability metrics.
Args:
stream_options: Existing stream options from the request
is_streaming: Whether this is a streaming request
supports_stream_options: Whether the provider supports stream_options parameter
Returns:
Updated stream_options with include_usage=True if conditions are met, otherwise original options
"""
if not is_streaming:
return stream_options
if not supports_stream_options:
return stream_options
from opentelemetry import trace
span = trace.get_current_span()
if not span or not span.is_recording():
return stream_options
if stream_options is None:
return {"include_usage": True}
return {**stream_options, "include_usage": True}

View file

@ -16,7 +16,10 @@ from pydantic import BaseModel, ConfigDict
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
from llama_stack.providers.utils.inference.openai_compat import (
get_stream_options_for_telemetry,
prepare_openai_completion_params,
)
from llama_stack.providers.utils.inference.prompt_adapter import localize_image_content
from llama_stack_api import (
Model,
@ -47,6 +50,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
The behavior of this class can be customized by child classes in the following ways:
- overwrite_completion_id: If True, overwrites the 'id' field in OpenAI responses
- download_images: If True, downloads images and converts to base64 for providers that require it
- supports_stream_options: If False, disables stream_options injection for providers that don't support it
- embedding_model_metadata: A dictionary mapping model IDs to their embedding metadata
- construct_model_from_identifier: Method to construct a Model instance corresponding to the given identifier
- provider_data_api_key_field: Optional field name in provider data to look for API key
@ -74,6 +78,10 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
# for providers that require base64 encoded images instead of URLs.
download_images: bool = False
# Allow subclasses to control whether the provider supports stream_options parameter
# Set to False for providers that don't support stream_options (e.g., Ollama, vLLM)
supports_stream_options: bool = True
# Embedding model metadata for this provider
# Can be set by subclasses or instances to provide embedding models
# Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}}
@ -270,6 +278,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
"""
Direct OpenAI completion API call.
"""
# Inject stream_options when streaming and telemetry is active
stream_options = get_stream_options_for_telemetry(
params.stream_options, params.stream or False, self.supports_stream_options
)
provider_model_id = await self._get_provider_model_id(params.model)
self._validate_model_allowed(provider_model_id)
@ -287,7 +300,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
seed=params.seed,
stop=params.stop,
stream=params.stream,
stream_options=params.stream_options,
stream_options=stream_options,
temperature=params.temperature,
top_p=params.top_p,
user=params.user,
@ -306,6 +319,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
"""
Direct OpenAI chat completion API call.
"""
# Inject stream_options when streaming and telemetry is active
stream_options = get_stream_options_for_telemetry(
params.stream_options, params.stream or False, self.supports_stream_options
)
provider_model_id = await self._get_provider_model_id(params.model)
self._validate_model_allowed(provider_model_id)
@ -346,7 +364,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
seed=params.seed,
stop=params.stop,
stream=params.stream,
stream_options=params.stream_options,
stream_options=stream_options,
temperature=params.temperature,
tool_choice=params.tool_choice,
tools=params.tools,