mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-16 17:02:36 +00:00
fix(inference): AttributeError in streaming response cleanup (#4236)
This PR fixes issue #3185 The code calls `await event_gen.aclose()` but OpenAI's `AsyncStream` doesn't have an `aclose()` method - it has `close()` (which is async). when clients cancel streaming requests, the server tries to clean up with: ```python await event_gen.aclose() # ❌ AsyncStream doesn't have aclose()! ``` But `AsyncStream` has never had a public `aclose()` method. The error message literally tells us: ``` AttributeError: 'AsyncStream' object has no attribute 'aclose'. Did you mean: 'close'? ^^^^^^^^ ``` ## Verification * Reproduction script [`reproduce_issue_3185.sh`](https://gist.github.com/r-bit-rry/dea4f8fbb81c446f5db50ea7abd6379b) can be used to verify the fix. * Manual checks, validation against original OpenAI library code
This commit is contained in:
parent
dfb9f6743a
commit
c574db5f1d
14 changed files with 213 additions and 30 deletions
125
tests/unit/providers/utils/test_openai_mixin_streaming.py
Normal file
125
tests/unit/providers/utils/test_openai_mixin_streaming.py
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Regression tests for issue #3185: AsyncStream passed where AsyncIterator expected.
|
||||
|
||||
The bug: OpenAI SDK's AsyncStream has close(), not aclose(), but Python's
|
||||
AsyncIterator protocol requires aclose(). The fix ensures _maybe_overwrite_id()
|
||||
always wraps streaming responses in an async generator.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from collections.abc import AsyncIterator
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
||||
class MockAsyncStream:
|
||||
"""Simulates OpenAI SDK's AsyncStream: has close() but NOT aclose()."""
|
||||
|
||||
def __init__(self, chunks):
|
||||
self.chunks = chunks
|
||||
self._iter = iter(chunks)
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
return next(self._iter)
|
||||
except StopIteration as e:
|
||||
raise StopAsyncIteration from e
|
||||
|
||||
async def close(self):
|
||||
pass
|
||||
|
||||
|
||||
class MockChunk:
|
||||
def __init__(self, chunk_id: str, content: str = "test"):
|
||||
self.id = chunk_id
|
||||
self.content = content
|
||||
|
||||
|
||||
class OpenAIMixinTestImpl(OpenAIMixin):
|
||||
__provider_id__: str = "test-provider"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return "test-api-key"
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return "http://test-base-url"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mixin():
|
||||
config = RemoteInferenceProviderConfig()
|
||||
m = OpenAIMixinTestImpl(config=config)
|
||||
m.overwrite_completion_id = False
|
||||
return m
|
||||
|
||||
|
||||
class TestIssue3185Regression:
|
||||
async def test_streaming_result_has_aclose(self, mixin):
|
||||
mock_stream = MockAsyncStream([MockChunk("1")])
|
||||
|
||||
assert not hasattr(mock_stream, "aclose")
|
||||
|
||||
result = await mixin._maybe_overwrite_id(mock_stream, stream=True)
|
||||
|
||||
assert hasattr(result, "aclose"), "Result MUST have aclose() for AsyncIterator"
|
||||
assert inspect.isasyncgen(result)
|
||||
assert isinstance(result, AsyncIterator)
|
||||
|
||||
async def test_streaming_yields_all_chunks(self, mixin):
|
||||
chunks = [MockChunk("1", "a"), MockChunk("2", "b")]
|
||||
mock_stream = MockAsyncStream(chunks)
|
||||
|
||||
result = await mixin._maybe_overwrite_id(mock_stream, stream=True)
|
||||
|
||||
received = [c async for c in result]
|
||||
assert len(received) == 2
|
||||
assert received[0].content == "a"
|
||||
assert received[1].content == "b"
|
||||
|
||||
async def test_non_streaming_returns_directly(self, mixin):
|
||||
mock_response = MagicMock()
|
||||
mock_response.id = "test-id"
|
||||
|
||||
result = await mixin._maybe_overwrite_id(mock_response, stream=False)
|
||||
|
||||
assert result is mock_response
|
||||
assert not inspect.isasyncgen(result)
|
||||
|
||||
|
||||
class TestIdOverwriting:
|
||||
async def test_ids_overwritten_when_enabled(self):
|
||||
config = RemoteInferenceProviderConfig()
|
||||
mixin = OpenAIMixinTestImpl(config=config)
|
||||
mixin.overwrite_completion_id = True
|
||||
|
||||
chunks = [MockChunk("orig-1"), MockChunk("orig-2")]
|
||||
result = await mixin._maybe_overwrite_id(MockAsyncStream(chunks), stream=True)
|
||||
|
||||
received = [c async for c in result]
|
||||
assert all(c.id.startswith("cltsd-") for c in received)
|
||||
assert received[0].id == received[1].id # Same ID for all chunks
|
||||
|
||||
async def test_ids_preserved_when_disabled(self):
|
||||
config = RemoteInferenceProviderConfig()
|
||||
mixin = OpenAIMixinTestImpl(config=config)
|
||||
mixin.overwrite_completion_id = False
|
||||
|
||||
chunks = [MockChunk("orig-1"), MockChunk("orig-2")]
|
||||
result = await mixin._maybe_overwrite_id(MockAsyncStream(chunks), stream=True)
|
||||
|
||||
received = [c async for c in result]
|
||||
assert received[0].id == "orig-1"
|
||||
assert received[1].id == "orig-2"
|
||||
Loading…
Add table
Add a link
Reference in a new issue