From dc46725f56d6a404f24793c1f7242c6fcdea8e5b Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Wed, 23 Apr 2025 09:44:28 -0400 Subject: [PATCH] fix: properly handle streaming client disconnects (#2000) # What does this PR do? Previously, when a streaming client would disconnect before we were finished streaming the entire response, an error like the below would get raised from the `sse_generator` function in `llama_stack/distribution/server/server.py`: ``` AttributeError: 'coroutine' object has no attribute 'aclose'. Did you mean: 'close'? ``` This was because we were calling `aclose` on a coroutine instead of the awaited value from that coroutine. This change fixes that, so that we save off the awaited value and then can call `aclose` on it if we encounter an `asyncio.CancelledError`, like we see when a client disconnects before we're finished streaming. The other changes in here are to add a simple set of tests for the happy path of our SSE streaming and this client disconnect path. That unfortunately requires adding one more dependency into our unit test section of pyproject.toml since `server.py` requires loading some of the telemetry code for me to test this functionality. ## Test Plan I wrote the tests in `tests/unit/server/test_sse.py` first, verified the client disconnected test failed before my change, and that it passed afterwards. ``` python -m pytest -s -v tests/unit/server/test_sse.py ``` Signed-off-by: Ben Browning --- llama_stack/distribution/server/server.py | 5 ++- pyproject.toml | 11 ++++- tests/unit/server/test_sse.py | 55 +++++++++++++++++++++++ uv.lock | 2 + 4 files changed, 70 insertions(+), 3 deletions(-) create mode 100644 tests/unit/server/test_sse.py diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 6c5e2506c..50cf44ec9 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -162,9 +162,10 @@ async def maybe_await(value): return value -async def sse_generator(event_gen): +async def sse_generator(event_gen_coroutine): + event_gen = await event_gen_coroutine try: - async for item in await event_gen: + async for item in event_gen: yield create_sse_event(item) await asyncio.sleep(0.01) except asyncio.CancelledError: diff --git a/pyproject.toml b/pyproject.toml index 47d845c30..209367c4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,16 @@ dev = [ "ruamel.yaml", # needed for openapi generator ] # These are the dependencies required for running unit tests. -unit = ["sqlite-vec", "openai", "aiosqlite", "aiohttp", "pypdf", "chardet", "qdrant-client"] +unit = [ + "sqlite-vec", + "openai", + "aiosqlite", + "aiohttp", + "pypdf", + "chardet", + "qdrant-client", + "opentelemetry-exporter-otlp-proto-http" +] # These are the core dependencies required for running integration tests. They are shared across all # providers. If a provider requires additional dependencies, please add them to your environment # separately. If you are using "uv" to execute your tests, you can use the "--with" flag to specify extra diff --git a/tests/unit/server/test_sse.py b/tests/unit/server/test_sse.py new file mode 100644 index 000000000..4a76bdc9b --- /dev/null +++ b/tests/unit/server/test_sse.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio + +import pytest + +from llama_stack.distribution.server.server import create_sse_event, sse_generator + + +@pytest.mark.asyncio +async def test_sse_generator_basic(): + # An AsyncIterator wrapped in an Awaitable, just like our web methods + async def async_event_gen(): + async def event_gen(): + yield "Test event 1" + yield "Test event 2" + + return event_gen() + + sse_gen = sse_generator(async_event_gen()) + assert sse_gen is not None + + # Test that the events are streamed correctly + seen_events = [] + async for event in sse_gen: + seen_events.append(event) + assert len(seen_events) == 2 + assert seen_events[0] == create_sse_event("Test event 1") + assert seen_events[1] == create_sse_event("Test event 2") + + +@pytest.mark.asyncio +async def test_sse_generator_client_disconnected(): + # An AsyncIterator wrapped in an Awaitable, just like our web methods + async def async_event_gen(): + async def event_gen(): + yield "Test event 1" + # Simulate a client disconnect before emitting event 2 + raise asyncio.CancelledError() + + return event_gen() + + sse_gen = sse_generator(async_event_gen()) + assert sse_gen is not None + + # Start reading the events, ensuring this doesn't raise an exception + seen_events = [] + async for event in sse_gen: + seen_events.append(event) + assert len(seen_events) == 1 + assert seen_events[0] == create_sse_event("Test event 1") diff --git a/uv.lock b/uv.lock index cd82a016c..e6368f131 100644 --- a/uv.lock +++ b/uv.lock @@ -1458,6 +1458,7 @@ unit = [ { name = "aiosqlite" }, { name = "chardet" }, { name = "openai" }, + { name = "opentelemetry-exporter-otlp-proto-http" }, { name = "pypdf" }, { name = "qdrant-client" }, { name = "sqlite-vec" }, @@ -1491,6 +1492,7 @@ requires-dist = [ { name = "openai", marker = "extra == 'test'" }, { name = "openai", marker = "extra == 'unit'" }, { name = "opentelemetry-exporter-otlp-proto-http", marker = "extra == 'test'" }, + { name = "opentelemetry-exporter-otlp-proto-http", marker = "extra == 'unit'" }, { name = "opentelemetry-sdk", marker = "extra == 'test'" }, { name = "pandas", marker = "extra == 'ui'" }, { name = "pillow" },