feat: enable streaming usage metrics for OpenAI-compatible providers (#4326)

Inject `stream_options={"include_usage": True} `when streaming and
OpenTelemetry telemetry is active. Telemetry always overrides any caller
preference to ensure complete and consistent observability metrics.

Changes:
- Add conditional stream_options injection to OpenAIMixin (benefits
OpenAI, Bedrock, Runpod, Together, Fireworks providers)
- Add conditional stream_options injection to LiteLLMOpenAIMixin
(benefits WatsonX and other litellm-based providers)
- Check telemetry status using trace.get_current_span().is_recording()
- Override include_usage=False when telemetry active to prevent metric
gaps
- Unit tests for this functionality

Fixes #3981

Note: this work originated in PR #4200, which I closed after rebasing on
the telemetry changes. This PR rebases those commits, incorporates the
Bedrock feedback, and carries forward the same scope described there.
## Test Plan
#### OpenAIMixin + telemetry injection tests 
PYTHONPATH=src python -m pytest
tests/unit/providers/utils/inference/test_openai_mixin.py

#### LiteLLM OpenAIMixin tests
PYTHONPATH=src python -m pytest
tests/unit/providers/inference/test_litellm_openai_mixin.py -v

#### Broader inference provider
PYTHONPATH=src python -m pytest tests/unit/providers/inference/
--ignore=tests/unit/providers/inference/test_inference_client_caching.py
-v
This commit is contained in:
Sumanth Kamenani 2025-12-19 18:53:53 -05:00 committed by GitHub
parent 5ebcde3042
commit bd35aa4d78
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 558 additions and 130 deletions

View file

@ -81,14 +81,7 @@ class BedrockInferenceAdapter(OpenAIMixin):
self, self,
params: OpenAIChatCompletionRequestWithExtraBody, params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""Override to enable streaming usage metrics and handle authentication errors.""" """Override to handle authentication errors and null responses."""
# Enable streaming usage metrics when telemetry is active
if params.stream:
if params.stream_options is None:
params.stream_options = {"include_usage": True}
elif "include_usage" not in params.stream_options:
params.stream_options = {**params.stream_options, "include_usage": True}
try: try:
logger.debug(f"Calling Bedrock OpenAI API with model={params.model}, stream={params.stream}") logger.debug(f"Calling Bedrock OpenAI API with model={params.model}, stream={params.stream}")
result = await super().openai_chat_completion(params=params) result = await super().openai_chat_completion(params=params)

View file

@ -28,6 +28,9 @@ class OllamaInferenceAdapter(OpenAIMixin):
# automatically set by the resolver when instantiating the provider # automatically set by the resolver when instantiating the provider
__provider_id__: str __provider_id__: str
# Ollama does not support the stream_options parameter
supports_stream_options: bool = False
embedding_model_metadata: dict[str, dict[str, int]] = { embedding_model_metadata: dict[str, dict[str, int]] = {
"all-minilm:l6-v2": { "all-minilm:l6-v2": {
"embedding_dimension": 384, "embedding_dimension": 384,

View file

@ -4,14 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import AsyncIterator
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack_api import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
)
from .config import RunpodImplConfig from .config import RunpodImplConfig
@ -29,15 +22,3 @@ class RunpodInferenceAdapter(OpenAIMixin):
def get_base_url(self) -> str: def get_base_url(self) -> str:
"""Get base URL for OpenAI client.""" """Get base URL for OpenAI client."""
return str(self.config.base_url) return str(self.config.base_url)
async def openai_chat_completion(
self,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""Override to add RunPod-specific stream_options requirement."""
params = params.model_copy()
if params.stream and not params.stream_options:
params.stream_options = {"include_usage": True}
return await super().openai_chat_completion(params)

View file

@ -30,6 +30,9 @@ class VLLMInferenceAdapter(OpenAIMixin):
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
# vLLM does not support the stream_options parameter
supports_stream_options: bool = False
provider_data_api_key_field: str = "vllm_api_token" provider_data_api_key_field: str = "vllm_api_token"
def get_api_key(self) -> str | None: def get_api_key(self) -> str | None:

View file

@ -13,8 +13,6 @@ import requests
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
from llama_stack.providers.utils.inference.stream_utils import wrap_async_stream
from llama_stack_api import ( from llama_stack_api import (
Model, Model,
ModelType, ModelType,
@ -22,7 +20,6 @@ from llama_stack_api import (
OpenAIChatCompletionChunk, OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody, OpenAIChatCompletionRequestWithExtraBody,
OpenAIChatCompletionUsage, OpenAIChatCompletionUsage,
OpenAICompletion,
OpenAICompletionRequestWithExtraBody, OpenAICompletionRequestWithExtraBody,
OpenAIEmbeddingsRequestWithExtraBody, OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse, OpenAIEmbeddingsResponse,
@ -48,57 +45,25 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
openai_compat_api_base=self.get_base_url(), openai_compat_api_base=self.get_base_url(),
) )
def _litellm_extra_request_params(
self,
params: OpenAIChatCompletionRequestWithExtraBody | OpenAICompletionRequestWithExtraBody,
) -> dict[str, Any]:
# These are watsonx-specific parameters used by LiteLLM.
return {"timeout": self.config.timeout, "project_id": self.config.project_id}
async def openai_chat_completion( async def openai_chat_completion(
self, self,
params: OpenAIChatCompletionRequestWithExtraBody, params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
""" """
Override parent method to add timeout and inject usage object when missing. Override parent method to inject usage object when missing.
This works around a LiteLLM defect where usage block is sometimes dropped. This works around a LiteLLM defect where usage block is sometimes dropped.
Note: request parameter construction (including telemetry-driven stream_options injection)
is handled by LiteLLMOpenAIMixin via _litellm_extra_request_params().
""" """
result = await super().openai_chat_completion(params)
# Add usage tracking for streaming when telemetry is active
stream_options = params.stream_options
if params.stream:
if stream_options is None:
stream_options = {"include_usage": True}
elif "include_usage" not in stream_options:
stream_options = {**stream_options, "include_usage": True}
model_obj = await self.model_store.get_model(params.model)
request_params = await prepare_openai_completion_params(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
messages=params.messages,
frequency_penalty=params.frequency_penalty,
function_call=params.function_call,
functions=params.functions,
logit_bias=params.logit_bias,
logprobs=params.logprobs,
max_completion_tokens=params.max_completion_tokens,
max_tokens=params.max_tokens,
n=params.n,
parallel_tool_calls=params.parallel_tool_calls,
presence_penalty=params.presence_penalty,
response_format=params.response_format,
seed=params.seed,
stop=params.stop,
stream=params.stream,
stream_options=stream_options,
temperature=params.temperature,
tool_choice=params.tool_choice,
tools=params.tools,
top_logprobs=params.top_logprobs,
top_p=params.top_p,
user=params.user,
api_key=self.get_api_key(),
api_base=self.api_base,
# These are watsonx-specific parameters
timeout=self.config.timeout,
project_id=self.config.project_id,
)
result = await litellm.acompletion(**request_params)
# If not streaming, check and inject usage if missing # If not streaming, check and inject usage if missing
if not params.stream: if not params.stream:
@ -175,49 +140,6 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
logger.error(f"Error normalizing stream: {e}", exc_info=True) logger.error(f"Error normalizing stream: {e}", exc_info=True)
raise raise
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
"""
Override parent method to add watsonx-specific parameters.
"""
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
model_obj = await self.model_store.get_model(params.model)
request_params = await prepare_openai_completion_params(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
prompt=params.prompt,
best_of=params.best_of,
echo=params.echo,
frequency_penalty=params.frequency_penalty,
logit_bias=params.logit_bias,
logprobs=params.logprobs,
max_tokens=params.max_tokens,
n=params.n,
presence_penalty=params.presence_penalty,
seed=params.seed,
stop=params.stop,
stream=params.stream,
stream_options=params.stream_options,
temperature=params.temperature,
top_p=params.top_p,
user=params.user,
suffix=params.suffix,
api_key=self.get_api_key(),
api_base=self.api_base,
# These are watsonx-specific parameters
timeout=self.config.timeout,
project_id=self.config.project_id,
)
result = await litellm.atext_completion(**request_params)
if params.stream:
return wrap_async_stream(result) # type: ignore[arg-type] # LiteLLM streaming types
return result # type: ignore[return-value] # external lib lacks type stubs
async def openai_embeddings( async def openai_embeddings(
self, self,
params: OpenAIEmbeddingsRequestWithExtraBody, params: OpenAIEmbeddingsRequestWithExtraBody,

View file

@ -7,6 +7,7 @@
import base64 import base64
import struct import struct
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from typing import Any
import litellm import litellm
@ -14,6 +15,7 @@ from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
get_stream_options_for_telemetry,
prepare_openai_completion_params, prepare_openai_completion_params,
) )
from llama_stack.providers.utils.inference.stream_utils import wrap_async_stream from llama_stack.providers.utils.inference.stream_utils import wrap_async_stream
@ -50,6 +52,7 @@ class LiteLLMOpenAIMixin(
openai_compat_api_base: str | None = None, openai_compat_api_base: str | None = None,
download_images: bool = False, download_images: bool = False,
json_schema_strict: bool = True, json_schema_strict: bool = True,
supports_stream_options: bool = True,
): ):
""" """
Initialize the LiteLLMOpenAIMixin. Initialize the LiteLLMOpenAIMixin.
@ -61,6 +64,7 @@ class LiteLLMOpenAIMixin(
:param openai_compat_api_base: The base URL for OpenAI compatibility, or None if not using OpenAI compatibility. :param openai_compat_api_base: The base URL for OpenAI compatibility, or None if not using OpenAI compatibility.
:param download_images: Whether to download images and convert to base64 for message conversion. :param download_images: Whether to download images and convert to base64 for message conversion.
:param json_schema_strict: Whether to use strict mode for JSON schema validation. :param json_schema_strict: Whether to use strict mode for JSON schema validation.
:param supports_stream_options: Whether the provider supports stream_options parameter.
""" """
ModelRegistryHelper.__init__(self, model_entries=model_entries) ModelRegistryHelper.__init__(self, model_entries=model_entries)
@ -70,6 +74,7 @@ class LiteLLMOpenAIMixin(
self.api_base = openai_compat_api_base self.api_base = openai_compat_api_base
self.download_images = download_images self.download_images = download_images
self.json_schema_strict = json_schema_strict self.json_schema_strict = json_schema_strict
self.supports_stream_options = supports_stream_options
if openai_compat_api_base: if openai_compat_api_base:
self.is_openai_compat = True self.is_openai_compat = True
@ -180,6 +185,11 @@ class LiteLLMOpenAIMixin(
self, self,
params: OpenAICompletionRequestWithExtraBody, params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]: ) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
# Inject stream_options when streaming and telemetry is active
stream_options = get_stream_options_for_telemetry(
params.stream_options, params.stream, self.supports_stream_options
)
if not self.model_store: if not self.model_store:
raise ValueError("Model store is not initialized") raise ValueError("Model store is not initialized")
@ -202,13 +212,14 @@ class LiteLLMOpenAIMixin(
seed=params.seed, seed=params.seed,
stop=params.stop, stop=params.stop,
stream=params.stream, stream=params.stream,
stream_options=params.stream_options, stream_options=stream_options,
temperature=params.temperature, temperature=params.temperature,
top_p=params.top_p, top_p=params.top_p,
user=params.user, user=params.user,
suffix=params.suffix, suffix=params.suffix,
api_key=self.get_api_key(), api_key=self.get_api_key(),
api_base=self.api_base, api_base=self.api_base,
**self._litellm_extra_request_params(params),
) )
# LiteLLM returns compatible type but mypy can't verify external library # LiteLLM returns compatible type but mypy can't verify external library
result = await litellm.atext_completion(**request_params) result = await litellm.atext_completion(**request_params)
@ -222,14 +233,10 @@ class LiteLLMOpenAIMixin(
self, self,
params: OpenAIChatCompletionRequestWithExtraBody, params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
# Add usage tracking for streaming when telemetry is active # Inject stream_options when streaming and telemetry is active
stream_options = get_stream_options_for_telemetry(
stream_options = params.stream_options params.stream_options, params.stream, self.supports_stream_options
if params.stream: )
if stream_options is None:
stream_options = {"include_usage": True}
elif "include_usage" not in stream_options:
stream_options = {**stream_options, "include_usage": True}
if not self.model_store: if not self.model_store:
raise ValueError("Model store is not initialized") raise ValueError("Model store is not initialized")
@ -265,6 +272,7 @@ class LiteLLMOpenAIMixin(
user=params.user, user=params.user,
api_key=self.get_api_key(), api_key=self.get_api_key(),
api_base=self.api_base, api_base=self.api_base,
**self._litellm_extra_request_params(params),
) )
# LiteLLM returns compatible type but mypy can't verify external library # LiteLLM returns compatible type but mypy can't verify external library
result = await litellm.acompletion(**request_params) result = await litellm.acompletion(**request_params)
@ -288,6 +296,20 @@ class LiteLLMOpenAIMixin(
return model in litellm.models_by_provider[self.litellm_provider_name] return model in litellm.models_by_provider[self.litellm_provider_name]
def _litellm_extra_request_params(
self,
params: OpenAIChatCompletionRequestWithExtraBody | OpenAICompletionRequestWithExtraBody,
) -> dict[str, Any]:
"""
Provider hook for extra LiteLLM/OpenAI-compat request params.
This is intentionally a narrow hook so provider adapters (e.g. WatsonX)
can add provider-specific kwargs (timeouts, project IDs, etc.) while the
mixin remains the single source of truth for telemetry-driven
stream_options injection.
"""
return {}
def b64_encode_openai_embeddings_response( def b64_encode_openai_embeddings_response(
response_data: list[dict], encoding_format: str | None = "float" response_data: list[dict], encoding_format: str | None = "float"

View file

@ -235,3 +235,40 @@ def prepare_openai_embeddings_params(
params["user"] = user params["user"] = user
return params return params
def get_stream_options_for_telemetry(
stream_options: dict[str, Any] | None,
is_streaming: bool,
supports_stream_options: bool = True,
) -> dict[str, Any] | None:
"""
Inject stream_options when streaming and telemetry is active.
Active telemetry takes precedence over caller preference to ensure
complete and consistent observability metrics.
Args:
stream_options: Existing stream options from the request
is_streaming: Whether this is a streaming request
supports_stream_options: Whether the provider supports stream_options parameter
Returns:
Updated stream_options with include_usage=True if conditions are met, otherwise original options
"""
if not is_streaming:
return stream_options
if not supports_stream_options:
return stream_options
from opentelemetry import trace
span = trace.get_current_span()
if not span or not span.is_recording():
return stream_options
if stream_options is None:
return {"include_usage": True}
return {**stream_options, "include_usage": True}

View file

@ -16,7 +16,10 @@ from pydantic import BaseModel, ConfigDict
from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params from llama_stack.providers.utils.inference.openai_compat import (
get_stream_options_for_telemetry,
prepare_openai_completion_params,
)
from llama_stack.providers.utils.inference.prompt_adapter import localize_image_content from llama_stack.providers.utils.inference.prompt_adapter import localize_image_content
from llama_stack_api import ( from llama_stack_api import (
Model, Model,
@ -47,6 +50,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
The behavior of this class can be customized by child classes in the following ways: The behavior of this class can be customized by child classes in the following ways:
- overwrite_completion_id: If True, overwrites the 'id' field in OpenAI responses - overwrite_completion_id: If True, overwrites the 'id' field in OpenAI responses
- download_images: If True, downloads images and converts to base64 for providers that require it - download_images: If True, downloads images and converts to base64 for providers that require it
- supports_stream_options: If False, disables stream_options injection for providers that don't support it
- embedding_model_metadata: A dictionary mapping model IDs to their embedding metadata - embedding_model_metadata: A dictionary mapping model IDs to their embedding metadata
- construct_model_from_identifier: Method to construct a Model instance corresponding to the given identifier - construct_model_from_identifier: Method to construct a Model instance corresponding to the given identifier
- provider_data_api_key_field: Optional field name in provider data to look for API key - provider_data_api_key_field: Optional field name in provider data to look for API key
@ -74,6 +78,10 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
# for providers that require base64 encoded images instead of URLs. # for providers that require base64 encoded images instead of URLs.
download_images: bool = False download_images: bool = False
# Allow subclasses to control whether the provider supports stream_options parameter
# Set to False for providers that don't support stream_options (e.g., Ollama, vLLM)
supports_stream_options: bool = True
# Embedding model metadata for this provider # Embedding model metadata for this provider
# Can be set by subclasses or instances to provide embedding models # Can be set by subclasses or instances to provide embedding models
# Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}} # Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}}
@ -270,6 +278,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
""" """
Direct OpenAI completion API call. Direct OpenAI completion API call.
""" """
# Inject stream_options when streaming and telemetry is active
stream_options = get_stream_options_for_telemetry(
params.stream_options, params.stream or False, self.supports_stream_options
)
provider_model_id = await self._get_provider_model_id(params.model) provider_model_id = await self._get_provider_model_id(params.model)
self._validate_model_allowed(provider_model_id) self._validate_model_allowed(provider_model_id)
@ -287,7 +300,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
seed=params.seed, seed=params.seed,
stop=params.stop, stop=params.stop,
stream=params.stream, stream=params.stream,
stream_options=params.stream_options, stream_options=stream_options,
temperature=params.temperature, temperature=params.temperature,
top_p=params.top_p, top_p=params.top_p,
user=params.user, user=params.user,
@ -306,6 +319,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
""" """
Direct OpenAI chat completion API call. Direct OpenAI chat completion API call.
""" """
# Inject stream_options when streaming and telemetry is active
stream_options = get_stream_options_for_telemetry(
params.stream_options, params.stream or False, self.supports_stream_options
)
provider_model_id = await self._get_provider_model_id(params.model) provider_model_id = await self._get_provider_model_id(params.model)
self._validate_model_allowed(provider_model_id) self._validate_model_allowed(provider_model_id)
@ -346,7 +364,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
seed=params.seed, seed=params.seed,
stop=params.stop, stop=params.stop,
stream=params.stream, stream=params.stream,
stream_options=params.stream_options, stream_options=stream_options,
temperature=params.temperature, temperature=params.temperature,
tool_choice=params.tool_choice, tool_choice=params.tool_choice,
tools=params.tools, tools=params.tools,

View file

@ -5,13 +5,18 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
from unittest.mock import MagicMock from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.core.request_headers import request_provider_data_context from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack_api import (
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletionRequestWithExtraBody,
OpenAIUserMessageParam,
)
# Test fixtures and helper classes # Test fixtures and helper classes
@ -109,3 +114,168 @@ def test_error_message_includes_correct_field_names(adapter_without_config_key):
except ValueError as e: except ValueError as e:
assert "test_api_key" in str(e) # Should mention the correct field name assert "test_api_key" in str(e) # Should mention the correct field name
assert "x-llamastack-provider-data" in str(e) # Should mention header name assert "x-llamastack-provider-data" in str(e) # Should mention header name
class TestLiteLLMOpenAIMixinStreamOptionsInjection:
"""Test cases for automatic stream_options injection in LiteLLMOpenAIMixin"""
@pytest.fixture
def mixin_with_model_store(self, adapter_with_config_key):
"""Fixture to create adapter with mocked model store"""
mock_model_store = AsyncMock()
mock_model = MagicMock()
mock_model.provider_resource_id = "test-model-id"
mock_model_store.get_model = AsyncMock(return_value=mock_model)
adapter_with_config_key.model_store = mock_model_store
return adapter_with_config_key
async def test_chat_completion_injects_stream_options_when_telemetry_active(self, mixin_with_model_store):
"""Test that stream_options is injected for streaming chat completion when telemetry is active"""
mock_span = MagicMock()
mock_span.is_recording.return_value = True
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion:
mock_acompletion.return_value = MagicMock()
await mixin_with_model_store.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="test-model", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=True
)
)
mock_acompletion.assert_called_once()
call_kwargs = mock_acompletion.call_args[1]
assert call_kwargs["stream_options"] == {"include_usage": True}
async def test_chat_completion_preserves_existing_stream_options(self, mixin_with_model_store):
"""Test that existing stream_options are preserved with include_usage added"""
mock_span = MagicMock()
mock_span.is_recording.return_value = True
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion:
mock_acompletion.return_value = MagicMock()
await mixin_with_model_store.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="test-model",
messages=[OpenAIUserMessageParam(role="user", content="Hello")],
stream=True,
stream_options={"other_option": True},
)
)
call_kwargs = mock_acompletion.call_args[1]
assert call_kwargs["stream_options"] == {"other_option": True, "include_usage": True}
async def test_chat_completion_no_injection_when_telemetry_inactive(self, mixin_with_model_store):
"""Test that stream_options is NOT injected when telemetry is inactive"""
mock_span = MagicMock()
mock_span.is_recording.return_value = False
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion:
mock_acompletion.return_value = MagicMock()
await mixin_with_model_store.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="test-model", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=True
)
)
call_kwargs = mock_acompletion.call_args[1]
assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None
async def test_chat_completion_no_injection_when_not_streaming(self, mixin_with_model_store):
"""Test that stream_options is NOT injected for non-streaming requests"""
mock_span = MagicMock()
mock_span.is_recording.return_value = True
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion:
mock_acompletion.return_value = MagicMock()
await mixin_with_model_store.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="test-model",
messages=[OpenAIUserMessageParam(role="user", content="Hello")],
stream=False,
)
)
call_kwargs = mock_acompletion.call_args[1]
assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None
async def test_completion_injects_stream_options_when_telemetry_active(self, mixin_with_model_store):
"""Test that stream_options is injected for streaming completion when telemetry is active"""
mock_span = MagicMock()
mock_span.is_recording.return_value = True
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
with patch("litellm.atext_completion", new_callable=AsyncMock) as mock_atext_completion:
mock_atext_completion.return_value = MagicMock()
await mixin_with_model_store.openai_completion(
OpenAICompletionRequestWithExtraBody(model="test-model", prompt="Hello", stream=True)
)
mock_atext_completion.assert_called_once()
call_kwargs = mock_atext_completion.call_args[1]
assert call_kwargs["stream_options"] == {"include_usage": True}
async def test_completion_no_injection_when_telemetry_inactive(self, mixin_with_model_store):
"""Test that stream_options is NOT injected for completion when telemetry is inactive"""
mock_span = MagicMock()
mock_span.is_recording.return_value = False
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
with patch("litellm.atext_completion", new_callable=AsyncMock) as mock_atext_completion:
mock_atext_completion.return_value = MagicMock()
await mixin_with_model_store.openai_completion(
OpenAICompletionRequestWithExtraBody(model="test-model", prompt="Hello", stream=True)
)
call_kwargs = mock_atext_completion.call_args[1]
assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None
async def test_original_params_not_mutated(self, mixin_with_model_store):
"""Test that original params object is not mutated when stream_options is injected"""
mock_span = MagicMock()
mock_span.is_recording.return_value = True
original_params = OpenAIChatCompletionRequestWithExtraBody(
model="test-model", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=True
)
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion:
mock_acompletion.return_value = MagicMock()
await mixin_with_model_store.openai_chat_completion(original_params)
# Original params should not be modified
assert original_params.stream_options is None
async def test_chat_completion_overrides_include_usage_false(self, mixin_with_model_store):
"""Test that include_usage=False is overridden when telemetry is active"""
mock_span = MagicMock()
mock_span.is_recording.return_value = True
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion:
mock_acompletion.return_value = MagicMock()
await mixin_with_model_store.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="test-model",
messages=[OpenAIUserMessageParam(role="user", content="Hello")],
stream=True,
stream_options={"include_usage": False},
)
)
call_kwargs = mock_acompletion.call_args[1]
# Telemetry must override False to ensure complete metrics
assert call_kwargs["stream_options"]["include_usage"] is True

View file

@ -0,0 +1,68 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import MagicMock, patch
from llama_stack.providers.utils.inference.openai_compat import (
get_stream_options_for_telemetry,
)
class TestGetStreamOptionsForTelemetry:
def test_returns_original_when_not_streaming(self):
stream_options = {"keep": True}
result = get_stream_options_for_telemetry(stream_options, False)
assert result is stream_options
def test_streaming_without_active_span_returns_original(self):
stream_options = {"keep": True}
with patch("opentelemetry.trace.get_current_span", return_value=None):
result = get_stream_options_for_telemetry(stream_options, True)
assert result is stream_options
def test_streaming_with_inactive_span_returns_original(self):
stream_options = {"keep": True}
mock_span = MagicMock()
mock_span.is_recording.return_value = False
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
result = get_stream_options_for_telemetry(stream_options, True)
assert result is stream_options
def test_streaming_with_active_span_adds_usage_when_missing(self):
mock_span = MagicMock()
mock_span.is_recording.return_value = True
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
result = get_stream_options_for_telemetry(None, True)
assert result == {"include_usage": True}
def test_streaming_with_active_span_merges_existing_options(self):
stream_options = {"other_option": "value"}
mock_span = MagicMock()
mock_span.is_recording.return_value = True
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
result = get_stream_options_for_telemetry(stream_options, True)
assert result == {"other_option": "value", "include_usage": True}
assert stream_options == {"other_option": "value"}
def test_streaming_with_active_span_overrides_include_usage_false(self):
stream_options = {"include_usage": False}
mock_span = MagicMock()
mock_span.is_recording.return_value = True
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
result = get_stream_options_for_telemetry(stream_options, True)
assert result["include_usage"] is True

View file

@ -934,3 +934,214 @@ class TestOpenAIMixinAllowedModelsInference:
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")] model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
) )
) )
class TestOpenAIMixinStreamOptionsInjection:
"""Test cases for automatic stream_options injection when telemetry is active"""
async def test_chat_completion_injects_stream_options_when_telemetry_active(self, mixin, mock_client_context):
"""Test that stream_options is injected for streaming chat completion when telemetry is active"""
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
# Mock OpenTelemetry span as recording
mock_span = MagicMock()
mock_span.is_recording.return_value = True
with mock_client_context(mixin, mock_client):
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
await mixin.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=True
)
)
mock_client.chat.completions.create.assert_called_once()
call_kwargs = mock_client.chat.completions.create.call_args[1]
assert call_kwargs["stream_options"] == {"include_usage": True}
async def test_chat_completion_preserves_existing_stream_options(self, mixin, mock_client_context):
"""Test that existing stream_options are preserved with include_usage added"""
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
mock_span = MagicMock()
mock_span.is_recording.return_value = True
with mock_client_context(mixin, mock_client):
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
await mixin.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="gpt-4",
messages=[OpenAIUserMessageParam(role="user", content="Hello")],
stream=True,
stream_options={"other_option": True},
)
)
call_kwargs = mock_client.chat.completions.create.call_args[1]
assert call_kwargs["stream_options"] == {"other_option": True, "include_usage": True}
async def test_chat_completion_no_injection_when_telemetry_inactive(self, mixin, mock_client_context):
"""Test that stream_options is NOT injected when telemetry is inactive"""
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
# Mock OpenTelemetry span as not recording
mock_span = MagicMock()
mock_span.is_recording.return_value = False
with mock_client_context(mixin, mock_client):
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
await mixin.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=True
)
)
call_kwargs = mock_client.chat.completions.create.call_args[1]
assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None
async def test_chat_completion_no_injection_when_not_streaming(self, mixin, mock_client_context):
"""Test that stream_options is NOT injected for non-streaming requests"""
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
mock_span = MagicMock()
mock_span.is_recording.return_value = True
with mock_client_context(mixin, mock_client):
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
await mixin.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=False
)
)
call_kwargs = mock_client.chat.completions.create.call_args[1]
assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None
async def test_completion_injects_stream_options_when_telemetry_active(self, mixin, mock_client_context):
"""Test that stream_options is injected for streaming completion when telemetry is active"""
mock_client = MagicMock()
mock_client.completions.create = AsyncMock(return_value=MagicMock())
mock_span = MagicMock()
mock_span.is_recording.return_value = True
with mock_client_context(mixin, mock_client):
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
await mixin.openai_completion(
OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello", stream=True)
)
mock_client.completions.create.assert_called_once()
call_kwargs = mock_client.completions.create.call_args[1]
assert call_kwargs["stream_options"] == {"include_usage": True}
async def test_completion_no_injection_when_telemetry_inactive(self, mixin, mock_client_context):
"""Test that stream_options is NOT injected for completion when telemetry is inactive"""
mock_client = MagicMock()
mock_client.completions.create = AsyncMock(return_value=MagicMock())
mock_span = MagicMock()
mock_span.is_recording.return_value = False
with mock_client_context(mixin, mock_client):
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
await mixin.openai_completion(
OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello", stream=True)
)
call_kwargs = mock_client.completions.create.call_args[1]
assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None
async def test_params_not_mutated(self, mixin, mock_client_context):
"""Test that original params object is not mutated when stream_options is injected"""
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
mock_span = MagicMock()
mock_span.is_recording.return_value = True
original_params = OpenAIChatCompletionRequestWithExtraBody(
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=True
)
with mock_client_context(mixin, mock_client):
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
await mixin.openai_chat_completion(original_params)
# Original params should not be modified
assert original_params.stream_options is None
async def test_chat_completion_overrides_include_usage_false(self, mixin, mock_client_context):
"""Test that include_usage=False is overridden when telemetry is active"""
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
mock_span = MagicMock()
mock_span.is_recording.return_value = True
with mock_client_context(mixin, mock_client):
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
await mixin.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="gpt-4",
messages=[OpenAIUserMessageParam(role="user", content="Hello")],
stream=True,
stream_options={"include_usage": False},
)
)
call_kwargs = mock_client.chat.completions.create.call_args[1]
# Telemetry must override False to ensure complete metrics
assert call_kwargs["stream_options"]["include_usage"] is True
async def test_no_injection_when_provider_doesnt_support_stream_options(self, mixin, mock_client_context):
"""Test that stream_options is NOT injected when provider doesn't support it"""
# Set supports_stream_options to False (like Ollama/vLLM)
mixin.supports_stream_options = False
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
# Mock OpenTelemetry span as recording (telemetry is active)
mock_span = MagicMock()
mock_span.is_recording.return_value = True
with mock_client_context(mixin, mock_client):
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
await mixin.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=True
)
)
call_kwargs = mock_client.chat.completions.create.call_args[1]
# Should NOT inject stream_options even though telemetry is active
assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None
async def test_completion_no_injection_when_provider_doesnt_support_stream_options(
self, mixin, mock_client_context
):
"""Test that stream_options is NOT injected for completion when provider doesn't support it"""
# Set supports_stream_options to False (like Ollama/vLLM)
mixin.supports_stream_options = False
mock_client = MagicMock()
mock_client.completions.create = AsyncMock(return_value=MagicMock())
# Mock OpenTelemetry span as recording (telemetry is active)
mock_span = MagicMock()
mock_span.is_recording.return_value = True
with mock_client_context(mixin, mock_client):
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
await mixin.openai_completion(
OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello", stream=True)
)
call_kwargs = mock_client.completions.create.call_args[1]
# Should NOT inject stream_options even though telemetry is active
assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None