fix: properly handle streaming client disconnects (#2000)

# What does this PR do?

Previously, when a streaming client would disconnect before we were
finished streaming the entire response, an error like the below would
get raised from the `sse_generator` function in
`llama_stack/distribution/server/server.py`:

```
AttributeError: 'coroutine' object has no attribute 'aclose'. Did you mean: 'close'?
```

This was because we were calling `aclose` on a coroutine instead of the
awaited value from that coroutine. This change fixes that, so that we
save off the awaited value and then can call `aclose` on it if we
encounter an `asyncio.CancelledError`, like we see when a client
disconnects before we're finished streaming.

The other changes in here are to add a simple set of tests for the happy
path of our SSE streaming and this client disconnect path.

That unfortunately requires adding one more dependency into our unit
test section of pyproject.toml since `server.py` requires loading some
of the telemetry code for me to test this functionality.

## Test Plan

I wrote the tests in `tests/unit/server/test_sse.py` first, verified the
client disconnected test failed before my change, and that it passed
afterwards.

```
python -m pytest -s -v tests/unit/server/test_sse.py
```

Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
Ben Browning 2025-04-23 09:44:28 -04:00 committed by GitHub
parent e0fa67c81c
commit dc46725f56
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 70 additions and 3 deletions

View file

@ -162,9 +162,10 @@ async def maybe_await(value):
return value
async def sse_generator(event_gen):
async def sse_generator(event_gen_coroutine):
event_gen = await event_gen_coroutine
try:
async for item in await event_gen:
async for item in event_gen:
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:

View file

@ -58,7 +58,16 @@ dev = [
"ruamel.yaml", # needed for openapi generator
]
# These are the dependencies required for running unit tests.
unit = ["sqlite-vec", "openai", "aiosqlite", "aiohttp", "pypdf", "chardet", "qdrant-client"]
unit = [
"sqlite-vec",
"openai",
"aiosqlite",
"aiohttp",
"pypdf",
"chardet",
"qdrant-client",
"opentelemetry-exporter-otlp-proto-http"
]
# These are the core dependencies required for running integration tests. They are shared across all
# providers. If a provider requires additional dependencies, please add them to your environment
# separately. If you are using "uv" to execute your tests, you can use the "--with" flag to specify extra

View file

@ -0,0 +1,55 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import pytest
from llama_stack.distribution.server.server import create_sse_event, sse_generator
@pytest.mark.asyncio
async def test_sse_generator_basic():
# An AsyncIterator wrapped in an Awaitable, just like our web methods
async def async_event_gen():
async def event_gen():
yield "Test event 1"
yield "Test event 2"
return event_gen()
sse_gen = sse_generator(async_event_gen())
assert sse_gen is not None
# Test that the events are streamed correctly
seen_events = []
async for event in sse_gen:
seen_events.append(event)
assert len(seen_events) == 2
assert seen_events[0] == create_sse_event("Test event 1")
assert seen_events[1] == create_sse_event("Test event 2")
@pytest.mark.asyncio
async def test_sse_generator_client_disconnected():
# An AsyncIterator wrapped in an Awaitable, just like our web methods
async def async_event_gen():
async def event_gen():
yield "Test event 1"
# Simulate a client disconnect before emitting event 2
raise asyncio.CancelledError()
return event_gen()
sse_gen = sse_generator(async_event_gen())
assert sse_gen is not None
# Start reading the events, ensuring this doesn't raise an exception
seen_events = []
async for event in sse_gen:
seen_events.append(event)
assert len(seen_events) == 1
assert seen_events[0] == create_sse_event("Test event 1")

2
uv.lock generated
View file

@ -1458,6 +1458,7 @@ unit = [
{ name = "aiosqlite" },
{ name = "chardet" },
{ name = "openai" },
{ name = "opentelemetry-exporter-otlp-proto-http" },
{ name = "pypdf" },
{ name = "qdrant-client" },
{ name = "sqlite-vec" },
@ -1491,6 +1492,7 @@ requires-dist = [
{ name = "openai", marker = "extra == 'test'" },
{ name = "openai", marker = "extra == 'unit'" },
{ name = "opentelemetry-exporter-otlp-proto-http", marker = "extra == 'test'" },
{ name = "opentelemetry-exporter-otlp-proto-http", marker = "extra == 'unit'" },
{ name = "opentelemetry-sdk", marker = "extra == 'test'" },
{ name = "pandas", marker = "extra == 'ui'" },
{ name = "pillow" },