mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
fix: properly handle streaming client disconnects (#2000)
# What does this PR do? Previously, when a streaming client would disconnect before we were finished streaming the entire response, an error like the below would get raised from the `sse_generator` function in `llama_stack/distribution/server/server.py`: ``` AttributeError: 'coroutine' object has no attribute 'aclose'. Did you mean: 'close'? ``` This was because we were calling `aclose` on a coroutine instead of the awaited value from that coroutine. This change fixes that, so that we save off the awaited value and then can call `aclose` on it if we encounter an `asyncio.CancelledError`, like we see when a client disconnects before we're finished streaming. The other changes in here are to add a simple set of tests for the happy path of our SSE streaming and this client disconnect path. That unfortunately requires adding one more dependency into our unit test section of pyproject.toml since `server.py` requires loading some of the telemetry code for me to test this functionality. ## Test Plan I wrote the tests in `tests/unit/server/test_sse.py` first, verified the client disconnected test failed before my change, and that it passed afterwards. ``` python -m pytest -s -v tests/unit/server/test_sse.py ``` Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
e0fa67c81c
commit
dc46725f56
4 changed files with 70 additions and 3 deletions
|
@ -162,9 +162,10 @@ async def maybe_await(value):
|
|||
return value
|
||||
|
||||
|
||||
async def sse_generator(event_gen):
|
||||
async def sse_generator(event_gen_coroutine):
|
||||
event_gen = await event_gen_coroutine
|
||||
try:
|
||||
async for item in await event_gen:
|
||||
async for item in event_gen:
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
|
|
|
@ -58,7 +58,16 @@ dev = [
|
|||
"ruamel.yaml", # needed for openapi generator
|
||||
]
|
||||
# These are the dependencies required for running unit tests.
|
||||
unit = ["sqlite-vec", "openai", "aiosqlite", "aiohttp", "pypdf", "chardet", "qdrant-client"]
|
||||
unit = [
|
||||
"sqlite-vec",
|
||||
"openai",
|
||||
"aiosqlite",
|
||||
"aiohttp",
|
||||
"pypdf",
|
||||
"chardet",
|
||||
"qdrant-client",
|
||||
"opentelemetry-exporter-otlp-proto-http"
|
||||
]
|
||||
# These are the core dependencies required for running integration tests. They are shared across all
|
||||
# providers. If a provider requires additional dependencies, please add them to your environment
|
||||
# separately. If you are using "uv" to execute your tests, you can use the "--with" flag to specify extra
|
||||
|
|
55
tests/unit/server/test_sse.py
Normal file
55
tests/unit/server/test_sse.py
Normal file
|
@ -0,0 +1,55 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.distribution.server.server import create_sse_event, sse_generator
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_generator_basic():
|
||||
# An AsyncIterator wrapped in an Awaitable, just like our web methods
|
||||
async def async_event_gen():
|
||||
async def event_gen():
|
||||
yield "Test event 1"
|
||||
yield "Test event 2"
|
||||
|
||||
return event_gen()
|
||||
|
||||
sse_gen = sse_generator(async_event_gen())
|
||||
assert sse_gen is not None
|
||||
|
||||
# Test that the events are streamed correctly
|
||||
seen_events = []
|
||||
async for event in sse_gen:
|
||||
seen_events.append(event)
|
||||
assert len(seen_events) == 2
|
||||
assert seen_events[0] == create_sse_event("Test event 1")
|
||||
assert seen_events[1] == create_sse_event("Test event 2")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_generator_client_disconnected():
|
||||
# An AsyncIterator wrapped in an Awaitable, just like our web methods
|
||||
async def async_event_gen():
|
||||
async def event_gen():
|
||||
yield "Test event 1"
|
||||
# Simulate a client disconnect before emitting event 2
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
return event_gen()
|
||||
|
||||
sse_gen = sse_generator(async_event_gen())
|
||||
assert sse_gen is not None
|
||||
|
||||
# Start reading the events, ensuring this doesn't raise an exception
|
||||
seen_events = []
|
||||
async for event in sse_gen:
|
||||
seen_events.append(event)
|
||||
assert len(seen_events) == 1
|
||||
assert seen_events[0] == create_sse_event("Test event 1")
|
2
uv.lock
generated
2
uv.lock
generated
|
@ -1458,6 +1458,7 @@ unit = [
|
|||
{ name = "aiosqlite" },
|
||||
{ name = "chardet" },
|
||||
{ name = "openai" },
|
||||
{ name = "opentelemetry-exporter-otlp-proto-http" },
|
||||
{ name = "pypdf" },
|
||||
{ name = "qdrant-client" },
|
||||
{ name = "sqlite-vec" },
|
||||
|
@ -1491,6 +1492,7 @@ requires-dist = [
|
|||
{ name = "openai", marker = "extra == 'test'" },
|
||||
{ name = "openai", marker = "extra == 'unit'" },
|
||||
{ name = "opentelemetry-exporter-otlp-proto-http", marker = "extra == 'test'" },
|
||||
{ name = "opentelemetry-exporter-otlp-proto-http", marker = "extra == 'unit'" },
|
||||
{ name = "opentelemetry-sdk", marker = "extra == 'test'" },
|
||||
{ name = "pandas", marker = "extra == 'ui'" },
|
||||
{ name = "pillow" },
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue