mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 01:48:05 +00:00
further fixes according to investigation and PR comments
This commit is contained in:
parent
9b3c041af0
commit
c3c9edf981
14 changed files with 212 additions and 30 deletions
5
docs/static/llama-stack-spec.yaml
vendored
5
docs/static/llama-stack-spec.yaml
vendored
|
|
@ -289,11 +289,14 @@ paths:
|
|||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: An OpenAICompletion.
|
||||
description: An OpenAICompletion or an async iterator of OpenAICompletion chunks when streaming.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/OpenAICompletion'
|
||||
text/event-stream:
|
||||
schema:
|
||||
$ref: '#/components/schemas/OpenAICompletion'
|
||||
'400':
|
||||
description: Bad Request
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
|
|
|
|||
5
docs/static/stainless-llama-stack-spec.yaml
vendored
5
docs/static/stainless-llama-stack-spec.yaml
vendored
|
|
@ -291,11 +291,14 @@ paths:
|
|||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: An OpenAICompletion.
|
||||
description: An OpenAICompletion or an async iterator of OpenAICompletion chunks when streaming.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/OpenAICompletion'
|
||||
text/event-stream:
|
||||
schema:
|
||||
$ref: '#/components/schemas/OpenAICompletion'
|
||||
'400':
|
||||
description: Bad Request
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
|
|
|
|||
|
|
@ -205,11 +205,7 @@ async def sse_generator(event_gen_coroutine):
|
|||
except asyncio.CancelledError:
|
||||
logger.info("Generator cancelled")
|
||||
if event_gen:
|
||||
# Some generators (like OpenAI's AsyncStream) only have close()
|
||||
if hasattr(event_gen, "aclose"):
|
||||
await event_gen.aclose()
|
||||
elif hasattr(event_gen, "close"):
|
||||
await event_gen.close()
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
logger.exception("Error in sse_generator")
|
||||
yield create_sse_event(
|
||||
|
|
|
|||
|
|
@ -246,7 +246,7 @@ class MetaReferenceInferenceImpl(
|
|||
async def openai_completion(
|
||||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
|
||||
raise NotImplementedError("OpenAI completion not supported by meta reference provider")
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ class BedrockInferenceAdapter(OpenAIMixin):
|
|||
async def openai_completion(
|
||||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
|
||||
"""Bedrock's OpenAI-compatible API does not support the /v1/completions endpoint."""
|
||||
raise NotImplementedError(
|
||||
"Bedrock's OpenAI-compatible API does not support /v1/completions endpoint. "
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import AsyncIterator, Iterable
|
||||
|
||||
from databricks.sdk import WorkspaceClient
|
||||
|
||||
|
|
@ -50,5 +50,5 @@ class DatabricksInferenceAdapter(OpenAIMixin):
|
|||
async def openai_completion(
|
||||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
|
||||
raise NotImplementedError()
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
|
@ -36,7 +38,7 @@ class LlamaCompatInferenceAdapter(OpenAIMixin):
|
|||
async def openai_completion(
|
||||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_embeddings(
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from collections.abc import AsyncIterator
|
|||
from openai import AsyncOpenAI
|
||||
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.utils.inference.stream_utils import wrap_async_stream
|
||||
from llama_stack_api import (
|
||||
Inference,
|
||||
Model,
|
||||
|
|
@ -107,12 +108,16 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
|
|||
async def openai_completion(
|
||||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
|
||||
"""Forward completion request to downstream using OpenAI client."""
|
||||
client = self._get_openai_client()
|
||||
request_params = params.model_dump(exclude_none=True)
|
||||
response = await client.completions.create(**request_params)
|
||||
return response # type: ignore
|
||||
|
||||
if params.stream:
|
||||
return wrap_async_stream(response)
|
||||
|
||||
return response # type: ignore[return-value]
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
|
|
@ -122,7 +127,11 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
|
|||
client = self._get_openai_client()
|
||||
request_params = params.model_dump(exclude_none=True)
|
||||
response = await client.chat.completions.create(**request_params)
|
||||
return response # type: ignore
|
||||
|
||||
if params.stream:
|
||||
return wrap_async_stream(response)
|
||||
|
||||
return response # type: ignore[return-value]
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from llama_stack.log import get_logger
|
|||
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||
from llama_stack.providers.utils.inference.stream_utils import wrap_async_stream
|
||||
from llama_stack_api import (
|
||||
Model,
|
||||
ModelType,
|
||||
|
|
@ -178,7 +179,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
async def openai_completion(
|
||||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
|
||||
"""
|
||||
Override parent method to add watsonx-specific parameters.
|
||||
"""
|
||||
|
|
@ -211,7 +212,12 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
timeout=self.config.timeout,
|
||||
project_id=self.config.project_id,
|
||||
)
|
||||
return await litellm.atext_completion(**request_params)
|
||||
result = await litellm.atext_completion(**request_params)
|
||||
|
||||
if params.stream:
|
||||
return wrap_async_stream(result) # type: ignore[arg-type] # LiteLLM streaming types
|
||||
|
||||
return result # type: ignore[return-value] # external lib lacks type stubs
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from llama_stack.providers.utils.inference.model_registry import ModelRegistryHe
|
|||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
prepare_openai_completion_params,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.stream_utils import wrap_async_stream
|
||||
from llama_stack_api import (
|
||||
InferenceProvider,
|
||||
OpenAIChatCompletion,
|
||||
|
|
@ -178,7 +179,7 @@ class LiteLLMOpenAIMixin(
|
|||
async def openai_completion(
|
||||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store is not initialized")
|
||||
|
||||
|
|
@ -210,7 +211,12 @@ class LiteLLMOpenAIMixin(
|
|||
api_base=self.api_base,
|
||||
)
|
||||
# LiteLLM returns compatible type but mypy can't verify external library
|
||||
return await litellm.atext_completion(**request_params) # type: ignore[no-any-return] # external lib lacks type stubs
|
||||
result = await litellm.atext_completion(**request_params)
|
||||
|
||||
if params.stream:
|
||||
return wrap_async_stream(result) # type: ignore[arg-type] # LiteLLM streaming types
|
||||
|
||||
return result # type: ignore[return-value] # external lib lacks type stubs
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
|
|
@ -262,7 +268,12 @@ class LiteLLMOpenAIMixin(
|
|||
api_base=self.api_base,
|
||||
)
|
||||
# LiteLLM returns compatible type but mypy can't verify external library
|
||||
return await litellm.acompletion(**request_params) # type: ignore[no-any-return] # external lib lacks type stubs
|
||||
result = await litellm.acompletion(**request_params)
|
||||
|
||||
if params.stream:
|
||||
return wrap_async_stream(result) # type: ignore[arg-type] # LiteLLM streaming types
|
||||
|
||||
return result # type: ignore[return-value] # external lib lacks type stubs
|
||||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -248,30 +248,28 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
return model_obj.provider_resource_id
|
||||
|
||||
async def _maybe_overwrite_id(self, resp: Any, stream: bool | None) -> Any:
|
||||
if not self.overwrite_completion_id:
|
||||
return resp
|
||||
|
||||
new_id = f"cltsd-{uuid.uuid4()}"
|
||||
if stream:
|
||||
new_id = f"cltsd-{uuid.uuid4()}" if self.overwrite_completion_id else None
|
||||
|
||||
async def _gen():
|
||||
async for chunk in resp:
|
||||
chunk.id = new_id
|
||||
if new_id:
|
||||
chunk.id = new_id
|
||||
yield chunk
|
||||
|
||||
return _gen()
|
||||
else:
|
||||
resp.id = new_id
|
||||
if self.overwrite_completion_id:
|
||||
resp.id = f"cltsd-{uuid.uuid4()}"
|
||||
return resp
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
|
||||
"""
|
||||
Direct OpenAI completion API call.
|
||||
"""
|
||||
# TODO: fix openai_completion to return type compatible with OpenAI's API response
|
||||
provider_model_id = await self._get_provider_model_id(params.model)
|
||||
self._validate_model_allowed(provider_model_id)
|
||||
|
||||
|
|
|
|||
22
src/llama_stack/providers/utils/inference/stream_utils.py
Normal file
22
src/llama_stack/providers/utils/inference/stream_utils.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def wrap_async_stream[T](stream: AsyncIterator[T]) -> AsyncIterator[T]:
|
||||
"""
|
||||
Wrap an async stream to ensure it returns a proper AsyncIterator.
|
||||
"""
|
||||
try:
|
||||
async for item in stream:
|
||||
yield item
|
||||
except Exception as e:
|
||||
logger.error(f"Error in wrapped async stream: {e}")
|
||||
raise
|
||||
|
|
@ -1022,11 +1022,11 @@ class InferenceProvider(Protocol):
|
|||
async def openai_completion(
|
||||
self,
|
||||
params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)],
|
||||
) -> OpenAICompletion:
|
||||
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
|
||||
"""Create completion.
|
||||
|
||||
Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
||||
:returns: An OpenAICompletion.
|
||||
:returns: An OpenAICompletion or an async iterator of OpenAICompletion chunks when streaming.
|
||||
"""
|
||||
...
|
||||
|
||||
|
|
|
|||
132
tests/unit/providers/utils/test_openai_mixin_streaming.py
Normal file
132
tests/unit/providers/utils/test_openai_mixin_streaming.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Regression tests for issue #3185: AsyncStream passed where AsyncIterator expected.
|
||||
|
||||
The bug: OpenAI SDK's AsyncStream has close(), not aclose(), but Python's
|
||||
AsyncIterator protocol requires aclose(). The fix ensures _maybe_overwrite_id()
|
||||
always wraps streaming responses in an async generator.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from collections.abc import AsyncIterator
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
||||
class MockAsyncStream:
|
||||
"""Simulates OpenAI SDK's AsyncStream: has close() but NOT aclose()."""
|
||||
|
||||
def __init__(self, chunks):
|
||||
self.chunks = chunks
|
||||
self._iter = iter(chunks)
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
return next(self._iter)
|
||||
except StopIteration as e:
|
||||
raise StopAsyncIteration from e
|
||||
|
||||
async def close(self):
|
||||
pass
|
||||
|
||||
|
||||
class MockChunk:
|
||||
def __init__(self, chunk_id: str, content: str = "test"):
|
||||
self.id = chunk_id
|
||||
self.content = content
|
||||
|
||||
|
||||
class OpenAIMixinTestImpl(OpenAIMixin):
|
||||
__provider_id__: str = "test-provider"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return "test-api-key"
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return "http://test-base-url"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mixin():
|
||||
config = RemoteInferenceProviderConfig()
|
||||
m = OpenAIMixinTestImpl(config=config)
|
||||
m.overwrite_completion_id = False
|
||||
return m
|
||||
|
||||
|
||||
class TestIssue3185Regression:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_result_has_aclose(self, mixin):
|
||||
mock_stream = MockAsyncStream([MockChunk("1")])
|
||||
|
||||
assert not hasattr(mock_stream, "aclose")
|
||||
|
||||
result = await mixin._maybe_overwrite_id(mock_stream, stream=True)
|
||||
|
||||
assert hasattr(result, "aclose"), "Result MUST have aclose() for AsyncIterator"
|
||||
assert inspect.isasyncgen(result)
|
||||
assert isinstance(result, AsyncIterator)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_yields_all_chunks(self, mixin):
|
||||
chunks = [MockChunk("1", "a"), MockChunk("2", "b")]
|
||||
mock_stream = MockAsyncStream(chunks)
|
||||
|
||||
result = await mixin._maybe_overwrite_id(mock_stream, stream=True)
|
||||
|
||||
received = [c async for c in result]
|
||||
assert len(received) == 2
|
||||
assert received[0].content == "a"
|
||||
assert received[1].content == "b"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_streaming_returns_directly(self, mixin):
|
||||
mock_response = MagicMock()
|
||||
mock_response.id = "test-id"
|
||||
|
||||
result = await mixin._maybe_overwrite_id(mock_response, stream=False)
|
||||
|
||||
assert result is mock_response
|
||||
assert not inspect.isasyncgen(result)
|
||||
|
||||
|
||||
class TestIdOverwriting:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ids_overwritten_when_enabled(self):
|
||||
config = RemoteInferenceProviderConfig()
|
||||
mixin = OpenAIMixinTestImpl(config=config)
|
||||
mixin.overwrite_completion_id = True
|
||||
|
||||
chunks = [MockChunk("orig-1"), MockChunk("orig-2")]
|
||||
result = await mixin._maybe_overwrite_id(MockAsyncStream(chunks), stream=True)
|
||||
|
||||
received = [c async for c in result]
|
||||
assert all(c.id.startswith("cltsd-") for c in received)
|
||||
assert received[0].id == received[1].id # Same ID for all chunks
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ids_preserved_when_disabled(self):
|
||||
config = RemoteInferenceProviderConfig()
|
||||
mixin = OpenAIMixinTestImpl(config=config)
|
||||
mixin.overwrite_completion_id = False
|
||||
|
||||
chunks = [MockChunk("orig-1"), MockChunk("orig-2")]
|
||||
result = await mixin._maybe_overwrite_id(MockAsyncStream(chunks), stream=True)
|
||||
|
||||
received = [c async for c in result]
|
||||
assert received[0].id == "orig-1"
|
||||
assert received[1].id == "orig-2"
|
||||
Loading…
Add table
Add a link
Reference in a new issue