mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
further fixes according to investigation and PR comments
This commit is contained in:
parent
9b3c041af0
commit
c3c9edf981
14 changed files with 212 additions and 30 deletions
5
docs/static/llama-stack-spec.yaml
vendored
5
docs/static/llama-stack-spec.yaml
vendored
|
|
@ -289,11 +289,14 @@ paths:
|
||||||
post:
|
post:
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: An OpenAICompletion.
|
description: An OpenAICompletion or an async iterator of OpenAICompletion chunks when streaming.
|
||||||
content:
|
content:
|
||||||
application/json:
|
application/json:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/OpenAICompletion'
|
$ref: '#/components/schemas/OpenAICompletion'
|
||||||
|
text/event-stream:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/OpenAICompletion'
|
||||||
'400':
|
'400':
|
||||||
description: Bad Request
|
description: Bad Request
|
||||||
$ref: '#/components/responses/BadRequest400'
|
$ref: '#/components/responses/BadRequest400'
|
||||||
|
|
|
||||||
5
docs/static/stainless-llama-stack-spec.yaml
vendored
5
docs/static/stainless-llama-stack-spec.yaml
vendored
|
|
@ -291,11 +291,14 @@ paths:
|
||||||
post:
|
post:
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: An OpenAICompletion.
|
description: An OpenAICompletion or an async iterator of OpenAICompletion chunks when streaming.
|
||||||
content:
|
content:
|
||||||
application/json:
|
application/json:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/OpenAICompletion'
|
$ref: '#/components/schemas/OpenAICompletion'
|
||||||
|
text/event-stream:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/OpenAICompletion'
|
||||||
'400':
|
'400':
|
||||||
description: Bad Request
|
description: Bad Request
|
||||||
$ref: '#/components/responses/BadRequest400'
|
$ref: '#/components/responses/BadRequest400'
|
||||||
|
|
|
||||||
|
|
@ -205,11 +205,7 @@ async def sse_generator(event_gen_coroutine):
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info("Generator cancelled")
|
logger.info("Generator cancelled")
|
||||||
if event_gen:
|
if event_gen:
|
||||||
# Some generators (like OpenAI's AsyncStream) only have close()
|
|
||||||
if hasattr(event_gen, "aclose"):
|
|
||||||
await event_gen.aclose()
|
await event_gen.aclose()
|
||||||
elif hasattr(event_gen, "close"):
|
|
||||||
await event_gen.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error in sse_generator")
|
logger.exception("Error in sse_generator")
|
||||||
yield create_sse_event(
|
yield create_sse_event(
|
||||||
|
|
|
||||||
|
|
@ -246,7 +246,7 @@ class MetaReferenceInferenceImpl(
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
params: OpenAICompletionRequestWithExtraBody,
|
params: OpenAICompletionRequestWithExtraBody,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
|
||||||
raise NotImplementedError("OpenAI completion not supported by meta reference provider")
|
raise NotImplementedError("OpenAI completion not supported by meta reference provider")
|
||||||
|
|
||||||
async def should_refresh_models(self) -> bool:
|
async def should_refresh_models(self) -> bool:
|
||||||
|
|
|
||||||
|
|
@ -70,7 +70,7 @@ class BedrockInferenceAdapter(OpenAIMixin):
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
params: OpenAICompletionRequestWithExtraBody,
|
params: OpenAICompletionRequestWithExtraBody,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
|
||||||
"""Bedrock's OpenAI-compatible API does not support the /v1/completions endpoint."""
|
"""Bedrock's OpenAI-compatible API does not support the /v1/completions endpoint."""
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Bedrock's OpenAI-compatible API does not support /v1/completions endpoint. "
|
"Bedrock's OpenAI-compatible API does not support /v1/completions endpoint. "
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from collections.abc import Iterable
|
from collections.abc import AsyncIterator, Iterable
|
||||||
|
|
||||||
from databricks.sdk import WorkspaceClient
|
from databricks.sdk import WorkspaceClient
|
||||||
|
|
||||||
|
|
@ -50,5 +50,5 @@ class DatabricksInferenceAdapter(OpenAIMixin):
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
params: OpenAICompletionRequestWithExtraBody,
|
params: OpenAICompletionRequestWithExtraBody,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
@ -36,7 +38,7 @@ class LlamaCompatInferenceAdapter(OpenAIMixin):
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
params: OpenAICompletionRequestWithExtraBody,
|
params: OpenAICompletionRequestWithExtraBody,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
async def openai_embeddings(
|
async def openai_embeddings(
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from collections.abc import AsyncIterator
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||||
|
from llama_stack.providers.utils.inference.stream_utils import wrap_async_stream
|
||||||
from llama_stack_api import (
|
from llama_stack_api import (
|
||||||
Inference,
|
Inference,
|
||||||
Model,
|
Model,
|
||||||
|
|
@ -107,12 +108,16 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
params: OpenAICompletionRequestWithExtraBody,
|
params: OpenAICompletionRequestWithExtraBody,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
|
||||||
"""Forward completion request to downstream using OpenAI client."""
|
"""Forward completion request to downstream using OpenAI client."""
|
||||||
client = self._get_openai_client()
|
client = self._get_openai_client()
|
||||||
request_params = params.model_dump(exclude_none=True)
|
request_params = params.model_dump(exclude_none=True)
|
||||||
response = await client.completions.create(**request_params)
|
response = await client.completions.create(**request_params)
|
||||||
return response # type: ignore
|
|
||||||
|
if params.stream:
|
||||||
|
return wrap_async_stream(response)
|
||||||
|
|
||||||
|
return response # type: ignore[return-value]
|
||||||
|
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
self,
|
self,
|
||||||
|
|
@ -122,7 +127,11 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
|
||||||
client = self._get_openai_client()
|
client = self._get_openai_client()
|
||||||
request_params = params.model_dump(exclude_none=True)
|
request_params = params.model_dump(exclude_none=True)
|
||||||
response = await client.chat.completions.create(**request_params)
|
response = await client.chat.completions.create(**request_params)
|
||||||
return response # type: ignore
|
|
||||||
|
if params.stream:
|
||||||
|
return wrap_async_stream(response)
|
||||||
|
|
||||||
|
return response # type: ignore[return-value]
|
||||||
|
|
||||||
async def openai_embeddings(
|
async def openai_embeddings(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
|
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
|
||||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||||
|
from llama_stack.providers.utils.inference.stream_utils import wrap_async_stream
|
||||||
from llama_stack_api import (
|
from llama_stack_api import (
|
||||||
Model,
|
Model,
|
||||||
ModelType,
|
ModelType,
|
||||||
|
|
@ -178,7 +179,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
params: OpenAICompletionRequestWithExtraBody,
|
params: OpenAICompletionRequestWithExtraBody,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
|
||||||
"""
|
"""
|
||||||
Override parent method to add watsonx-specific parameters.
|
Override parent method to add watsonx-specific parameters.
|
||||||
"""
|
"""
|
||||||
|
|
@ -211,7 +212,12 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
timeout=self.config.timeout,
|
timeout=self.config.timeout,
|
||||||
project_id=self.config.project_id,
|
project_id=self.config.project_id,
|
||||||
)
|
)
|
||||||
return await litellm.atext_completion(**request_params)
|
result = await litellm.atext_completion(**request_params)
|
||||||
|
|
||||||
|
if params.stream:
|
||||||
|
return wrap_async_stream(result) # type: ignore[arg-type] # LiteLLM streaming types
|
||||||
|
|
||||||
|
return result # type: ignore[return-value] # external lib lacks type stubs
|
||||||
|
|
||||||
async def openai_embeddings(
|
async def openai_embeddings(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ from llama_stack.providers.utils.inference.model_registry import ModelRegistryHe
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
prepare_openai_completion_params,
|
prepare_openai_completion_params,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.inference.stream_utils import wrap_async_stream
|
||||||
from llama_stack_api import (
|
from llama_stack_api import (
|
||||||
InferenceProvider,
|
InferenceProvider,
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
|
|
@ -178,7 +179,7 @@ class LiteLLMOpenAIMixin(
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
params: OpenAICompletionRequestWithExtraBody,
|
params: OpenAICompletionRequestWithExtraBody,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
|
||||||
if not self.model_store:
|
if not self.model_store:
|
||||||
raise ValueError("Model store is not initialized")
|
raise ValueError("Model store is not initialized")
|
||||||
|
|
||||||
|
|
@ -210,7 +211,12 @@ class LiteLLMOpenAIMixin(
|
||||||
api_base=self.api_base,
|
api_base=self.api_base,
|
||||||
)
|
)
|
||||||
# LiteLLM returns compatible type but mypy can't verify external library
|
# LiteLLM returns compatible type but mypy can't verify external library
|
||||||
return await litellm.atext_completion(**request_params) # type: ignore[no-any-return] # external lib lacks type stubs
|
result = await litellm.atext_completion(**request_params)
|
||||||
|
|
||||||
|
if params.stream:
|
||||||
|
return wrap_async_stream(result) # type: ignore[arg-type] # LiteLLM streaming types
|
||||||
|
|
||||||
|
return result # type: ignore[return-value] # external lib lacks type stubs
|
||||||
|
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
self,
|
self,
|
||||||
|
|
@ -262,7 +268,12 @@ class LiteLLMOpenAIMixin(
|
||||||
api_base=self.api_base,
|
api_base=self.api_base,
|
||||||
)
|
)
|
||||||
# LiteLLM returns compatible type but mypy can't verify external library
|
# LiteLLM returns compatible type but mypy can't verify external library
|
||||||
return await litellm.acompletion(**request_params) # type: ignore[no-any-return] # external lib lacks type stubs
|
result = await litellm.acompletion(**request_params)
|
||||||
|
|
||||||
|
if params.stream:
|
||||||
|
return wrap_async_stream(result) # type: ignore[arg-type] # LiteLLM streaming types
|
||||||
|
|
||||||
|
return result # type: ignore[return-value] # external lib lacks type stubs
|
||||||
|
|
||||||
async def check_model_availability(self, model: str) -> bool:
|
async def check_model_availability(self, model: str) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -248,30 +248,28 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
||||||
return model_obj.provider_resource_id
|
return model_obj.provider_resource_id
|
||||||
|
|
||||||
async def _maybe_overwrite_id(self, resp: Any, stream: bool | None) -> Any:
|
async def _maybe_overwrite_id(self, resp: Any, stream: bool | None) -> Any:
|
||||||
if not self.overwrite_completion_id:
|
|
||||||
return resp
|
|
||||||
|
|
||||||
new_id = f"cltsd-{uuid.uuid4()}"
|
|
||||||
if stream:
|
if stream:
|
||||||
|
new_id = f"cltsd-{uuid.uuid4()}" if self.overwrite_completion_id else None
|
||||||
|
|
||||||
async def _gen():
|
async def _gen():
|
||||||
async for chunk in resp:
|
async for chunk in resp:
|
||||||
|
if new_id:
|
||||||
chunk.id = new_id
|
chunk.id = new_id
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
return _gen()
|
return _gen()
|
||||||
else:
|
else:
|
||||||
resp.id = new_id
|
if self.overwrite_completion_id:
|
||||||
|
resp.id = f"cltsd-{uuid.uuid4()}"
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
params: OpenAICompletionRequestWithExtraBody,
|
params: OpenAICompletionRequestWithExtraBody,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
|
||||||
"""
|
"""
|
||||||
Direct OpenAI completion API call.
|
Direct OpenAI completion API call.
|
||||||
"""
|
"""
|
||||||
# TODO: fix openai_completion to return type compatible with OpenAI's API response
|
|
||||||
provider_model_id = await self._get_provider_model_id(params.model)
|
provider_model_id = await self._get_provider_model_id(params.model)
|
||||||
self._validate_model_allowed(provider_model_id)
|
self._validate_model_allowed(provider_model_id)
|
||||||
|
|
||||||
|
|
|
||||||
22
src/llama_stack/providers/utils/inference/stream_utils.py
Normal file
22
src/llama_stack/providers/utils/inference/stream_utils.py
Normal file
|
|
@ -0,0 +1,22 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def wrap_async_stream[T](stream: AsyncIterator[T]) -> AsyncIterator[T]:
|
||||||
|
"""
|
||||||
|
Wrap an async stream to ensure it returns a proper AsyncIterator.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async for item in stream:
|
||||||
|
yield item
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in wrapped async stream: {e}")
|
||||||
|
raise
|
||||||
|
|
@ -1022,11 +1022,11 @@ class InferenceProvider(Protocol):
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)],
|
params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)],
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
|
||||||
"""Create completion.
|
"""Create completion.
|
||||||
|
|
||||||
Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
||||||
:returns: An OpenAICompletion.
|
:returns: An OpenAICompletion or an async iterator of OpenAICompletion chunks when streaming.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
|
||||||
132
tests/unit/providers/utils/test_openai_mixin_streaming.py
Normal file
132
tests/unit/providers/utils/test_openai_mixin_streaming.py
Normal file
|
|
@ -0,0 +1,132 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Regression tests for issue #3185: AsyncStream passed where AsyncIterator expected.
|
||||||
|
|
||||||
|
The bug: OpenAI SDK's AsyncStream has close(), not aclose(), but Python's
|
||||||
|
AsyncIterator protocol requires aclose(). The fix ensures _maybe_overwrite_id()
|
||||||
|
always wraps streaming responses in an async generator.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
||||||
|
|
||||||
|
class MockAsyncStream:
|
||||||
|
"""Simulates OpenAI SDK's AsyncStream: has close() but NOT aclose()."""
|
||||||
|
|
||||||
|
def __init__(self, chunks):
|
||||||
|
self.chunks = chunks
|
||||||
|
self._iter = iter(chunks)
|
||||||
|
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
try:
|
||||||
|
return next(self._iter)
|
||||||
|
except StopIteration as e:
|
||||||
|
raise StopAsyncIteration from e
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MockChunk:
|
||||||
|
def __init__(self, chunk_id: str, content: str = "test"):
|
||||||
|
self.id = chunk_id
|
||||||
|
self.content = content
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIMixinTestImpl(OpenAIMixin):
|
||||||
|
__provider_id__: str = "test-provider"
|
||||||
|
|
||||||
|
def get_api_key(self) -> str:
|
||||||
|
return "test-api-key"
|
||||||
|
|
||||||
|
def get_base_url(self) -> str:
|
||||||
|
return "http://test-base-url"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mixin():
|
||||||
|
config = RemoteInferenceProviderConfig()
|
||||||
|
m = OpenAIMixinTestImpl(config=config)
|
||||||
|
m.overwrite_completion_id = False
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssue3185Regression:
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_streaming_result_has_aclose(self, mixin):
|
||||||
|
mock_stream = MockAsyncStream([MockChunk("1")])
|
||||||
|
|
||||||
|
assert not hasattr(mock_stream, "aclose")
|
||||||
|
|
||||||
|
result = await mixin._maybe_overwrite_id(mock_stream, stream=True)
|
||||||
|
|
||||||
|
assert hasattr(result, "aclose"), "Result MUST have aclose() for AsyncIterator"
|
||||||
|
assert inspect.isasyncgen(result)
|
||||||
|
assert isinstance(result, AsyncIterator)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_streaming_yields_all_chunks(self, mixin):
|
||||||
|
chunks = [MockChunk("1", "a"), MockChunk("2", "b")]
|
||||||
|
mock_stream = MockAsyncStream(chunks)
|
||||||
|
|
||||||
|
result = await mixin._maybe_overwrite_id(mock_stream, stream=True)
|
||||||
|
|
||||||
|
received = [c async for c in result]
|
||||||
|
assert len(received) == 2
|
||||||
|
assert received[0].content == "a"
|
||||||
|
assert received[1].content == "b"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_non_streaming_returns_directly(self, mixin):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.id = "test-id"
|
||||||
|
|
||||||
|
result = await mixin._maybe_overwrite_id(mock_response, stream=False)
|
||||||
|
|
||||||
|
assert result is mock_response
|
||||||
|
assert not inspect.isasyncgen(result)
|
||||||
|
|
||||||
|
|
||||||
|
class TestIdOverwriting:
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ids_overwritten_when_enabled(self):
|
||||||
|
config = RemoteInferenceProviderConfig()
|
||||||
|
mixin = OpenAIMixinTestImpl(config=config)
|
||||||
|
mixin.overwrite_completion_id = True
|
||||||
|
|
||||||
|
chunks = [MockChunk("orig-1"), MockChunk("orig-2")]
|
||||||
|
result = await mixin._maybe_overwrite_id(MockAsyncStream(chunks), stream=True)
|
||||||
|
|
||||||
|
received = [c async for c in result]
|
||||||
|
assert all(c.id.startswith("cltsd-") for c in received)
|
||||||
|
assert received[0].id == received[1].id # Same ID for all chunks
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ids_preserved_when_disabled(self):
|
||||||
|
config = RemoteInferenceProviderConfig()
|
||||||
|
mixin = OpenAIMixinTestImpl(config=config)
|
||||||
|
mixin.overwrite_completion_id = False
|
||||||
|
|
||||||
|
chunks = [MockChunk("orig-1"), MockChunk("orig-2")]
|
||||||
|
result = await mixin._maybe_overwrite_id(MockAsyncStream(chunks), stream=True)
|
||||||
|
|
||||||
|
received = [c async for c in result]
|
||||||
|
assert received[0].id == "orig-1"
|
||||||
|
assert received[1].id == "orig-2"
|
||||||
Loading…
Add table
Add a link
Reference in a new issue