diff --git a/src/llama_stack/providers/remote/inference/bedrock/bedrock.py b/src/llama_stack/providers/remote/inference/bedrock/bedrock.py index cdc6b5f25..532912080 100644 --- a/src/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/src/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -81,14 +81,7 @@ class BedrockInferenceAdapter(OpenAIMixin): self, params: OpenAIChatCompletionRequestWithExtraBody, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - """Override to enable streaming usage metrics and handle authentication errors.""" - # Enable streaming usage metrics when telemetry is active - if params.stream: - if params.stream_options is None: - params.stream_options = {"include_usage": True} - elif "include_usage" not in params.stream_options: - params.stream_options = {**params.stream_options, "include_usage": True} - + """Override to handle authentication errors and null responses.""" try: logger.debug(f"Calling Bedrock OpenAI API with model={params.model}, stream={params.stream}") result = await super().openai_chat_completion(params=params) diff --git a/src/llama_stack/providers/remote/inference/ollama/ollama.py b/src/llama_stack/providers/remote/inference/ollama/ollama.py index e8b872384..72897d22a 100644 --- a/src/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/src/llama_stack/providers/remote/inference/ollama/ollama.py @@ -28,6 +28,9 @@ class OllamaInferenceAdapter(OpenAIMixin): # automatically set by the resolver when instantiating the provider __provider_id__: str + # Ollama does not support the stream_options parameter + supports_stream_options: bool = False + embedding_model_metadata: dict[str, dict[str, int]] = { "all-minilm:l6-v2": { "embedding_dimension": 384, diff --git a/src/llama_stack/providers/remote/inference/runpod/runpod.py b/src/llama_stack/providers/remote/inference/runpod/runpod.py index 04ad12851..6a0dbd82d 100644 --- a/src/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/src/llama_stack/providers/remote/inference/runpod/runpod.py @@ -4,14 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from collections.abc import AsyncIterator - from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin -from llama_stack_api import ( - OpenAIChatCompletion, - OpenAIChatCompletionChunk, - OpenAIChatCompletionRequestWithExtraBody, -) from .config import RunpodImplConfig @@ -29,15 +22,3 @@ class RunpodInferenceAdapter(OpenAIMixin): def get_base_url(self) -> str: """Get base URL for OpenAI client.""" return str(self.config.base_url) - - async def openai_chat_completion( - self, - params: OpenAIChatCompletionRequestWithExtraBody, - ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - """Override to add RunPod-specific stream_options requirement.""" - params = params.model_copy() - - if params.stream and not params.stream_options: - params.stream_options = {"include_usage": True} - - return await super().openai_chat_completion(params) diff --git a/src/llama_stack/providers/remote/inference/vllm/vllm.py b/src/llama_stack/providers/remote/inference/vllm/vllm.py index 45d9176aa..4874233b5 100644 --- a/src/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/src/llama_stack/providers/remote/inference/vllm/vllm.py @@ -30,6 +30,9 @@ class VLLMInferenceAdapter(OpenAIMixin): model_config = ConfigDict(arbitrary_types_allowed=True) + # vLLM does not support the stream_options parameter + supports_stream_options: bool = False + provider_data_api_key_field: str = "vllm_api_token" def get_api_key(self) -> str | None: diff --git a/src/llama_stack/providers/remote/inference/watsonx/watsonx.py b/src/llama_stack/providers/remote/inference/watsonx/watsonx.py index e7d99fb2c..eb7be9d90 100644 --- a/src/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/src/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -13,8 +13,6 @@ import requests from llama_stack.log import get_logger from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin -from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params -from llama_stack.providers.utils.inference.stream_utils import wrap_async_stream from llama_stack_api import ( Model, ModelType, @@ -22,7 +20,6 @@ from llama_stack_api import ( OpenAIChatCompletionChunk, OpenAIChatCompletionRequestWithExtraBody, OpenAIChatCompletionUsage, - OpenAICompletion, OpenAICompletionRequestWithExtraBody, OpenAIEmbeddingsRequestWithExtraBody, OpenAIEmbeddingsResponse, @@ -48,57 +45,25 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin): openai_compat_api_base=self.get_base_url(), ) + def _litellm_extra_request_params( + self, + params: OpenAIChatCompletionRequestWithExtraBody | OpenAICompletionRequestWithExtraBody, + ) -> dict[str, Any]: + # These are watsonx-specific parameters used by LiteLLM. + return {"timeout": self.config.timeout, "project_id": self.config.project_id} + async def openai_chat_completion( self, params: OpenAIChatCompletionRequestWithExtraBody, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: """ - Override parent method to add timeout and inject usage object when missing. + Override parent method to inject usage object when missing. + This works around a LiteLLM defect where usage block is sometimes dropped. + Note: request parameter construction (including telemetry-driven stream_options injection) + is handled by LiteLLMOpenAIMixin via _litellm_extra_request_params(). """ - - # Add usage tracking for streaming when telemetry is active - stream_options = params.stream_options - if params.stream: - if stream_options is None: - stream_options = {"include_usage": True} - elif "include_usage" not in stream_options: - stream_options = {**stream_options, "include_usage": True} - - model_obj = await self.model_store.get_model(params.model) - - request_params = await prepare_openai_completion_params( - model=self.get_litellm_model_name(model_obj.provider_resource_id), - messages=params.messages, - frequency_penalty=params.frequency_penalty, - function_call=params.function_call, - functions=params.functions, - logit_bias=params.logit_bias, - logprobs=params.logprobs, - max_completion_tokens=params.max_completion_tokens, - max_tokens=params.max_tokens, - n=params.n, - parallel_tool_calls=params.parallel_tool_calls, - presence_penalty=params.presence_penalty, - response_format=params.response_format, - seed=params.seed, - stop=params.stop, - stream=params.stream, - stream_options=stream_options, - temperature=params.temperature, - tool_choice=params.tool_choice, - tools=params.tools, - top_logprobs=params.top_logprobs, - top_p=params.top_p, - user=params.user, - api_key=self.get_api_key(), - api_base=self.api_base, - # These are watsonx-specific parameters - timeout=self.config.timeout, - project_id=self.config.project_id, - ) - - result = await litellm.acompletion(**request_params) + result = await super().openai_chat_completion(params) # If not streaming, check and inject usage if missing if not params.stream: @@ -175,49 +140,6 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin): logger.error(f"Error normalizing stream: {e}", exc_info=True) raise - async def openai_completion( - self, - params: OpenAICompletionRequestWithExtraBody, - ) -> OpenAICompletion | AsyncIterator[OpenAICompletion]: - """ - Override parent method to add watsonx-specific parameters. - """ - from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params - - model_obj = await self.model_store.get_model(params.model) - - request_params = await prepare_openai_completion_params( - model=self.get_litellm_model_name(model_obj.provider_resource_id), - prompt=params.prompt, - best_of=params.best_of, - echo=params.echo, - frequency_penalty=params.frequency_penalty, - logit_bias=params.logit_bias, - logprobs=params.logprobs, - max_tokens=params.max_tokens, - n=params.n, - presence_penalty=params.presence_penalty, - seed=params.seed, - stop=params.stop, - stream=params.stream, - stream_options=params.stream_options, - temperature=params.temperature, - top_p=params.top_p, - user=params.user, - suffix=params.suffix, - api_key=self.get_api_key(), - api_base=self.api_base, - # These are watsonx-specific parameters - timeout=self.config.timeout, - project_id=self.config.project_id, - ) - result = await litellm.atext_completion(**request_params) - - if params.stream: - return wrap_async_stream(result) # type: ignore[arg-type] # LiteLLM streaming types - - return result # type: ignore[return-value] # external lib lacks type stubs - async def openai_embeddings( self, params: OpenAIEmbeddingsRequestWithExtraBody, diff --git a/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py index bb7f972d4..b1190fc77 100644 --- a/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -7,6 +7,7 @@ import base64 import struct from collections.abc import AsyncIterator +from typing import Any import litellm @@ -14,6 +15,7 @@ from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry from llama_stack.providers.utils.inference.openai_compat import ( + get_stream_options_for_telemetry, prepare_openai_completion_params, ) from llama_stack.providers.utils.inference.stream_utils import wrap_async_stream @@ -50,6 +52,7 @@ class LiteLLMOpenAIMixin( openai_compat_api_base: str | None = None, download_images: bool = False, json_schema_strict: bool = True, + supports_stream_options: bool = True, ): """ Initialize the LiteLLMOpenAIMixin. @@ -61,6 +64,7 @@ class LiteLLMOpenAIMixin( :param openai_compat_api_base: The base URL for OpenAI compatibility, or None if not using OpenAI compatibility. :param download_images: Whether to download images and convert to base64 for message conversion. :param json_schema_strict: Whether to use strict mode for JSON schema validation. + :param supports_stream_options: Whether the provider supports stream_options parameter. """ ModelRegistryHelper.__init__(self, model_entries=model_entries) @@ -70,6 +74,7 @@ class LiteLLMOpenAIMixin( self.api_base = openai_compat_api_base self.download_images = download_images self.json_schema_strict = json_schema_strict + self.supports_stream_options = supports_stream_options if openai_compat_api_base: self.is_openai_compat = True @@ -180,6 +185,11 @@ class LiteLLMOpenAIMixin( self, params: OpenAICompletionRequestWithExtraBody, ) -> OpenAICompletion | AsyncIterator[OpenAICompletion]: + # Inject stream_options when streaming and telemetry is active + stream_options = get_stream_options_for_telemetry( + params.stream_options, params.stream, self.supports_stream_options + ) + if not self.model_store: raise ValueError("Model store is not initialized") @@ -202,13 +212,14 @@ class LiteLLMOpenAIMixin( seed=params.seed, stop=params.stop, stream=params.stream, - stream_options=params.stream_options, + stream_options=stream_options, temperature=params.temperature, top_p=params.top_p, user=params.user, suffix=params.suffix, api_key=self.get_api_key(), api_base=self.api_base, + **self._litellm_extra_request_params(params), ) # LiteLLM returns compatible type but mypy can't verify external library result = await litellm.atext_completion(**request_params) @@ -222,14 +233,10 @@ class LiteLLMOpenAIMixin( self, params: OpenAIChatCompletionRequestWithExtraBody, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - # Add usage tracking for streaming when telemetry is active - - stream_options = params.stream_options - if params.stream: - if stream_options is None: - stream_options = {"include_usage": True} - elif "include_usage" not in stream_options: - stream_options = {**stream_options, "include_usage": True} + # Inject stream_options when streaming and telemetry is active + stream_options = get_stream_options_for_telemetry( + params.stream_options, params.stream, self.supports_stream_options + ) if not self.model_store: raise ValueError("Model store is not initialized") @@ -265,6 +272,7 @@ class LiteLLMOpenAIMixin( user=params.user, api_key=self.get_api_key(), api_base=self.api_base, + **self._litellm_extra_request_params(params), ) # LiteLLM returns compatible type but mypy can't verify external library result = await litellm.acompletion(**request_params) @@ -288,6 +296,20 @@ class LiteLLMOpenAIMixin( return model in litellm.models_by_provider[self.litellm_provider_name] + def _litellm_extra_request_params( + self, + params: OpenAIChatCompletionRequestWithExtraBody | OpenAICompletionRequestWithExtraBody, + ) -> dict[str, Any]: + """ + Provider hook for extra LiteLLM/OpenAI-compat request params. + + This is intentionally a narrow hook so provider adapters (e.g. WatsonX) + can add provider-specific kwargs (timeouts, project IDs, etc.) while the + mixin remains the single source of truth for telemetry-driven + stream_options injection. + """ + return {} + def b64_encode_openai_embeddings_response( response_data: list[dict], encoding_format: str | None = "float" diff --git a/src/llama_stack/providers/utils/inference/openai_compat.py b/src/llama_stack/providers/utils/inference/openai_compat.py index 3ce7d361d..8a63fc519 100644 --- a/src/llama_stack/providers/utils/inference/openai_compat.py +++ b/src/llama_stack/providers/utils/inference/openai_compat.py @@ -235,3 +235,40 @@ def prepare_openai_embeddings_params( params["user"] = user return params + + +def get_stream_options_for_telemetry( + stream_options: dict[str, Any] | None, + is_streaming: bool, + supports_stream_options: bool = True, +) -> dict[str, Any] | None: + """ + Inject stream_options when streaming and telemetry is active. + + Active telemetry takes precedence over caller preference to ensure + complete and consistent observability metrics. + + Args: + stream_options: Existing stream options from the request + is_streaming: Whether this is a streaming request + supports_stream_options: Whether the provider supports stream_options parameter + + Returns: + Updated stream_options with include_usage=True if conditions are met, otherwise original options + """ + if not is_streaming: + return stream_options + + if not supports_stream_options: + return stream_options + + from opentelemetry import trace + + span = trace.get_current_span() + if not span or not span.is_recording(): + return stream_options + + if stream_options is None: + return {"include_usage": True} + + return {**stream_options, "include_usage": True} diff --git a/src/llama_stack/providers/utils/inference/openai_mixin.py b/src/llama_stack/providers/utils/inference/openai_mixin.py index 499379214..d1983d6f0 100644 --- a/src/llama_stack/providers/utils/inference/openai_mixin.py +++ b/src/llama_stack/providers/utils/inference/openai_mixin.py @@ -16,7 +16,10 @@ from pydantic import BaseModel, ConfigDict from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig -from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params +from llama_stack.providers.utils.inference.openai_compat import ( + get_stream_options_for_telemetry, + prepare_openai_completion_params, +) from llama_stack.providers.utils.inference.prompt_adapter import localize_image_content from llama_stack_api import ( Model, @@ -47,6 +50,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): The behavior of this class can be customized by child classes in the following ways: - overwrite_completion_id: If True, overwrites the 'id' field in OpenAI responses - download_images: If True, downloads images and converts to base64 for providers that require it + - supports_stream_options: If False, disables stream_options injection for providers that don't support it - embedding_model_metadata: A dictionary mapping model IDs to their embedding metadata - construct_model_from_identifier: Method to construct a Model instance corresponding to the given identifier - provider_data_api_key_field: Optional field name in provider data to look for API key @@ -74,6 +78,10 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): # for providers that require base64 encoded images instead of URLs. download_images: bool = False + # Allow subclasses to control whether the provider supports stream_options parameter + # Set to False for providers that don't support stream_options (e.g., Ollama, vLLM) + supports_stream_options: bool = True + # Embedding model metadata for this provider # Can be set by subclasses or instances to provide embedding models # Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}} @@ -270,6 +278,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): """ Direct OpenAI completion API call. """ + # Inject stream_options when streaming and telemetry is active + stream_options = get_stream_options_for_telemetry( + params.stream_options, params.stream or False, self.supports_stream_options + ) + provider_model_id = await self._get_provider_model_id(params.model) self._validate_model_allowed(provider_model_id) @@ -287,7 +300,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): seed=params.seed, stop=params.stop, stream=params.stream, - stream_options=params.stream_options, + stream_options=stream_options, temperature=params.temperature, top_p=params.top_p, user=params.user, @@ -306,6 +319,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): """ Direct OpenAI chat completion API call. """ + # Inject stream_options when streaming and telemetry is active + stream_options = get_stream_options_for_telemetry( + params.stream_options, params.stream or False, self.supports_stream_options + ) + provider_model_id = await self._get_provider_model_id(params.model) self._validate_model_allowed(provider_model_id) @@ -346,7 +364,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): seed=params.seed, stop=params.stop, stream=params.stream, - stream_options=params.stream_options, + stream_options=stream_options, temperature=params.temperature, tool_choice=params.tool_choice, tools=params.tools, diff --git a/tests/unit/providers/inference/test_litellm_openai_mixin.py b/tests/unit/providers/inference/test_litellm_openai_mixin.py index 1f6a687d6..1e03ea0cf 100644 --- a/tests/unit/providers/inference/test_litellm_openai_mixin.py +++ b/tests/unit/providers/inference/test_litellm_openai_mixin.py @@ -5,13 +5,18 @@ # the root directory of this source tree. import json -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock, patch import pytest from pydantic import BaseModel, Field from llama_stack.core.request_headers import request_provider_data_context from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin +from llama_stack_api import ( + OpenAIChatCompletionRequestWithExtraBody, + OpenAICompletionRequestWithExtraBody, + OpenAIUserMessageParam, +) # Test fixtures and helper classes @@ -109,3 +114,168 @@ def test_error_message_includes_correct_field_names(adapter_without_config_key): except ValueError as e: assert "test_api_key" in str(e) # Should mention the correct field name assert "x-llamastack-provider-data" in str(e) # Should mention header name + + +class TestLiteLLMOpenAIMixinStreamOptionsInjection: + """Test cases for automatic stream_options injection in LiteLLMOpenAIMixin""" + + @pytest.fixture + def mixin_with_model_store(self, adapter_with_config_key): + """Fixture to create adapter with mocked model store""" + mock_model_store = AsyncMock() + mock_model = MagicMock() + mock_model.provider_resource_id = "test-model-id" + mock_model_store.get_model = AsyncMock(return_value=mock_model) + adapter_with_config_key.model_store = mock_model_store + return adapter_with_config_key + + async def test_chat_completion_injects_stream_options_when_telemetry_active(self, mixin_with_model_store): + """Test that stream_options is injected for streaming chat completion when telemetry is active""" + mock_span = MagicMock() + mock_span.is_recording.return_value = True + + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion: + mock_acompletion.return_value = MagicMock() + + await mixin_with_model_store.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="test-model", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=True + ) + ) + + mock_acompletion.assert_called_once() + call_kwargs = mock_acompletion.call_args[1] + assert call_kwargs["stream_options"] == {"include_usage": True} + + async def test_chat_completion_preserves_existing_stream_options(self, mixin_with_model_store): + """Test that existing stream_options are preserved with include_usage added""" + mock_span = MagicMock() + mock_span.is_recording.return_value = True + + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion: + mock_acompletion.return_value = MagicMock() + + await mixin_with_model_store.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="test-model", + messages=[OpenAIUserMessageParam(role="user", content="Hello")], + stream=True, + stream_options={"other_option": True}, + ) + ) + + call_kwargs = mock_acompletion.call_args[1] + assert call_kwargs["stream_options"] == {"other_option": True, "include_usage": True} + + async def test_chat_completion_no_injection_when_telemetry_inactive(self, mixin_with_model_store): + """Test that stream_options is NOT injected when telemetry is inactive""" + mock_span = MagicMock() + mock_span.is_recording.return_value = False + + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion: + mock_acompletion.return_value = MagicMock() + + await mixin_with_model_store.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="test-model", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=True + ) + ) + + call_kwargs = mock_acompletion.call_args[1] + assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None + + async def test_chat_completion_no_injection_when_not_streaming(self, mixin_with_model_store): + """Test that stream_options is NOT injected for non-streaming requests""" + mock_span = MagicMock() + mock_span.is_recording.return_value = True + + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion: + mock_acompletion.return_value = MagicMock() + + await mixin_with_model_store.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="test-model", + messages=[OpenAIUserMessageParam(role="user", content="Hello")], + stream=False, + ) + ) + + call_kwargs = mock_acompletion.call_args[1] + assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None + + async def test_completion_injects_stream_options_when_telemetry_active(self, mixin_with_model_store): + """Test that stream_options is injected for streaming completion when telemetry is active""" + mock_span = MagicMock() + mock_span.is_recording.return_value = True + + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + with patch("litellm.atext_completion", new_callable=AsyncMock) as mock_atext_completion: + mock_atext_completion.return_value = MagicMock() + + await mixin_with_model_store.openai_completion( + OpenAICompletionRequestWithExtraBody(model="test-model", prompt="Hello", stream=True) + ) + + mock_atext_completion.assert_called_once() + call_kwargs = mock_atext_completion.call_args[1] + assert call_kwargs["stream_options"] == {"include_usage": True} + + async def test_completion_no_injection_when_telemetry_inactive(self, mixin_with_model_store): + """Test that stream_options is NOT injected for completion when telemetry is inactive""" + mock_span = MagicMock() + mock_span.is_recording.return_value = False + + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + with patch("litellm.atext_completion", new_callable=AsyncMock) as mock_atext_completion: + mock_atext_completion.return_value = MagicMock() + + await mixin_with_model_store.openai_completion( + OpenAICompletionRequestWithExtraBody(model="test-model", prompt="Hello", stream=True) + ) + + call_kwargs = mock_atext_completion.call_args[1] + assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None + + async def test_original_params_not_mutated(self, mixin_with_model_store): + """Test that original params object is not mutated when stream_options is injected""" + mock_span = MagicMock() + mock_span.is_recording.return_value = True + + original_params = OpenAIChatCompletionRequestWithExtraBody( + model="test-model", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=True + ) + + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion: + mock_acompletion.return_value = MagicMock() + + await mixin_with_model_store.openai_chat_completion(original_params) + + # Original params should not be modified + assert original_params.stream_options is None + + async def test_chat_completion_overrides_include_usage_false(self, mixin_with_model_store): + """Test that include_usage=False is overridden when telemetry is active""" + mock_span = MagicMock() + mock_span.is_recording.return_value = True + + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion: + mock_acompletion.return_value = MagicMock() + + await mixin_with_model_store.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="test-model", + messages=[OpenAIUserMessageParam(role="user", content="Hello")], + stream=True, + stream_options={"include_usage": False}, + ) + ) + + call_kwargs = mock_acompletion.call_args[1] + # Telemetry must override False to ensure complete metrics + assert call_kwargs["stream_options"]["include_usage"] is True diff --git a/tests/unit/providers/utils/inference/test_openai_compat.py b/tests/unit/providers/utils/inference/test_openai_compat.py new file mode 100644 index 000000000..0d4031f54 --- /dev/null +++ b/tests/unit/providers/utils/inference/test_openai_compat.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import MagicMock, patch + +from llama_stack.providers.utils.inference.openai_compat import ( + get_stream_options_for_telemetry, +) + + +class TestGetStreamOptionsForTelemetry: + def test_returns_original_when_not_streaming(self): + stream_options = {"keep": True} + + result = get_stream_options_for_telemetry(stream_options, False) + + assert result is stream_options + + def test_streaming_without_active_span_returns_original(self): + stream_options = {"keep": True} + + with patch("opentelemetry.trace.get_current_span", return_value=None): + result = get_stream_options_for_telemetry(stream_options, True) + + assert result is stream_options + + def test_streaming_with_inactive_span_returns_original(self): + stream_options = {"keep": True} + mock_span = MagicMock() + mock_span.is_recording.return_value = False + + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + result = get_stream_options_for_telemetry(stream_options, True) + + assert result is stream_options + + def test_streaming_with_active_span_adds_usage_when_missing(self): + mock_span = MagicMock() + mock_span.is_recording.return_value = True + + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + result = get_stream_options_for_telemetry(None, True) + + assert result == {"include_usage": True} + + def test_streaming_with_active_span_merges_existing_options(self): + stream_options = {"other_option": "value"} + mock_span = MagicMock() + mock_span.is_recording.return_value = True + + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + result = get_stream_options_for_telemetry(stream_options, True) + + assert result == {"other_option": "value", "include_usage": True} + assert stream_options == {"other_option": "value"} + + def test_streaming_with_active_span_overrides_include_usage_false(self): + stream_options = {"include_usage": False} + mock_span = MagicMock() + mock_span.is_recording.return_value = True + + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + result = get_stream_options_for_telemetry(stream_options, True) + + assert result["include_usage"] is True diff --git a/tests/unit/providers/utils/inference/test_openai_mixin.py b/tests/unit/providers/utils/inference/test_openai_mixin.py index 02d44f2ba..b84635c10 100644 --- a/tests/unit/providers/utils/inference/test_openai_mixin.py +++ b/tests/unit/providers/utils/inference/test_openai_mixin.py @@ -934,3 +934,214 @@ class TestOpenAIMixinAllowedModelsInference: model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")] ) ) + + +class TestOpenAIMixinStreamOptionsInjection: + """Test cases for automatic stream_options injection when telemetry is active""" + + async def test_chat_completion_injects_stream_options_when_telemetry_active(self, mixin, mock_client_context): + """Test that stream_options is injected for streaming chat completion when telemetry is active""" + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + + # Mock OpenTelemetry span as recording + mock_span = MagicMock() + mock_span.is_recording.return_value = True + + with mock_client_context(mixin, mock_client): + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=True + ) + ) + + mock_client.chat.completions.create.assert_called_once() + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert call_kwargs["stream_options"] == {"include_usage": True} + + async def test_chat_completion_preserves_existing_stream_options(self, mixin, mock_client_context): + """Test that existing stream_options are preserved with include_usage added""" + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + + mock_span = MagicMock() + mock_span.is_recording.return_value = True + + with mock_client_context(mixin, mock_client): + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4", + messages=[OpenAIUserMessageParam(role="user", content="Hello")], + stream=True, + stream_options={"other_option": True}, + ) + ) + + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert call_kwargs["stream_options"] == {"other_option": True, "include_usage": True} + + async def test_chat_completion_no_injection_when_telemetry_inactive(self, mixin, mock_client_context): + """Test that stream_options is NOT injected when telemetry is inactive""" + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + + # Mock OpenTelemetry span as not recording + mock_span = MagicMock() + mock_span.is_recording.return_value = False + + with mock_client_context(mixin, mock_client): + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=True + ) + ) + + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None + + async def test_chat_completion_no_injection_when_not_streaming(self, mixin, mock_client_context): + """Test that stream_options is NOT injected for non-streaming requests""" + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + + mock_span = MagicMock() + mock_span.is_recording.return_value = True + + with mock_client_context(mixin, mock_client): + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=False + ) + ) + + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None + + async def test_completion_injects_stream_options_when_telemetry_active(self, mixin, mock_client_context): + """Test that stream_options is injected for streaming completion when telemetry is active""" + mock_client = MagicMock() + mock_client.completions.create = AsyncMock(return_value=MagicMock()) + + mock_span = MagicMock() + mock_span.is_recording.return_value = True + + with mock_client_context(mixin, mock_client): + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + await mixin.openai_completion( + OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello", stream=True) + ) + + mock_client.completions.create.assert_called_once() + call_kwargs = mock_client.completions.create.call_args[1] + assert call_kwargs["stream_options"] == {"include_usage": True} + + async def test_completion_no_injection_when_telemetry_inactive(self, mixin, mock_client_context): + """Test that stream_options is NOT injected for completion when telemetry is inactive""" + mock_client = MagicMock() + mock_client.completions.create = AsyncMock(return_value=MagicMock()) + + mock_span = MagicMock() + mock_span.is_recording.return_value = False + + with mock_client_context(mixin, mock_client): + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + await mixin.openai_completion( + OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello", stream=True) + ) + + call_kwargs = mock_client.completions.create.call_args[1] + assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None + + async def test_params_not_mutated(self, mixin, mock_client_context): + """Test that original params object is not mutated when stream_options is injected""" + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + + mock_span = MagicMock() + mock_span.is_recording.return_value = True + + original_params = OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=True + ) + + with mock_client_context(mixin, mock_client): + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + await mixin.openai_chat_completion(original_params) + + # Original params should not be modified + assert original_params.stream_options is None + + async def test_chat_completion_overrides_include_usage_false(self, mixin, mock_client_context): + """Test that include_usage=False is overridden when telemetry is active""" + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + + mock_span = MagicMock() + mock_span.is_recording.return_value = True + + with mock_client_context(mixin, mock_client): + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4", + messages=[OpenAIUserMessageParam(role="user", content="Hello")], + stream=True, + stream_options={"include_usage": False}, + ) + ) + + call_kwargs = mock_client.chat.completions.create.call_args[1] + # Telemetry must override False to ensure complete metrics + assert call_kwargs["stream_options"]["include_usage"] is True + + async def test_no_injection_when_provider_doesnt_support_stream_options(self, mixin, mock_client_context): + """Test that stream_options is NOT injected when provider doesn't support it""" + # Set supports_stream_options to False (like Ollama/vLLM) + mixin.supports_stream_options = False + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + + # Mock OpenTelemetry span as recording (telemetry is active) + mock_span = MagicMock() + mock_span.is_recording.return_value = True + + with mock_client_context(mixin, mock_client): + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=True + ) + ) + + call_kwargs = mock_client.chat.completions.create.call_args[1] + # Should NOT inject stream_options even though telemetry is active + assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None + + async def test_completion_no_injection_when_provider_doesnt_support_stream_options( + self, mixin, mock_client_context + ): + """Test that stream_options is NOT injected for completion when provider doesn't support it""" + # Set supports_stream_options to False (like Ollama/vLLM) + mixin.supports_stream_options = False + + mock_client = MagicMock() + mock_client.completions.create = AsyncMock(return_value=MagicMock()) + + # Mock OpenTelemetry span as recording (telemetry is active) + mock_span = MagicMock() + mock_span.is_recording.return_value = True + + with mock_client_context(mixin, mock_client): + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + await mixin.openai_completion( + OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello", stream=True) + ) + + call_kwargs = mock_client.completions.create.call_args[1] + # Should NOT inject stream_options even though telemetry is active + assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None