forked from phoenix-oss/llama-stack-mirror
fix: Additional streaming error handling (#2007)
# What does this PR do? This expands the `test_sse` test suite and fixes some edge cases with bugs in our SSE error handling to ensure streaming clients always get a proper error response. First, we handle the case where a client disconnects before we actually start streaming the response back. Previously we only handled the case where a client disconnected as we were streaming the response, but there was an edge case where a client disconnecting before we streamed any response back did not trigger our logic to cleanly handle that disconnect. Second, we handle the case where an error is thrown from the server before the actual async generator gets created from the provider. This happens in scenarios like the newly merged OpenAI API input validation, where we eagerly raise validation errors before returning the async generator object that streams the responses back. ## Test Plan Tested via: ``` python -m pytest -s -v tests/unit/server/test_sse.py ``` Both test cases failed before, and passed afterwards. The test cases were written based on me experimenting with actual clients that would do bad things like randomly disconnect or send invalid input in streaming mode and I hit these two cases, where things were misbehaving in our error handling. Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
c8797f1125
commit
0b6cd45950
2 changed files with 41 additions and 3 deletions
|
@ -166,14 +166,16 @@ async def maybe_await(value):
|
|||
|
||||
|
||||
async def sse_generator(event_gen_coroutine):
|
||||
event_gen = await event_gen_coroutine
|
||||
event_gen = None
|
||||
try:
|
||||
event_gen = await event_gen_coroutine
|
||||
async for item in event_gen:
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Generator cancelled")
|
||||
await event_gen.aclose()
|
||||
if event_gen:
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
logger.exception("Error in sse_generator")
|
||||
yield create_sse_event(
|
||||
|
|
|
@ -47,9 +47,45 @@ async def test_sse_generator_client_disconnected():
|
|||
sse_gen = sse_generator(async_event_gen())
|
||||
assert sse_gen is not None
|
||||
|
||||
# Start reading the events, ensuring this doesn't raise an exception
|
||||
seen_events = []
|
||||
async for event in sse_gen:
|
||||
seen_events.append(event)
|
||||
|
||||
# We should see 1 event before the client disconnected
|
||||
assert len(seen_events) == 1
|
||||
assert seen_events[0] == create_sse_event("Test event 1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_generator_client_disconnected_before_response_starts():
|
||||
# Disconnect before the response starts
|
||||
async def async_event_gen():
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
sse_gen = sse_generator(async_event_gen())
|
||||
assert sse_gen is not None
|
||||
|
||||
seen_events = []
|
||||
async for event in sse_gen:
|
||||
seen_events.append(event)
|
||||
|
||||
# No events should be seen since the client disconnected immediately
|
||||
assert len(seen_events) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_generator_error_before_response_starts():
|
||||
# Raise an error before the response starts
|
||||
async def async_event_gen():
|
||||
raise Exception("Test error")
|
||||
|
||||
sse_gen = sse_generator(async_event_gen())
|
||||
assert sse_gen is not None
|
||||
|
||||
seen_events = []
|
||||
async for event in sse_gen:
|
||||
seen_events.append(event)
|
||||
|
||||
# We should have 1 error event
|
||||
assert len(seen_events) == 1
|
||||
assert 'data: {"error":' in seen_events[0]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue