From 57f8f6d3af370e77d779897cae02a1b525bcbd5c Mon Sep 17 00:00:00 2001 From: r-bit-rry Date: Wed, 26 Nov 2025 10:55:28 +0200 Subject: [PATCH 1/3] fix(server.py): check attr sse_generator returned object --- src/llama_stack/core/server/server.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/llama_stack/core/server/server.py b/src/llama_stack/core/server/server.py index 0d3513980..4c391d7f1 100644 --- a/src/llama_stack/core/server/server.py +++ b/src/llama_stack/core/server/server.py @@ -205,7 +205,11 @@ async def sse_generator(event_gen_coroutine): except asyncio.CancelledError: logger.info("Generator cancelled") if event_gen: - await event_gen.aclose() + # Some generators (like OpenAI's AsyncStream) only have close() + if hasattr(event_gen, "aclose"): + await event_gen.aclose() + elif hasattr(event_gen, "close"): + await event_gen.close() except Exception as e: logger.exception("Error in sse_generator") yield create_sse_event( From c3c9edf9814b7ac185181184b926b45520c3eca7 Mon Sep 17 00:00:00 2001 From: r-bit-rry Date: Sun, 30 Nov 2025 17:31:07 +0200 Subject: [PATCH 2/3] further fixes according to investigation and PR comments --- docs/static/llama-stack-spec.yaml | 5 +- docs/static/stainless-llama-stack-spec.yaml | 5 +- src/llama_stack/core/server/server.py | 6 +- .../inference/meta_reference/inference.py | 2 +- .../remote/inference/bedrock/bedrock.py | 2 +- .../remote/inference/databricks/databricks.py | 4 +- .../inference/llama_openai_compat/llama.py | 4 +- .../inference/passthrough/passthrough.py | 15 +- .../remote/inference/watsonx/watsonx.py | 10 +- .../utils/inference/litellm_openai_mixin.py | 17 ++- .../providers/utils/inference/openai_mixin.py | 14 +- .../providers/utils/inference/stream_utils.py | 22 +++ src/llama_stack_api/inference.py | 4 +- .../utils/test_openai_mixin_streaming.py | 132 ++++++++++++++++++ 14 files changed, 212 insertions(+), 30 deletions(-) create mode 100644 src/llama_stack/providers/utils/inference/stream_utils.py create mode 100644 tests/unit/providers/utils/test_openai_mixin_streaming.py diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index 9f7b2ed64..5ba6b826a 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -289,11 +289,14 @@ paths: post: responses: '200': - description: An OpenAICompletion. + description: An OpenAICompletion or an async iterator of OpenAICompletion chunks when streaming. content: application/json: schema: $ref: '#/components/schemas/OpenAICompletion' + text/event-stream: + schema: + $ref: '#/components/schemas/OpenAICompletion' '400': description: Bad Request $ref: '#/components/responses/BadRequest400' diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index 9269b7e39..7feffbf31 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -291,11 +291,14 @@ paths: post: responses: '200': - description: An OpenAICompletion. + description: An OpenAICompletion or an async iterator of OpenAICompletion chunks when streaming. content: application/json: schema: $ref: '#/components/schemas/OpenAICompletion' + text/event-stream: + schema: + $ref: '#/components/schemas/OpenAICompletion' '400': description: Bad Request $ref: '#/components/responses/BadRequest400' diff --git a/src/llama_stack/core/server/server.py b/src/llama_stack/core/server/server.py index 4c391d7f1..0d3513980 100644 --- a/src/llama_stack/core/server/server.py +++ b/src/llama_stack/core/server/server.py @@ -205,11 +205,7 @@ async def sse_generator(event_gen_coroutine): except asyncio.CancelledError: logger.info("Generator cancelled") if event_gen: - # Some generators (like OpenAI's AsyncStream) only have close() - if hasattr(event_gen, "aclose"): - await event_gen.aclose() - elif hasattr(event_gen, "close"): - await event_gen.close() + await event_gen.aclose() except Exception as e: logger.exception("Error in sse_generator") yield create_sse_event( diff --git a/src/llama_stack/providers/inline/inference/meta_reference/inference.py b/src/llama_stack/providers/inline/inference/meta_reference/inference.py index 42d1299ab..58c2d66af 100644 --- a/src/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/src/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -246,7 +246,7 @@ class MetaReferenceInferenceImpl( async def openai_completion( self, params: OpenAICompletionRequestWithExtraBody, - ) -> OpenAICompletion: + ) -> OpenAICompletion | AsyncIterator[OpenAICompletion]: raise NotImplementedError("OpenAI completion not supported by meta reference provider") async def should_refresh_models(self) -> bool: diff --git a/src/llama_stack/providers/remote/inference/bedrock/bedrock.py b/src/llama_stack/providers/remote/inference/bedrock/bedrock.py index 451549db8..49b9496e1 100644 --- a/src/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/src/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -70,7 +70,7 @@ class BedrockInferenceAdapter(OpenAIMixin): async def openai_completion( self, params: OpenAICompletionRequestWithExtraBody, - ) -> OpenAICompletion: + ) -> OpenAICompletion | AsyncIterator[OpenAICompletion]: """Bedrock's OpenAI-compatible API does not support the /v1/completions endpoint.""" raise NotImplementedError( "Bedrock's OpenAI-compatible API does not support /v1/completions endpoint. " diff --git a/src/llama_stack/providers/remote/inference/databricks/databricks.py b/src/llama_stack/providers/remote/inference/databricks/databricks.py index f2f8832f6..a465a536e 100644 --- a/src/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/src/llama_stack/providers/remote/inference/databricks/databricks.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from collections.abc import Iterable +from collections.abc import AsyncIterator, Iterable from databricks.sdk import WorkspaceClient @@ -50,5 +50,5 @@ class DatabricksInferenceAdapter(OpenAIMixin): async def openai_completion( self, params: OpenAICompletionRequestWithExtraBody, - ) -> OpenAICompletion: + ) -> OpenAICompletion | AsyncIterator[OpenAICompletion]: raise NotImplementedError() diff --git a/src/llama_stack/providers/remote/inference/llama_openai_compat/llama.py b/src/llama_stack/providers/remote/inference/llama_openai_compat/llama.py index f29aebf36..d09a91cc1 100644 --- a/src/llama_stack/providers/remote/inference/llama_openai_compat/llama.py +++ b/src/llama_stack/providers/remote/inference/llama_openai_compat/llama.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from collections.abc import AsyncIterator + from llama_stack.log import get_logger from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin @@ -36,7 +38,7 @@ class LlamaCompatInferenceAdapter(OpenAIMixin): async def openai_completion( self, params: OpenAICompletionRequestWithExtraBody, - ) -> OpenAICompletion: + ) -> OpenAICompletion | AsyncIterator[OpenAICompletion]: raise NotImplementedError() async def openai_embeddings( diff --git a/src/llama_stack/providers/remote/inference/passthrough/passthrough.py b/src/llama_stack/providers/remote/inference/passthrough/passthrough.py index b0e2e74ad..028dea613 100644 --- a/src/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/src/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -9,6 +9,7 @@ from collections.abc import AsyncIterator from openai import AsyncOpenAI from llama_stack.core.request_headers import NeedsRequestProviderData +from llama_stack.providers.utils.inference.stream_utils import wrap_async_stream from llama_stack_api import ( Inference, Model, @@ -107,12 +108,16 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference): async def openai_completion( self, params: OpenAICompletionRequestWithExtraBody, - ) -> OpenAICompletion: + ) -> OpenAICompletion | AsyncIterator[OpenAICompletion]: """Forward completion request to downstream using OpenAI client.""" client = self._get_openai_client() request_params = params.model_dump(exclude_none=True) response = await client.completions.create(**request_params) - return response # type: ignore + + if params.stream: + return wrap_async_stream(response) + + return response # type: ignore[return-value] async def openai_chat_completion( self, @@ -122,7 +127,11 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference): client = self._get_openai_client() request_params = params.model_dump(exclude_none=True) response = await client.chat.completions.create(**request_params) - return response # type: ignore + + if params.stream: + return wrap_async_stream(response) + + return response # type: ignore[return-value] async def openai_embeddings( self, diff --git a/src/llama_stack/providers/remote/inference/watsonx/watsonx.py b/src/llama_stack/providers/remote/inference/watsonx/watsonx.py index 5684f6c17..761af7648 100644 --- a/src/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/src/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -15,6 +15,7 @@ from llama_stack.log import get_logger from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params +from llama_stack.providers.utils.inference.stream_utils import wrap_async_stream from llama_stack_api import ( Model, ModelType, @@ -178,7 +179,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin): async def openai_completion( self, params: OpenAICompletionRequestWithExtraBody, - ) -> OpenAICompletion: + ) -> OpenAICompletion | AsyncIterator[OpenAICompletion]: """ Override parent method to add watsonx-specific parameters. """ @@ -211,7 +212,12 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin): timeout=self.config.timeout, project_id=self.config.project_id, ) - return await litellm.atext_completion(**request_params) + result = await litellm.atext_completion(**request_params) + + if params.stream: + return wrap_async_stream(result) # type: ignore[arg-type] # LiteLLM streaming types + + return result # type: ignore[return-value] # external lib lacks type stubs async def openai_embeddings( self, diff --git a/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py index c462d1aad..78cd27db4 100644 --- a/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -16,6 +16,7 @@ from llama_stack.providers.utils.inference.model_registry import ModelRegistryHe from llama_stack.providers.utils.inference.openai_compat import ( prepare_openai_completion_params, ) +from llama_stack.providers.utils.inference.stream_utils import wrap_async_stream from llama_stack_api import ( InferenceProvider, OpenAIChatCompletion, @@ -178,7 +179,7 @@ class LiteLLMOpenAIMixin( async def openai_completion( self, params: OpenAICompletionRequestWithExtraBody, - ) -> OpenAICompletion: + ) -> OpenAICompletion | AsyncIterator[OpenAICompletion]: if not self.model_store: raise ValueError("Model store is not initialized") @@ -210,7 +211,12 @@ class LiteLLMOpenAIMixin( api_base=self.api_base, ) # LiteLLM returns compatible type but mypy can't verify external library - return await litellm.atext_completion(**request_params) # type: ignore[no-any-return] # external lib lacks type stubs + result = await litellm.atext_completion(**request_params) + + if params.stream: + return wrap_async_stream(result) # type: ignore[arg-type] # LiteLLM streaming types + + return result # type: ignore[return-value] # external lib lacks type stubs async def openai_chat_completion( self, @@ -262,7 +268,12 @@ class LiteLLMOpenAIMixin( api_base=self.api_base, ) # LiteLLM returns compatible type but mypy can't verify external library - return await litellm.acompletion(**request_params) # type: ignore[no-any-return] # external lib lacks type stubs + result = await litellm.acompletion(**request_params) + + if params.stream: + return wrap_async_stream(result) # type: ignore[arg-type] # LiteLLM streaming types + + return result # type: ignore[return-value] # external lib lacks type stubs async def check_model_availability(self, model: str) -> bool: """ diff --git a/src/llama_stack/providers/utils/inference/openai_mixin.py b/src/llama_stack/providers/utils/inference/openai_mixin.py index 30511a341..499379214 100644 --- a/src/llama_stack/providers/utils/inference/openai_mixin.py +++ b/src/llama_stack/providers/utils/inference/openai_mixin.py @@ -248,30 +248,28 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): return model_obj.provider_resource_id async def _maybe_overwrite_id(self, resp: Any, stream: bool | None) -> Any: - if not self.overwrite_completion_id: - return resp - - new_id = f"cltsd-{uuid.uuid4()}" if stream: + new_id = f"cltsd-{uuid.uuid4()}" if self.overwrite_completion_id else None async def _gen(): async for chunk in resp: - chunk.id = new_id + if new_id: + chunk.id = new_id yield chunk return _gen() else: - resp.id = new_id + if self.overwrite_completion_id: + resp.id = f"cltsd-{uuid.uuid4()}" return resp async def openai_completion( self, params: OpenAICompletionRequestWithExtraBody, - ) -> OpenAICompletion: + ) -> OpenAICompletion | AsyncIterator[OpenAICompletion]: """ Direct OpenAI completion API call. """ - # TODO: fix openai_completion to return type compatible with OpenAI's API response provider_model_id = await self._get_provider_model_id(params.model) self._validate_model_allowed(provider_model_id) diff --git a/src/llama_stack/providers/utils/inference/stream_utils.py b/src/llama_stack/providers/utils/inference/stream_utils.py new file mode 100644 index 000000000..3a9292893 --- /dev/null +++ b/src/llama_stack/providers/utils/inference/stream_utils.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import logging +from collections.abc import AsyncIterator + +logger = logging.getLogger(__name__) + + +async def wrap_async_stream[T](stream: AsyncIterator[T]) -> AsyncIterator[T]: + """ + Wrap an async stream to ensure it returns a proper AsyncIterator. + """ + try: + async for item in stream: + yield item + except Exception as e: + logger.error(f"Error in wrapped async stream: {e}") + raise diff --git a/src/llama_stack_api/inference.py b/src/llama_stack_api/inference.py index b42de95be..8daff3cd9 100644 --- a/src/llama_stack_api/inference.py +++ b/src/llama_stack_api/inference.py @@ -1022,11 +1022,11 @@ class InferenceProvider(Protocol): async def openai_completion( self, params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)], - ) -> OpenAICompletion: + ) -> OpenAICompletion | AsyncIterator[OpenAICompletion]: """Create completion. Generate an OpenAI-compatible completion for the given prompt using the specified model. - :returns: An OpenAICompletion. + :returns: An OpenAICompletion or an async iterator of OpenAICompletion chunks when streaming. """ ... diff --git a/tests/unit/providers/utils/test_openai_mixin_streaming.py b/tests/unit/providers/utils/test_openai_mixin_streaming.py new file mode 100644 index 000000000..e2a6d7a92 --- /dev/null +++ b/tests/unit/providers/utils/test_openai_mixin_streaming.py @@ -0,0 +1,132 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Regression tests for issue #3185: AsyncStream passed where AsyncIterator expected. + +The bug: OpenAI SDK's AsyncStream has close(), not aclose(), but Python's +AsyncIterator protocol requires aclose(). The fix ensures _maybe_overwrite_id() +always wraps streaming responses in an async generator. +""" + +import inspect +from collections.abc import AsyncIterator +from unittest.mock import MagicMock + +import pytest + +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin + + +class MockAsyncStream: + """Simulates OpenAI SDK's AsyncStream: has close() but NOT aclose().""" + + def __init__(self, chunks): + self.chunks = chunks + self._iter = iter(chunks) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self._iter) + except StopIteration as e: + raise StopAsyncIteration from e + + async def close(self): + pass + + +class MockChunk: + def __init__(self, chunk_id: str, content: str = "test"): + self.id = chunk_id + self.content = content + + +class OpenAIMixinTestImpl(OpenAIMixin): + __provider_id__: str = "test-provider" + + def get_api_key(self) -> str: + return "test-api-key" + + def get_base_url(self) -> str: + return "http://test-base-url" + + +@pytest.fixture +def mixin(): + config = RemoteInferenceProviderConfig() + m = OpenAIMixinTestImpl(config=config) + m.overwrite_completion_id = False + return m + + +class TestIssue3185Regression: + + @pytest.mark.asyncio + async def test_streaming_result_has_aclose(self, mixin): + mock_stream = MockAsyncStream([MockChunk("1")]) + + assert not hasattr(mock_stream, "aclose") + + result = await mixin._maybe_overwrite_id(mock_stream, stream=True) + + assert hasattr(result, "aclose"), "Result MUST have aclose() for AsyncIterator" + assert inspect.isasyncgen(result) + assert isinstance(result, AsyncIterator) + + @pytest.mark.asyncio + async def test_streaming_yields_all_chunks(self, mixin): + chunks = [MockChunk("1", "a"), MockChunk("2", "b")] + mock_stream = MockAsyncStream(chunks) + + result = await mixin._maybe_overwrite_id(mock_stream, stream=True) + + received = [c async for c in result] + assert len(received) == 2 + assert received[0].content == "a" + assert received[1].content == "b" + + @pytest.mark.asyncio + async def test_non_streaming_returns_directly(self, mixin): + mock_response = MagicMock() + mock_response.id = "test-id" + + result = await mixin._maybe_overwrite_id(mock_response, stream=False) + + assert result is mock_response + assert not inspect.isasyncgen(result) + + +class TestIdOverwriting: + + @pytest.mark.asyncio + async def test_ids_overwritten_when_enabled(self): + config = RemoteInferenceProviderConfig() + mixin = OpenAIMixinTestImpl(config=config) + mixin.overwrite_completion_id = True + + chunks = [MockChunk("orig-1"), MockChunk("orig-2")] + result = await mixin._maybe_overwrite_id(MockAsyncStream(chunks), stream=True) + + received = [c async for c in result] + assert all(c.id.startswith("cltsd-") for c in received) + assert received[0].id == received[1].id # Same ID for all chunks + + @pytest.mark.asyncio + async def test_ids_preserved_when_disabled(self): + config = RemoteInferenceProviderConfig() + mixin = OpenAIMixinTestImpl(config=config) + mixin.overwrite_completion_id = False + + chunks = [MockChunk("orig-1"), MockChunk("orig-2")] + result = await mixin._maybe_overwrite_id(MockAsyncStream(chunks), stream=True) + + received = [c async for c in result] + assert received[0].id == "orig-1" + assert received[1].id == "orig-2" From dc4c7eaed7de77aed79d72cfe6caf5cee22943e7 Mon Sep 17 00:00:00 2001 From: r-bit-rry Date: Sun, 30 Nov 2025 17:45:47 +0200 Subject: [PATCH 3/3] fix for the logger --- src/llama_stack/providers/utils/inference/stream_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/llama_stack/providers/utils/inference/stream_utils.py b/src/llama_stack/providers/utils/inference/stream_utils.py index 3a9292893..a15e29d08 100644 --- a/src/llama_stack/providers/utils/inference/stream_utils.py +++ b/src/llama_stack/providers/utils/inference/stream_utils.py @@ -4,10 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging from collections.abc import AsyncIterator -logger = logging.getLogger(__name__) +from llama_stack.log import get_logger + +log = get_logger(name=__name__, category="providers::utils") async def wrap_async_stream[T](stream: AsyncIterator[T]) -> AsyncIterator[T]: @@ -18,5 +19,5 @@ async def wrap_async_stream[T](stream: AsyncIterator[T]) -> AsyncIterator[T]: async for item in stream: yield item except Exception as e: - logger.error(f"Error in wrapped async stream: {e}") + log.error(f"Error in wrapped async stream: {e}") raise