mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-21 07:38:41 +00:00
feat: enable streaming usage metrics for OpenAI-compatible providers (#4326)
Inject `stream_options={"include_usage": True} `when streaming and
OpenTelemetry telemetry is active. Telemetry always overrides any caller
preference to ensure complete and consistent observability metrics.
Changes:
- Add conditional stream_options injection to OpenAIMixin (benefits
OpenAI, Bedrock, Runpod, Together, Fireworks providers)
- Add conditional stream_options injection to LiteLLMOpenAIMixin
(benefits WatsonX and other litellm-based providers)
- Check telemetry status using trace.get_current_span().is_recording()
- Override include_usage=False when telemetry active to prevent metric
gaps
- Unit tests for this functionality
Fixes #3981
Note: this work originated in PR #4200, which I closed after rebasing on
the telemetry changes. This PR rebases those commits, incorporates the
Bedrock feedback, and carries forward the same scope described there.
## Test Plan
#### OpenAIMixin + telemetry injection tests
PYTHONPATH=src python -m pytest
tests/unit/providers/utils/inference/test_openai_mixin.py
#### LiteLLM OpenAIMixin tests
PYTHONPATH=src python -m pytest
tests/unit/providers/inference/test_litellm_openai_mixin.py -v
#### Broader inference provider
PYTHONPATH=src python -m pytest tests/unit/providers/inference/
--ignore=tests/unit/providers/inference/test_inference_client_caching.py
-v
This commit is contained in:
parent
5ebcde3042
commit
bd35aa4d78
11 changed files with 558 additions and 130 deletions
|
|
@ -81,14 +81,7 @@ class BedrockInferenceAdapter(OpenAIMixin):
|
|||
self,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""Override to enable streaming usage metrics and handle authentication errors."""
|
||||
# Enable streaming usage metrics when telemetry is active
|
||||
if params.stream:
|
||||
if params.stream_options is None:
|
||||
params.stream_options = {"include_usage": True}
|
||||
elif "include_usage" not in params.stream_options:
|
||||
params.stream_options = {**params.stream_options, "include_usage": True}
|
||||
|
||||
"""Override to handle authentication errors and null responses."""
|
||||
try:
|
||||
logger.debug(f"Calling Bedrock OpenAI API with model={params.model}, stream={params.stream}")
|
||||
result = await super().openai_chat_completion(params=params)
|
||||
|
|
|
|||
|
|
@ -28,6 +28,9 @@ class OllamaInferenceAdapter(OpenAIMixin):
|
|||
# automatically set by the resolver when instantiating the provider
|
||||
__provider_id__: str
|
||||
|
||||
# Ollama does not support the stream_options parameter
|
||||
supports_stream_options: bool = False
|
||||
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"all-minilm:l6-v2": {
|
||||
"embedding_dimension": 384,
|
||||
|
|
|
|||
|
|
@ -4,14 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack_api import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
)
|
||||
|
||||
from .config import RunpodImplConfig
|
||||
|
||||
|
|
@ -29,15 +22,3 @@ class RunpodInferenceAdapter(OpenAIMixin):
|
|||
def get_base_url(self) -> str:
|
||||
"""Get base URL for OpenAI client."""
|
||||
return str(self.config.base_url)
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""Override to add RunPod-specific stream_options requirement."""
|
||||
params = params.model_copy()
|
||||
|
||||
if params.stream and not params.stream_options:
|
||||
params.stream_options = {"include_usage": True}
|
||||
|
||||
return await super().openai_chat_completion(params)
|
||||
|
|
|
|||
|
|
@ -30,6 +30,9 @@ class VLLMInferenceAdapter(OpenAIMixin):
|
|||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
# vLLM does not support the stream_options parameter
|
||||
supports_stream_options: bool = False
|
||||
|
||||
provider_data_api_key_field: str = "vllm_api_token"
|
||||
|
||||
def get_api_key(self) -> str | None:
|
||||
|
|
|
|||
|
|
@ -13,8 +13,6 @@ import requests
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||
from llama_stack.providers.utils.inference.stream_utils import wrap_async_stream
|
||||
from llama_stack_api import (
|
||||
Model,
|
||||
ModelType,
|
||||
|
|
@ -22,7 +20,6 @@ from llama_stack_api import (
|
|||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIChatCompletionUsage,
|
||||
OpenAICompletion,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
|
|
@ -48,57 +45,25 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
openai_compat_api_base=self.get_base_url(),
|
||||
)
|
||||
|
||||
def _litellm_extra_request_params(
|
||||
self,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody | OpenAICompletionRequestWithExtraBody,
|
||||
) -> dict[str, Any]:
|
||||
# These are watsonx-specific parameters used by LiteLLM.
|
||||
return {"timeout": self.config.timeout, "project_id": self.config.project_id}
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""
|
||||
Override parent method to add timeout and inject usage object when missing.
|
||||
Override parent method to inject usage object when missing.
|
||||
|
||||
This works around a LiteLLM defect where usage block is sometimes dropped.
|
||||
Note: request parameter construction (including telemetry-driven stream_options injection)
|
||||
is handled by LiteLLMOpenAIMixin via _litellm_extra_request_params().
|
||||
"""
|
||||
|
||||
# Add usage tracking for streaming when telemetry is active
|
||||
stream_options = params.stream_options
|
||||
if params.stream:
|
||||
if stream_options is None:
|
||||
stream_options = {"include_usage": True}
|
||||
elif "include_usage" not in stream_options:
|
||||
stream_options = {**stream_options, "include_usage": True}
|
||||
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
|
||||
request_params = await prepare_openai_completion_params(
|
||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||
messages=params.messages,
|
||||
frequency_penalty=params.frequency_penalty,
|
||||
function_call=params.function_call,
|
||||
functions=params.functions,
|
||||
logit_bias=params.logit_bias,
|
||||
logprobs=params.logprobs,
|
||||
max_completion_tokens=params.max_completion_tokens,
|
||||
max_tokens=params.max_tokens,
|
||||
n=params.n,
|
||||
parallel_tool_calls=params.parallel_tool_calls,
|
||||
presence_penalty=params.presence_penalty,
|
||||
response_format=params.response_format,
|
||||
seed=params.seed,
|
||||
stop=params.stop,
|
||||
stream=params.stream,
|
||||
stream_options=stream_options,
|
||||
temperature=params.temperature,
|
||||
tool_choice=params.tool_choice,
|
||||
tools=params.tools,
|
||||
top_logprobs=params.top_logprobs,
|
||||
top_p=params.top_p,
|
||||
user=params.user,
|
||||
api_key=self.get_api_key(),
|
||||
api_base=self.api_base,
|
||||
# These are watsonx-specific parameters
|
||||
timeout=self.config.timeout,
|
||||
project_id=self.config.project_id,
|
||||
)
|
||||
|
||||
result = await litellm.acompletion(**request_params)
|
||||
result = await super().openai_chat_completion(params)
|
||||
|
||||
# If not streaming, check and inject usage if missing
|
||||
if not params.stream:
|
||||
|
|
@ -175,49 +140,6 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
logger.error(f"Error normalizing stream: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
|
||||
"""
|
||||
Override parent method to add watsonx-specific parameters.
|
||||
"""
|
||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
|
||||
request_params = await prepare_openai_completion_params(
|
||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||
prompt=params.prompt,
|
||||
best_of=params.best_of,
|
||||
echo=params.echo,
|
||||
frequency_penalty=params.frequency_penalty,
|
||||
logit_bias=params.logit_bias,
|
||||
logprobs=params.logprobs,
|
||||
max_tokens=params.max_tokens,
|
||||
n=params.n,
|
||||
presence_penalty=params.presence_penalty,
|
||||
seed=params.seed,
|
||||
stop=params.stop,
|
||||
stream=params.stream,
|
||||
stream_options=params.stream_options,
|
||||
temperature=params.temperature,
|
||||
top_p=params.top_p,
|
||||
user=params.user,
|
||||
suffix=params.suffix,
|
||||
api_key=self.get_api_key(),
|
||||
api_base=self.api_base,
|
||||
# These are watsonx-specific parameters
|
||||
timeout=self.config.timeout,
|
||||
project_id=self.config.project_id,
|
||||
)
|
||||
result = await litellm.atext_completion(**request_params)
|
||||
|
||||
if params.stream:
|
||||
return wrap_async_stream(result) # type: ignore[arg-type] # LiteLLM streaming types
|
||||
|
||||
return result # type: ignore[return-value] # external lib lacks type stubs
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
import base64
|
||||
import struct
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
|
||||
|
|
@ -14,6 +15,7 @@ from llama_stack.core.request_headers import NeedsRequestProviderData
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_stream_options_for_telemetry,
|
||||
prepare_openai_completion_params,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.stream_utils import wrap_async_stream
|
||||
|
|
@ -50,6 +52,7 @@ class LiteLLMOpenAIMixin(
|
|||
openai_compat_api_base: str | None = None,
|
||||
download_images: bool = False,
|
||||
json_schema_strict: bool = True,
|
||||
supports_stream_options: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize the LiteLLMOpenAIMixin.
|
||||
|
|
@ -61,6 +64,7 @@ class LiteLLMOpenAIMixin(
|
|||
:param openai_compat_api_base: The base URL for OpenAI compatibility, or None if not using OpenAI compatibility.
|
||||
:param download_images: Whether to download images and convert to base64 for message conversion.
|
||||
:param json_schema_strict: Whether to use strict mode for JSON schema validation.
|
||||
:param supports_stream_options: Whether the provider supports stream_options parameter.
|
||||
"""
|
||||
ModelRegistryHelper.__init__(self, model_entries=model_entries)
|
||||
|
||||
|
|
@ -70,6 +74,7 @@ class LiteLLMOpenAIMixin(
|
|||
self.api_base = openai_compat_api_base
|
||||
self.download_images = download_images
|
||||
self.json_schema_strict = json_schema_strict
|
||||
self.supports_stream_options = supports_stream_options
|
||||
|
||||
if openai_compat_api_base:
|
||||
self.is_openai_compat = True
|
||||
|
|
@ -180,6 +185,11 @@ class LiteLLMOpenAIMixin(
|
|||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
|
||||
# Inject stream_options when streaming and telemetry is active
|
||||
stream_options = get_stream_options_for_telemetry(
|
||||
params.stream_options, params.stream, self.supports_stream_options
|
||||
)
|
||||
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store is not initialized")
|
||||
|
||||
|
|
@ -202,13 +212,14 @@ class LiteLLMOpenAIMixin(
|
|||
seed=params.seed,
|
||||
stop=params.stop,
|
||||
stream=params.stream,
|
||||
stream_options=params.stream_options,
|
||||
stream_options=stream_options,
|
||||
temperature=params.temperature,
|
||||
top_p=params.top_p,
|
||||
user=params.user,
|
||||
suffix=params.suffix,
|
||||
api_key=self.get_api_key(),
|
||||
api_base=self.api_base,
|
||||
**self._litellm_extra_request_params(params),
|
||||
)
|
||||
# LiteLLM returns compatible type but mypy can't verify external library
|
||||
result = await litellm.atext_completion(**request_params)
|
||||
|
|
@ -222,14 +233,10 @@ class LiteLLMOpenAIMixin(
|
|||
self,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
# Add usage tracking for streaming when telemetry is active
|
||||
|
||||
stream_options = params.stream_options
|
||||
if params.stream:
|
||||
if stream_options is None:
|
||||
stream_options = {"include_usage": True}
|
||||
elif "include_usage" not in stream_options:
|
||||
stream_options = {**stream_options, "include_usage": True}
|
||||
# Inject stream_options when streaming and telemetry is active
|
||||
stream_options = get_stream_options_for_telemetry(
|
||||
params.stream_options, params.stream, self.supports_stream_options
|
||||
)
|
||||
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store is not initialized")
|
||||
|
|
@ -265,6 +272,7 @@ class LiteLLMOpenAIMixin(
|
|||
user=params.user,
|
||||
api_key=self.get_api_key(),
|
||||
api_base=self.api_base,
|
||||
**self._litellm_extra_request_params(params),
|
||||
)
|
||||
# LiteLLM returns compatible type but mypy can't verify external library
|
||||
result = await litellm.acompletion(**request_params)
|
||||
|
|
@ -288,6 +296,20 @@ class LiteLLMOpenAIMixin(
|
|||
|
||||
return model in litellm.models_by_provider[self.litellm_provider_name]
|
||||
|
||||
def _litellm_extra_request_params(
|
||||
self,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody | OpenAICompletionRequestWithExtraBody,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Provider hook for extra LiteLLM/OpenAI-compat request params.
|
||||
|
||||
This is intentionally a narrow hook so provider adapters (e.g. WatsonX)
|
||||
can add provider-specific kwargs (timeouts, project IDs, etc.) while the
|
||||
mixin remains the single source of truth for telemetry-driven
|
||||
stream_options injection.
|
||||
"""
|
||||
return {}
|
||||
|
||||
|
||||
def b64_encode_openai_embeddings_response(
|
||||
response_data: list[dict], encoding_format: str | None = "float"
|
||||
|
|
|
|||
|
|
@ -235,3 +235,40 @@ def prepare_openai_embeddings_params(
|
|||
params["user"] = user
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def get_stream_options_for_telemetry(
|
||||
stream_options: dict[str, Any] | None,
|
||||
is_streaming: bool,
|
||||
supports_stream_options: bool = True,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Inject stream_options when streaming and telemetry is active.
|
||||
|
||||
Active telemetry takes precedence over caller preference to ensure
|
||||
complete and consistent observability metrics.
|
||||
|
||||
Args:
|
||||
stream_options: Existing stream options from the request
|
||||
is_streaming: Whether this is a streaming request
|
||||
supports_stream_options: Whether the provider supports stream_options parameter
|
||||
|
||||
Returns:
|
||||
Updated stream_options with include_usage=True if conditions are met, otherwise original options
|
||||
"""
|
||||
if not is_streaming:
|
||||
return stream_options
|
||||
|
||||
if not supports_stream_options:
|
||||
return stream_options
|
||||
|
||||
from opentelemetry import trace
|
||||
|
||||
span = trace.get_current_span()
|
||||
if not span or not span.is_recording():
|
||||
return stream_options
|
||||
|
||||
if stream_options is None:
|
||||
return {"include_usage": True}
|
||||
|
||||
return {**stream_options, "include_usage": True}
|
||||
|
|
|
|||
|
|
@ -16,7 +16,10 @@ from pydantic import BaseModel, ConfigDict
|
|||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_stream_options_for_telemetry,
|
||||
prepare_openai_completion_params,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import localize_image_content
|
||||
from llama_stack_api import (
|
||||
Model,
|
||||
|
|
@ -47,6 +50,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
The behavior of this class can be customized by child classes in the following ways:
|
||||
- overwrite_completion_id: If True, overwrites the 'id' field in OpenAI responses
|
||||
- download_images: If True, downloads images and converts to base64 for providers that require it
|
||||
- supports_stream_options: If False, disables stream_options injection for providers that don't support it
|
||||
- embedding_model_metadata: A dictionary mapping model IDs to their embedding metadata
|
||||
- construct_model_from_identifier: Method to construct a Model instance corresponding to the given identifier
|
||||
- provider_data_api_key_field: Optional field name in provider data to look for API key
|
||||
|
|
@ -74,6 +78,10 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
# for providers that require base64 encoded images instead of URLs.
|
||||
download_images: bool = False
|
||||
|
||||
# Allow subclasses to control whether the provider supports stream_options parameter
|
||||
# Set to False for providers that don't support stream_options (e.g., Ollama, vLLM)
|
||||
supports_stream_options: bool = True
|
||||
|
||||
# Embedding model metadata for this provider
|
||||
# Can be set by subclasses or instances to provide embedding models
|
||||
# Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}}
|
||||
|
|
@ -270,6 +278,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
"""
|
||||
Direct OpenAI completion API call.
|
||||
"""
|
||||
# Inject stream_options when streaming and telemetry is active
|
||||
stream_options = get_stream_options_for_telemetry(
|
||||
params.stream_options, params.stream or False, self.supports_stream_options
|
||||
)
|
||||
|
||||
provider_model_id = await self._get_provider_model_id(params.model)
|
||||
self._validate_model_allowed(provider_model_id)
|
||||
|
||||
|
|
@ -287,7 +300,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
seed=params.seed,
|
||||
stop=params.stop,
|
||||
stream=params.stream,
|
||||
stream_options=params.stream_options,
|
||||
stream_options=stream_options,
|
||||
temperature=params.temperature,
|
||||
top_p=params.top_p,
|
||||
user=params.user,
|
||||
|
|
@ -306,6 +319,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
"""
|
||||
Direct OpenAI chat completion API call.
|
||||
"""
|
||||
# Inject stream_options when streaming and telemetry is active
|
||||
stream_options = get_stream_options_for_telemetry(
|
||||
params.stream_options, params.stream or False, self.supports_stream_options
|
||||
)
|
||||
|
||||
provider_model_id = await self._get_provider_model_id(params.model)
|
||||
self._validate_model_allowed(provider_model_id)
|
||||
|
||||
|
|
@ -346,7 +364,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
seed=params.seed,
|
||||
stop=params.stop,
|
||||
stream=params.stream,
|
||||
stream_options=params.stream_options,
|
||||
stream_options=stream_options,
|
||||
temperature=params.temperature,
|
||||
tool_choice=params.tool_choice,
|
||||
tools=params.tools,
|
||||
|
|
|
|||
|
|
@ -5,13 +5,18 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.core.request_headers import request_provider_data_context
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack_api import (
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
|
||||
|
||||
# Test fixtures and helper classes
|
||||
|
|
@ -109,3 +114,168 @@ def test_error_message_includes_correct_field_names(adapter_without_config_key):
|
|||
except ValueError as e:
|
||||
assert "test_api_key" in str(e) # Should mention the correct field name
|
||||
assert "x-llamastack-provider-data" in str(e) # Should mention header name
|
||||
|
||||
|
||||
class TestLiteLLMOpenAIMixinStreamOptionsInjection:
|
||||
"""Test cases for automatic stream_options injection in LiteLLMOpenAIMixin"""
|
||||
|
||||
@pytest.fixture
|
||||
def mixin_with_model_store(self, adapter_with_config_key):
|
||||
"""Fixture to create adapter with mocked model store"""
|
||||
mock_model_store = AsyncMock()
|
||||
mock_model = MagicMock()
|
||||
mock_model.provider_resource_id = "test-model-id"
|
||||
mock_model_store.get_model = AsyncMock(return_value=mock_model)
|
||||
adapter_with_config_key.model_store = mock_model_store
|
||||
return adapter_with_config_key
|
||||
|
||||
async def test_chat_completion_injects_stream_options_when_telemetry_active(self, mixin_with_model_store):
|
||||
"""Test that stream_options is injected for streaming chat completion when telemetry is active"""
|
||||
mock_span = MagicMock()
|
||||
mock_span.is_recording.return_value = True
|
||||
|
||||
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion:
|
||||
mock_acompletion.return_value = MagicMock()
|
||||
|
||||
await mixin_with_model_store.openai_chat_completion(
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="test-model", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=True
|
||||
)
|
||||
)
|
||||
|
||||
mock_acompletion.assert_called_once()
|
||||
call_kwargs = mock_acompletion.call_args[1]
|
||||
assert call_kwargs["stream_options"] == {"include_usage": True}
|
||||
|
||||
async def test_chat_completion_preserves_existing_stream_options(self, mixin_with_model_store):
|
||||
"""Test that existing stream_options are preserved with include_usage added"""
|
||||
mock_span = MagicMock()
|
||||
mock_span.is_recording.return_value = True
|
||||
|
||||
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion:
|
||||
mock_acompletion.return_value = MagicMock()
|
||||
|
||||
await mixin_with_model_store.openai_chat_completion(
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="test-model",
|
||||
messages=[OpenAIUserMessageParam(role="user", content="Hello")],
|
||||
stream=True,
|
||||
stream_options={"other_option": True},
|
||||
)
|
||||
)
|
||||
|
||||
call_kwargs = mock_acompletion.call_args[1]
|
||||
assert call_kwargs["stream_options"] == {"other_option": True, "include_usage": True}
|
||||
|
||||
async def test_chat_completion_no_injection_when_telemetry_inactive(self, mixin_with_model_store):
|
||||
"""Test that stream_options is NOT injected when telemetry is inactive"""
|
||||
mock_span = MagicMock()
|
||||
mock_span.is_recording.return_value = False
|
||||
|
||||
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion:
|
||||
mock_acompletion.return_value = MagicMock()
|
||||
|
||||
await mixin_with_model_store.openai_chat_completion(
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="test-model", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=True
|
||||
)
|
||||
)
|
||||
|
||||
call_kwargs = mock_acompletion.call_args[1]
|
||||
assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None
|
||||
|
||||
async def test_chat_completion_no_injection_when_not_streaming(self, mixin_with_model_store):
|
||||
"""Test that stream_options is NOT injected for non-streaming requests"""
|
||||
mock_span = MagicMock()
|
||||
mock_span.is_recording.return_value = True
|
||||
|
||||
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion:
|
||||
mock_acompletion.return_value = MagicMock()
|
||||
|
||||
await mixin_with_model_store.openai_chat_completion(
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="test-model",
|
||||
messages=[OpenAIUserMessageParam(role="user", content="Hello")],
|
||||
stream=False,
|
||||
)
|
||||
)
|
||||
|
||||
call_kwargs = mock_acompletion.call_args[1]
|
||||
assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None
|
||||
|
||||
async def test_completion_injects_stream_options_when_telemetry_active(self, mixin_with_model_store):
|
||||
"""Test that stream_options is injected for streaming completion when telemetry is active"""
|
||||
mock_span = MagicMock()
|
||||
mock_span.is_recording.return_value = True
|
||||
|
||||
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
with patch("litellm.atext_completion", new_callable=AsyncMock) as mock_atext_completion:
|
||||
mock_atext_completion.return_value = MagicMock()
|
||||
|
||||
await mixin_with_model_store.openai_completion(
|
||||
OpenAICompletionRequestWithExtraBody(model="test-model", prompt="Hello", stream=True)
|
||||
)
|
||||
|
||||
mock_atext_completion.assert_called_once()
|
||||
call_kwargs = mock_atext_completion.call_args[1]
|
||||
assert call_kwargs["stream_options"] == {"include_usage": True}
|
||||
|
||||
async def test_completion_no_injection_when_telemetry_inactive(self, mixin_with_model_store):
|
||||
"""Test that stream_options is NOT injected for completion when telemetry is inactive"""
|
||||
mock_span = MagicMock()
|
||||
mock_span.is_recording.return_value = False
|
||||
|
||||
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
with patch("litellm.atext_completion", new_callable=AsyncMock) as mock_atext_completion:
|
||||
mock_atext_completion.return_value = MagicMock()
|
||||
|
||||
await mixin_with_model_store.openai_completion(
|
||||
OpenAICompletionRequestWithExtraBody(model="test-model", prompt="Hello", stream=True)
|
||||
)
|
||||
|
||||
call_kwargs = mock_atext_completion.call_args[1]
|
||||
assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None
|
||||
|
||||
async def test_original_params_not_mutated(self, mixin_with_model_store):
|
||||
"""Test that original params object is not mutated when stream_options is injected"""
|
||||
mock_span = MagicMock()
|
||||
mock_span.is_recording.return_value = True
|
||||
|
||||
original_params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="test-model", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=True
|
||||
)
|
||||
|
||||
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion:
|
||||
mock_acompletion.return_value = MagicMock()
|
||||
|
||||
await mixin_with_model_store.openai_chat_completion(original_params)
|
||||
|
||||
# Original params should not be modified
|
||||
assert original_params.stream_options is None
|
||||
|
||||
async def test_chat_completion_overrides_include_usage_false(self, mixin_with_model_store):
|
||||
"""Test that include_usage=False is overridden when telemetry is active"""
|
||||
mock_span = MagicMock()
|
||||
mock_span.is_recording.return_value = True
|
||||
|
||||
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion:
|
||||
mock_acompletion.return_value = MagicMock()
|
||||
|
||||
await mixin_with_model_store.openai_chat_completion(
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="test-model",
|
||||
messages=[OpenAIUserMessageParam(role="user", content="Hello")],
|
||||
stream=True,
|
||||
stream_options={"include_usage": False},
|
||||
)
|
||||
)
|
||||
|
||||
call_kwargs = mock_acompletion.call_args[1]
|
||||
# Telemetry must override False to ensure complete metrics
|
||||
assert call_kwargs["stream_options"]["include_usage"] is True
|
||||
|
|
|
|||
68
tests/unit/providers/utils/inference/test_openai_compat.py
Normal file
68
tests/unit/providers/utils/inference/test_openai_compat.py
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_stream_options_for_telemetry,
|
||||
)
|
||||
|
||||
|
||||
class TestGetStreamOptionsForTelemetry:
|
||||
def test_returns_original_when_not_streaming(self):
|
||||
stream_options = {"keep": True}
|
||||
|
||||
result = get_stream_options_for_telemetry(stream_options, False)
|
||||
|
||||
assert result is stream_options
|
||||
|
||||
def test_streaming_without_active_span_returns_original(self):
|
||||
stream_options = {"keep": True}
|
||||
|
||||
with patch("opentelemetry.trace.get_current_span", return_value=None):
|
||||
result = get_stream_options_for_telemetry(stream_options, True)
|
||||
|
||||
assert result is stream_options
|
||||
|
||||
def test_streaming_with_inactive_span_returns_original(self):
|
||||
stream_options = {"keep": True}
|
||||
mock_span = MagicMock()
|
||||
mock_span.is_recording.return_value = False
|
||||
|
||||
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
result = get_stream_options_for_telemetry(stream_options, True)
|
||||
|
||||
assert result is stream_options
|
||||
|
||||
def test_streaming_with_active_span_adds_usage_when_missing(self):
|
||||
mock_span = MagicMock()
|
||||
mock_span.is_recording.return_value = True
|
||||
|
||||
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
result = get_stream_options_for_telemetry(None, True)
|
||||
|
||||
assert result == {"include_usage": True}
|
||||
|
||||
def test_streaming_with_active_span_merges_existing_options(self):
|
||||
stream_options = {"other_option": "value"}
|
||||
mock_span = MagicMock()
|
||||
mock_span.is_recording.return_value = True
|
||||
|
||||
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
result = get_stream_options_for_telemetry(stream_options, True)
|
||||
|
||||
assert result == {"other_option": "value", "include_usage": True}
|
||||
assert stream_options == {"other_option": "value"}
|
||||
|
||||
def test_streaming_with_active_span_overrides_include_usage_false(self):
|
||||
stream_options = {"include_usage": False}
|
||||
mock_span = MagicMock()
|
||||
mock_span.is_recording.return_value = True
|
||||
|
||||
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
result = get_stream_options_for_telemetry(stream_options, True)
|
||||
|
||||
assert result["include_usage"] is True
|
||||
|
|
@ -934,3 +934,214 @@ class TestOpenAIMixinAllowedModelsInference:
|
|||
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TestOpenAIMixinStreamOptionsInjection:
|
||||
"""Test cases for automatic stream_options injection when telemetry is active"""
|
||||
|
||||
async def test_chat_completion_injects_stream_options_when_telemetry_active(self, mixin, mock_client_context):
|
||||
"""Test that stream_options is injected for streaming chat completion when telemetry is active"""
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
|
||||
|
||||
# Mock OpenTelemetry span as recording
|
||||
mock_span = MagicMock()
|
||||
mock_span.is_recording.return_value = True
|
||||
|
||||
with mock_client_context(mixin, mock_client):
|
||||
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
await mixin.openai_chat_completion(
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=True
|
||||
)
|
||||
)
|
||||
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||
assert call_kwargs["stream_options"] == {"include_usage": True}
|
||||
|
||||
async def test_chat_completion_preserves_existing_stream_options(self, mixin, mock_client_context):
|
||||
"""Test that existing stream_options are preserved with include_usage added"""
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
|
||||
|
||||
mock_span = MagicMock()
|
||||
mock_span.is_recording.return_value = True
|
||||
|
||||
with mock_client_context(mixin, mock_client):
|
||||
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
await mixin.openai_chat_completion(
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="gpt-4",
|
||||
messages=[OpenAIUserMessageParam(role="user", content="Hello")],
|
||||
stream=True,
|
||||
stream_options={"other_option": True},
|
||||
)
|
||||
)
|
||||
|
||||
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||
assert call_kwargs["stream_options"] == {"other_option": True, "include_usage": True}
|
||||
|
||||
async def test_chat_completion_no_injection_when_telemetry_inactive(self, mixin, mock_client_context):
|
||||
"""Test that stream_options is NOT injected when telemetry is inactive"""
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
|
||||
|
||||
# Mock OpenTelemetry span as not recording
|
||||
mock_span = MagicMock()
|
||||
mock_span.is_recording.return_value = False
|
||||
|
||||
with mock_client_context(mixin, mock_client):
|
||||
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
await mixin.openai_chat_completion(
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=True
|
||||
)
|
||||
)
|
||||
|
||||
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||
assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None
|
||||
|
||||
async def test_chat_completion_no_injection_when_not_streaming(self, mixin, mock_client_context):
|
||||
"""Test that stream_options is NOT injected for non-streaming requests"""
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
|
||||
|
||||
mock_span = MagicMock()
|
||||
mock_span.is_recording.return_value = True
|
||||
|
||||
with mock_client_context(mixin, mock_client):
|
||||
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
await mixin.openai_chat_completion(
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=False
|
||||
)
|
||||
)
|
||||
|
||||
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||
assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None
|
||||
|
||||
async def test_completion_injects_stream_options_when_telemetry_active(self, mixin, mock_client_context):
|
||||
"""Test that stream_options is injected for streaming completion when telemetry is active"""
|
||||
mock_client = MagicMock()
|
||||
mock_client.completions.create = AsyncMock(return_value=MagicMock())
|
||||
|
||||
mock_span = MagicMock()
|
||||
mock_span.is_recording.return_value = True
|
||||
|
||||
with mock_client_context(mixin, mock_client):
|
||||
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
await mixin.openai_completion(
|
||||
OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello", stream=True)
|
||||
)
|
||||
|
||||
mock_client.completions.create.assert_called_once()
|
||||
call_kwargs = mock_client.completions.create.call_args[1]
|
||||
assert call_kwargs["stream_options"] == {"include_usage": True}
|
||||
|
||||
async def test_completion_no_injection_when_telemetry_inactive(self, mixin, mock_client_context):
|
||||
"""Test that stream_options is NOT injected for completion when telemetry is inactive"""
|
||||
mock_client = MagicMock()
|
||||
mock_client.completions.create = AsyncMock(return_value=MagicMock())
|
||||
|
||||
mock_span = MagicMock()
|
||||
mock_span.is_recording.return_value = False
|
||||
|
||||
with mock_client_context(mixin, mock_client):
|
||||
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
await mixin.openai_completion(
|
||||
OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello", stream=True)
|
||||
)
|
||||
|
||||
call_kwargs = mock_client.completions.create.call_args[1]
|
||||
assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None
|
||||
|
||||
async def test_params_not_mutated(self, mixin, mock_client_context):
|
||||
"""Test that original params object is not mutated when stream_options is injected"""
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
|
||||
|
||||
mock_span = MagicMock()
|
||||
mock_span.is_recording.return_value = True
|
||||
|
||||
original_params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=True
|
||||
)
|
||||
|
||||
with mock_client_context(mixin, mock_client):
|
||||
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
await mixin.openai_chat_completion(original_params)
|
||||
|
||||
# Original params should not be modified
|
||||
assert original_params.stream_options is None
|
||||
|
||||
async def test_chat_completion_overrides_include_usage_false(self, mixin, mock_client_context):
|
||||
"""Test that include_usage=False is overridden when telemetry is active"""
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
|
||||
|
||||
mock_span = MagicMock()
|
||||
mock_span.is_recording.return_value = True
|
||||
|
||||
with mock_client_context(mixin, mock_client):
|
||||
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
await mixin.openai_chat_completion(
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="gpt-4",
|
||||
messages=[OpenAIUserMessageParam(role="user", content="Hello")],
|
||||
stream=True,
|
||||
stream_options={"include_usage": False},
|
||||
)
|
||||
)
|
||||
|
||||
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||
# Telemetry must override False to ensure complete metrics
|
||||
assert call_kwargs["stream_options"]["include_usage"] is True
|
||||
|
||||
async def test_no_injection_when_provider_doesnt_support_stream_options(self, mixin, mock_client_context):
|
||||
"""Test that stream_options is NOT injected when provider doesn't support it"""
|
||||
# Set supports_stream_options to False (like Ollama/vLLM)
|
||||
mixin.supports_stream_options = False
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
|
||||
|
||||
# Mock OpenTelemetry span as recording (telemetry is active)
|
||||
mock_span = MagicMock()
|
||||
mock_span.is_recording.return_value = True
|
||||
|
||||
with mock_client_context(mixin, mock_client):
|
||||
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
await mixin.openai_chat_completion(
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=True
|
||||
)
|
||||
)
|
||||
|
||||
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||
# Should NOT inject stream_options even though telemetry is active
|
||||
assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None
|
||||
|
||||
async def test_completion_no_injection_when_provider_doesnt_support_stream_options(
|
||||
self, mixin, mock_client_context
|
||||
):
|
||||
"""Test that stream_options is NOT injected for completion when provider doesn't support it"""
|
||||
# Set supports_stream_options to False (like Ollama/vLLM)
|
||||
mixin.supports_stream_options = False
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.completions.create = AsyncMock(return_value=MagicMock())
|
||||
|
||||
# Mock OpenTelemetry span as recording (telemetry is active)
|
||||
mock_span = MagicMock()
|
||||
mock_span.is_recording.return_value = True
|
||||
|
||||
with mock_client_context(mixin, mock_client):
|
||||
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
await mixin.openai_completion(
|
||||
OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello", stream=True)
|
||||
)
|
||||
|
||||
call_kwargs = mock_client.completions.create.call_args[1]
|
||||
# Should NOT inject stream_options even though telemetry is active
|
||||
assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue