diff --git a/src/llama_stack/core/routers/inference.py b/src/llama_stack/core/routers/inference.py index 719624e86..c7a3d9a93 100644 --- a/src/llama_stack/core/routers/inference.py +++ b/src/llama_stack/core/routers/inference.py @@ -176,7 +176,7 @@ class InferenceRouter(Inference): async def openai_completion( self, params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)], - ) -> OpenAICompletion: + ) -> OpenAICompletion | AsyncIterator[Any]: logger.debug( f"InferenceRouter.openai_completion: model={params.model}, stream={params.stream}, prompt={params.prompt}", ) @@ -185,9 +185,12 @@ class InferenceRouter(Inference): params.model = provider_resource_id if params.stream: - return await provider.openai_completion(params) - # TODO: Metrics do NOT work with openai_completion stream=True due to the fact - # that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently. + response_stream = await provider.openai_completion(params) + return self.wrap_completion_stream_with_metrics( + response=response_stream, + fully_qualified_model_id=request_model_id, + provider_id=provider.__provider_id__, + ) response = await provider.openai_completion(params) response.model = request_model_id @@ -412,16 +415,17 @@ class InferenceRouter(Inference): completion_text += "".join(choice_data["content_parts"]) # Add metrics to the chunk - if self.telemetry_enabled and hasattr(chunk, "usage") and chunk.usage: - metrics = self._construct_metrics( - prompt_tokens=chunk.usage.prompt_tokens, - completion_tokens=chunk.usage.completion_tokens, - total_tokens=chunk.usage.total_tokens, - fully_qualified_model_id=fully_qualified_model_id, - provider_id=provider_id, - ) - for metric in metrics: - enqueue_event(metric) + if self.telemetry_enabled: + if hasattr(chunk, "usage") and chunk.usage: + metrics = self._construct_metrics( + prompt_tokens=chunk.usage.prompt_tokens, + completion_tokens=chunk.usage.completion_tokens, + total_tokens=chunk.usage.total_tokens, + fully_qualified_model_id=fully_qualified_model_id, + provider_id=provider_id, + ) + for metric in metrics: + enqueue_event(metric) yield chunk finally: @@ -471,3 +475,31 @@ class InferenceRouter(Inference): ) logger.debug(f"InferenceRouter.completion_response: {final_response}") asyncio.create_task(self.store.store_chat_completion(final_response, messages)) + + async def wrap_completion_stream_with_metrics( + self, + response: AsyncIterator, + fully_qualified_model_id: str, + provider_id: str, + ) -> AsyncIterator: + """Stream OpenAI completion chunks and compute metrics on final chunk.""" + + async for chunk in response: + if hasattr(chunk, "model"): + chunk.model = fully_qualified_model_id + + if getattr(chunk, "choices", None) and any(c.finish_reason for c in chunk.choices): + if self.telemetry_enabled: + if getattr(chunk, "usage", None): + usage = chunk.usage + metrics = self._construct_metrics( + prompt_tokens=usage.prompt_tokens, + completion_tokens=usage.completion_tokens, + total_tokens=usage.total_tokens, + fully_qualified_model_id=fully_qualified_model_id, + provider_id=provider_id, + ) + for metric in metrics: + enqueue_event(metric) + + yield chunk diff --git a/src/llama_stack/providers/remote/inference/bedrock/bedrock.py b/src/llama_stack/providers/remote/inference/bedrock/bedrock.py index 451549db8..2a6eb25e0 100644 --- a/src/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/src/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -8,7 +8,6 @@ from collections.abc import AsyncIterator, Iterable from openai import AuthenticationError -from llama_stack.core.telemetry.tracing import get_current_span from llama_stack.log import get_logger from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack_api import ( @@ -82,14 +81,7 @@ class BedrockInferenceAdapter(OpenAIMixin): self, params: OpenAIChatCompletionRequestWithExtraBody, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - """Override to enable streaming usage metrics and handle authentication errors.""" - # Enable streaming usage metrics when telemetry is active - if params.stream and get_current_span() is not None: - if params.stream_options is None: - params.stream_options = {"include_usage": True} - elif "include_usage" not in params.stream_options: - params.stream_options = {**params.stream_options, "include_usage": True} - + """Override to handle authentication errors and null responses.""" try: logger.debug(f"Calling Bedrock OpenAI API with model={params.model}, stream={params.stream}") result = await super().openai_chat_completion(params=params) diff --git a/src/llama_stack/providers/remote/inference/runpod/runpod.py b/src/llama_stack/providers/remote/inference/runpod/runpod.py index 04ad12851..6a0dbd82d 100644 --- a/src/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/src/llama_stack/providers/remote/inference/runpod/runpod.py @@ -4,14 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from collections.abc import AsyncIterator - from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin -from llama_stack_api import ( - OpenAIChatCompletion, - OpenAIChatCompletionChunk, - OpenAIChatCompletionRequestWithExtraBody, -) from .config import RunpodImplConfig @@ -29,15 +22,3 @@ class RunpodInferenceAdapter(OpenAIMixin): def get_base_url(self) -> str: """Get base URL for OpenAI client.""" return str(self.config.base_url) - - async def openai_chat_completion( - self, - params: OpenAIChatCompletionRequestWithExtraBody, - ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - """Override to add RunPod-specific stream_options requirement.""" - params = params.model_copy() - - if params.stream and not params.stream_options: - params.stream_options = {"include_usage": True} - - return await super().openai_chat_completion(params) diff --git a/src/llama_stack/providers/remote/inference/watsonx/watsonx.py b/src/llama_stack/providers/remote/inference/watsonx/watsonx.py index 5684f6c17..15ba244c1 100644 --- a/src/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/src/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -10,7 +10,6 @@ from typing import Any import litellm import requests -from llama_stack.core.telemetry.tracing import get_current_span from llama_stack.log import get_logger from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin @@ -56,15 +55,6 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin): Override parent method to add timeout and inject usage object when missing. This works around a LiteLLM defect where usage block is sometimes dropped. """ - - # Add usage tracking for streaming when telemetry is active - stream_options = params.stream_options - if params.stream and get_current_span() is not None: - if stream_options is None: - stream_options = {"include_usage": True} - elif "include_usage" not in stream_options: - stream_options = {**stream_options, "include_usage": True} - model_obj = await self.model_store.get_model(params.model) request_params = await prepare_openai_completion_params( @@ -84,7 +74,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin): seed=params.seed, stop=params.stop, stream=params.stream, - stream_options=stream_options, + stream_options=params.stream_options, temperature=params.temperature, tool_choice=params.tool_choice, tools=params.tools, diff --git a/src/llama_stack/providers/utils/inference/openai_mixin.py b/src/llama_stack/providers/utils/inference/openai_mixin.py index 30511a341..5f44f2034 100644 --- a/src/llama_stack/providers/utils/inference/openai_mixin.py +++ b/src/llama_stack/providers/utils/inference/openai_mixin.py @@ -271,6 +271,16 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): """ Direct OpenAI completion API call. """ + from llama_stack.core.telemetry.tracing import get_current_span + + # inject if streaming AND telemetry active + if params.stream and get_current_span() is not None: + params = params.model_copy() + if params.stream_options is None: + params.stream_options = {"include_usage": True} + elif "include_usage" not in params.stream_options: + params.stream_options = {**params.stream_options, "include_usage": True} + # TODO: fix openai_completion to return type compatible with OpenAI's API response provider_model_id = await self._get_provider_model_id(params.model) self._validate_model_allowed(provider_model_id) @@ -308,6 +318,16 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): """ Direct OpenAI chat completion API call. """ + from llama_stack.core.telemetry.tracing import get_current_span + + # inject if streaming AND telemetry active + if params.stream and get_current_span() is not None: + params = params.model_copy() + if params.stream_options is None: + params.stream_options = {"include_usage": True} + elif "include_usage" not in params.stream_options: + params.stream_options = {**params.stream_options, "include_usage": True} + provider_model_id = await self._get_provider_model_id(params.model) self._validate_model_allowed(provider_model_id) diff --git a/tests/unit/providers/utils/inference/test_openai_mixin.py b/tests/unit/providers/utils/inference/test_openai_mixin.py index 02d44f2ba..534eb8a10 100644 --- a/tests/unit/providers/utils/inference/test_openai_mixin.py +++ b/tests/unit/providers/utils/inference/test_openai_mixin.py @@ -934,3 +934,146 @@ class TestOpenAIMixinAllowedModelsInference: model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")] ) ) + + +class TestOpenAIMixinStreamingMetrics: + """Test cases for streaming metrics injection in OpenAIMixin""" + + async def test_openai_chat_completion_streaming_metrics_injection(self, mixin, mock_client_context): + """Test that stream_options={"include_usage": True} is injected when streaming and telemetry is enabled""" + + params = OpenAIChatCompletionRequestWithExtraBody( + model="test-model", + messages=[{"role": "user", "content": "hello"}], + stream=True, + stream_options=None, + ) + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + + with mock_client_context(mixin, mock_client): + with patch("llama_stack.core.telemetry.tracing.get_current_span") as mock_get_span: + mock_get_span.return_value = MagicMock() + + with patch( + "llama_stack.providers.utils.inference.openai_mixin.prepare_openai_completion_params" + ) as mock_prepare: + mock_prepare.return_value = {"model": "test-model"} + + await mixin.openai_chat_completion(params) + + call_kwargs = mock_prepare.call_args.kwargs + assert call_kwargs["stream_options"] == {"include_usage": True} + + assert params.stream_options is None + + async def test_openai_chat_completion_streaming_no_telemetry(self, mixin, mock_client_context): + """Test that stream_options is NOT injected when telemetry is disabled""" + + params = OpenAIChatCompletionRequestWithExtraBody( + model="test-model", + messages=[{"role": "user", "content": "hello"}], + stream=True, + stream_options=None, + ) + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + + with mock_client_context(mixin, mock_client): + with patch("llama_stack.core.telemetry.tracing.get_current_span") as mock_get_span: + mock_get_span.return_value = None + + with patch( + "llama_stack.providers.utils.inference.openai_mixin.prepare_openai_completion_params" + ) as mock_prepare: + mock_prepare.return_value = {"model": "test-model"} + + await mixin.openai_chat_completion(params) + + call_kwargs = mock_prepare.call_args.kwargs + assert call_kwargs["stream_options"] is None + + async def test_openai_completion_streaming_metrics_injection(self, mixin, mock_client_context): + """Test that stream_options={"include_usage": True} is injected for legacy completion""" + + params = OpenAICompletionRequestWithExtraBody( + model="test-model", + prompt="hello", + stream=True, + stream_options=None, + ) + + mock_client = MagicMock() + mock_client.completions.create = AsyncMock(return_value=MagicMock()) + + with mock_client_context(mixin, mock_client): + with patch("llama_stack.core.telemetry.tracing.get_current_span") as mock_get_span: + mock_get_span.return_value = MagicMock() + + with patch( + "llama_stack.providers.utils.inference.openai_mixin.prepare_openai_completion_params" + ) as mock_prepare: + mock_prepare.return_value = {"model": "test-model"} + + await mixin.openai_completion(params) + + call_kwargs = mock_prepare.call_args.kwargs + assert call_kwargs["stream_options"] == {"include_usage": True} + assert params.stream_options is None + + async def test_preserves_existing_stream_options(self, mixin, mock_client_context): + """Test that existing stream_options are preserved and merged""" + + params = OpenAIChatCompletionRequestWithExtraBody( + model="test-model", + messages=[{"role": "user", "content": "hello"}], + stream=True, + stream_options={"include_usage": False}, + ) + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + + with mock_client_context(mixin, mock_client): + with patch("llama_stack.core.telemetry.tracing.get_current_span") as mock_get_span: + mock_get_span.return_value = MagicMock() + + with patch( + "llama_stack.providers.utils.inference.openai_mixin.prepare_openai_completion_params" + ) as mock_prepare: + mock_prepare.return_value = {"model": "test-model"} + + await mixin.openai_chat_completion(params) + + call_kwargs = mock_prepare.call_args.kwargs + # It should stay False because it was present + assert call_kwargs["stream_options"] == {"include_usage": False} + + async def test_merges_existing_stream_options(self, mixin, mock_client_context): + """Test that existing stream_options are merged""" + + params = OpenAIChatCompletionRequestWithExtraBody( + model="test-model", + messages=[{"role": "user", "content": "hello"}], + stream=True, + stream_options={"other_option": True}, + ) + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + + with mock_client_context(mixin, mock_client): + with patch("llama_stack.core.telemetry.tracing.get_current_span") as mock_get_span: + mock_get_span.return_value = MagicMock() + + with patch( + "llama_stack.providers.utils.inference.openai_mixin.prepare_openai_completion_params" + ) as mock_prepare: + mock_prepare.return_value = {"model": "test-model"} + + await mixin.openai_chat_completion(params) + + call_kwargs = mock_prepare.call_args.kwargs + assert call_kwargs["stream_options"] == {"other_option": True, "include_usage": True}