further fixes according to investigation and PR comments

This commit is contained in:
r-bit-rry 2025-11-30 17:31:07 +02:00
parent 9b3c041af0
commit c3c9edf981
14 changed files with 212 additions and 30 deletions

View file

@ -289,11 +289,14 @@ paths:
post: post:
responses: responses:
'200': '200':
description: An OpenAICompletion. description: An OpenAICompletion or an async iterator of OpenAICompletion chunks when streaming.
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/OpenAICompletion' $ref: '#/components/schemas/OpenAICompletion'
text/event-stream:
schema:
$ref: '#/components/schemas/OpenAICompletion'
'400': '400':
description: Bad Request description: Bad Request
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'

View file

@ -291,11 +291,14 @@ paths:
post: post:
responses: responses:
'200': '200':
description: An OpenAICompletion. description: An OpenAICompletion or an async iterator of OpenAICompletion chunks when streaming.
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/OpenAICompletion' $ref: '#/components/schemas/OpenAICompletion'
text/event-stream:
schema:
$ref: '#/components/schemas/OpenAICompletion'
'400': '400':
description: Bad Request description: Bad Request
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'

View file

@ -205,11 +205,7 @@ async def sse_generator(event_gen_coroutine):
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info("Generator cancelled") logger.info("Generator cancelled")
if event_gen: if event_gen:
# Some generators (like OpenAI's AsyncStream) only have close() await event_gen.aclose()
if hasattr(event_gen, "aclose"):
await event_gen.aclose()
elif hasattr(event_gen, "close"):
await event_gen.close()
except Exception as e: except Exception as e:
logger.exception("Error in sse_generator") logger.exception("Error in sse_generator")
yield create_sse_event( yield create_sse_event(

View file

@ -246,7 +246,7 @@ class MetaReferenceInferenceImpl(
async def openai_completion( async def openai_completion(
self, self,
params: OpenAICompletionRequestWithExtraBody, params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion: ) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
raise NotImplementedError("OpenAI completion not supported by meta reference provider") raise NotImplementedError("OpenAI completion not supported by meta reference provider")
async def should_refresh_models(self) -> bool: async def should_refresh_models(self) -> bool:

View file

@ -70,7 +70,7 @@ class BedrockInferenceAdapter(OpenAIMixin):
async def openai_completion( async def openai_completion(
self, self,
params: OpenAICompletionRequestWithExtraBody, params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion: ) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
"""Bedrock's OpenAI-compatible API does not support the /v1/completions endpoint.""" """Bedrock's OpenAI-compatible API does not support the /v1/completions endpoint."""
raise NotImplementedError( raise NotImplementedError(
"Bedrock's OpenAI-compatible API does not support /v1/completions endpoint. " "Bedrock's OpenAI-compatible API does not support /v1/completions endpoint. "

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import Iterable from collections.abc import AsyncIterator, Iterable
from databricks.sdk import WorkspaceClient from databricks.sdk import WorkspaceClient
@ -50,5 +50,5 @@ class DatabricksInferenceAdapter(OpenAIMixin):
async def openai_completion( async def openai_completion(
self, self,
params: OpenAICompletionRequestWithExtraBody, params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion: ) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
raise NotImplementedError() raise NotImplementedError()

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import AsyncIterator
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -36,7 +38,7 @@ class LlamaCompatInferenceAdapter(OpenAIMixin):
async def openai_completion( async def openai_completion(
self, self,
params: OpenAICompletionRequestWithExtraBody, params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion: ) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
raise NotImplementedError() raise NotImplementedError()
async def openai_embeddings( async def openai_embeddings(

View file

@ -9,6 +9,7 @@ from collections.abc import AsyncIterator
from openai import AsyncOpenAI from openai import AsyncOpenAI
from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.stream_utils import wrap_async_stream
from llama_stack_api import ( from llama_stack_api import (
Inference, Inference,
Model, Model,
@ -107,12 +108,16 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
async def openai_completion( async def openai_completion(
self, self,
params: OpenAICompletionRequestWithExtraBody, params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion: ) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
"""Forward completion request to downstream using OpenAI client.""" """Forward completion request to downstream using OpenAI client."""
client = self._get_openai_client() client = self._get_openai_client()
request_params = params.model_dump(exclude_none=True) request_params = params.model_dump(exclude_none=True)
response = await client.completions.create(**request_params) response = await client.completions.create(**request_params)
return response # type: ignore
if params.stream:
return wrap_async_stream(response)
return response # type: ignore[return-value]
async def openai_chat_completion( async def openai_chat_completion(
self, self,
@ -122,7 +127,11 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
client = self._get_openai_client() client = self._get_openai_client()
request_params = params.model_dump(exclude_none=True) request_params = params.model_dump(exclude_none=True)
response = await client.chat.completions.create(**request_params) response = await client.chat.completions.create(**request_params)
return response # type: ignore
if params.stream:
return wrap_async_stream(response)
return response # type: ignore[return-value]
async def openai_embeddings( async def openai_embeddings(
self, self,

View file

@ -15,6 +15,7 @@ from llama_stack.log import get_logger
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
from llama_stack.providers.utils.inference.stream_utils import wrap_async_stream
from llama_stack_api import ( from llama_stack_api import (
Model, Model,
ModelType, ModelType,
@ -178,7 +179,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
async def openai_completion( async def openai_completion(
self, self,
params: OpenAICompletionRequestWithExtraBody, params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion: ) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
""" """
Override parent method to add watsonx-specific parameters. Override parent method to add watsonx-specific parameters.
""" """
@ -211,7 +212,12 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
timeout=self.config.timeout, timeout=self.config.timeout,
project_id=self.config.project_id, project_id=self.config.project_id,
) )
return await litellm.atext_completion(**request_params) result = await litellm.atext_completion(**request_params)
if params.stream:
return wrap_async_stream(result) # type: ignore[arg-type] # LiteLLM streaming types
return result # type: ignore[return-value] # external lib lacks type stubs
async def openai_embeddings( async def openai_embeddings(
self, self,

View file

@ -16,6 +16,7 @@ from llama_stack.providers.utils.inference.model_registry import ModelRegistryHe
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
prepare_openai_completion_params, prepare_openai_completion_params,
) )
from llama_stack.providers.utils.inference.stream_utils import wrap_async_stream
from llama_stack_api import ( from llama_stack_api import (
InferenceProvider, InferenceProvider,
OpenAIChatCompletion, OpenAIChatCompletion,
@ -178,7 +179,7 @@ class LiteLLMOpenAIMixin(
async def openai_completion( async def openai_completion(
self, self,
params: OpenAICompletionRequestWithExtraBody, params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion: ) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
if not self.model_store: if not self.model_store:
raise ValueError("Model store is not initialized") raise ValueError("Model store is not initialized")
@ -210,7 +211,12 @@ class LiteLLMOpenAIMixin(
api_base=self.api_base, api_base=self.api_base,
) )
# LiteLLM returns compatible type but mypy can't verify external library # LiteLLM returns compatible type but mypy can't verify external library
return await litellm.atext_completion(**request_params) # type: ignore[no-any-return] # external lib lacks type stubs result = await litellm.atext_completion(**request_params)
if params.stream:
return wrap_async_stream(result) # type: ignore[arg-type] # LiteLLM streaming types
return result # type: ignore[return-value] # external lib lacks type stubs
async def openai_chat_completion( async def openai_chat_completion(
self, self,
@ -262,7 +268,12 @@ class LiteLLMOpenAIMixin(
api_base=self.api_base, api_base=self.api_base,
) )
# LiteLLM returns compatible type but mypy can't verify external library # LiteLLM returns compatible type but mypy can't verify external library
return await litellm.acompletion(**request_params) # type: ignore[no-any-return] # external lib lacks type stubs result = await litellm.acompletion(**request_params)
if params.stream:
return wrap_async_stream(result) # type: ignore[arg-type] # LiteLLM streaming types
return result # type: ignore[return-value] # external lib lacks type stubs
async def check_model_availability(self, model: str) -> bool: async def check_model_availability(self, model: str) -> bool:
""" """

View file

@ -248,30 +248,28 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
return model_obj.provider_resource_id return model_obj.provider_resource_id
async def _maybe_overwrite_id(self, resp: Any, stream: bool | None) -> Any: async def _maybe_overwrite_id(self, resp: Any, stream: bool | None) -> Any:
if not self.overwrite_completion_id:
return resp
new_id = f"cltsd-{uuid.uuid4()}"
if stream: if stream:
new_id = f"cltsd-{uuid.uuid4()}" if self.overwrite_completion_id else None
async def _gen(): async def _gen():
async for chunk in resp: async for chunk in resp:
chunk.id = new_id if new_id:
chunk.id = new_id
yield chunk yield chunk
return _gen() return _gen()
else: else:
resp.id = new_id if self.overwrite_completion_id:
resp.id = f"cltsd-{uuid.uuid4()}"
return resp return resp
async def openai_completion( async def openai_completion(
self, self,
params: OpenAICompletionRequestWithExtraBody, params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion: ) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
""" """
Direct OpenAI completion API call. Direct OpenAI completion API call.
""" """
# TODO: fix openai_completion to return type compatible with OpenAI's API response
provider_model_id = await self._get_provider_model_id(params.model) provider_model_id = await self._get_provider_model_id(params.model)
self._validate_model_allowed(provider_model_id) self._validate_model_allowed(provider_model_id)

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from collections.abc import AsyncIterator
logger = logging.getLogger(__name__)
async def wrap_async_stream[T](stream: AsyncIterator[T]) -> AsyncIterator[T]:
"""
Wrap an async stream to ensure it returns a proper AsyncIterator.
"""
try:
async for item in stream:
yield item
except Exception as e:
logger.error(f"Error in wrapped async stream: {e}")
raise

View file

@ -1022,11 +1022,11 @@ class InferenceProvider(Protocol):
async def openai_completion( async def openai_completion(
self, self,
params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)], params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)],
) -> OpenAICompletion: ) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
"""Create completion. """Create completion.
Generate an OpenAI-compatible completion for the given prompt using the specified model. Generate an OpenAI-compatible completion for the given prompt using the specified model.
:returns: An OpenAICompletion. :returns: An OpenAICompletion or an async iterator of OpenAICompletion chunks when streaming.
""" """
... ...

View file

@ -0,0 +1,132 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Regression tests for issue #3185: AsyncStream passed where AsyncIterator expected.
The bug: OpenAI SDK's AsyncStream has close(), not aclose(), but Python's
AsyncIterator protocol requires aclose(). The fix ensures _maybe_overwrite_id()
always wraps streaming responses in an async generator.
"""
import inspect
from collections.abc import AsyncIterator
from unittest.mock import MagicMock
import pytest
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
class MockAsyncStream:
"""Simulates OpenAI SDK's AsyncStream: has close() but NOT aclose()."""
def __init__(self, chunks):
self.chunks = chunks
self._iter = iter(chunks)
def __aiter__(self):
return self
async def __anext__(self):
try:
return next(self._iter)
except StopIteration as e:
raise StopAsyncIteration from e
async def close(self):
pass
class MockChunk:
def __init__(self, chunk_id: str, content: str = "test"):
self.id = chunk_id
self.content = content
class OpenAIMixinTestImpl(OpenAIMixin):
__provider_id__: str = "test-provider"
def get_api_key(self) -> str:
return "test-api-key"
def get_base_url(self) -> str:
return "http://test-base-url"
@pytest.fixture
def mixin():
config = RemoteInferenceProviderConfig()
m = OpenAIMixinTestImpl(config=config)
m.overwrite_completion_id = False
return m
class TestIssue3185Regression:
@pytest.mark.asyncio
async def test_streaming_result_has_aclose(self, mixin):
mock_stream = MockAsyncStream([MockChunk("1")])
assert not hasattr(mock_stream, "aclose")
result = await mixin._maybe_overwrite_id(mock_stream, stream=True)
assert hasattr(result, "aclose"), "Result MUST have aclose() for AsyncIterator"
assert inspect.isasyncgen(result)
assert isinstance(result, AsyncIterator)
@pytest.mark.asyncio
async def test_streaming_yields_all_chunks(self, mixin):
chunks = [MockChunk("1", "a"), MockChunk("2", "b")]
mock_stream = MockAsyncStream(chunks)
result = await mixin._maybe_overwrite_id(mock_stream, stream=True)
received = [c async for c in result]
assert len(received) == 2
assert received[0].content == "a"
assert received[1].content == "b"
@pytest.mark.asyncio
async def test_non_streaming_returns_directly(self, mixin):
mock_response = MagicMock()
mock_response.id = "test-id"
result = await mixin._maybe_overwrite_id(mock_response, stream=False)
assert result is mock_response
assert not inspect.isasyncgen(result)
class TestIdOverwriting:
@pytest.mark.asyncio
async def test_ids_overwritten_when_enabled(self):
config = RemoteInferenceProviderConfig()
mixin = OpenAIMixinTestImpl(config=config)
mixin.overwrite_completion_id = True
chunks = [MockChunk("orig-1"), MockChunk("orig-2")]
result = await mixin._maybe_overwrite_id(MockAsyncStream(chunks), stream=True)
received = [c async for c in result]
assert all(c.id.startswith("cltsd-") for c in received)
assert received[0].id == received[1].id # Same ID for all chunks
@pytest.mark.asyncio
async def test_ids_preserved_when_disabled(self):
config = RemoteInferenceProviderConfig()
mixin = OpenAIMixinTestImpl(config=config)
mixin.overwrite_completion_id = False
chunks = [MockChunk("orig-1"), MockChunk("orig-2")]
result = await mixin._maybe_overwrite_id(MockAsyncStream(chunks), stream=True)
received = [c async for c in result]
assert received[0].id == "orig-1"
assert received[1].id == "orig-2"