diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 6c5e2506c..50cf44ec9 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -162,9 +162,10 @@ async def maybe_await(value): return value -async def sse_generator(event_gen): +async def sse_generator(event_gen_coroutine): + event_gen = await event_gen_coroutine try: - async for item in await event_gen: + async for item in event_gen: yield create_sse_event(item) await asyncio.sleep(0.01) except asyncio.CancelledError: diff --git a/pyproject.toml b/pyproject.toml index 47d845c30..209367c4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,16 @@ dev = [ "ruamel.yaml", # needed for openapi generator ] # These are the dependencies required for running unit tests. -unit = ["sqlite-vec", "openai", "aiosqlite", "aiohttp", "pypdf", "chardet", "qdrant-client"] +unit = [ + "sqlite-vec", + "openai", + "aiosqlite", + "aiohttp", + "pypdf", + "chardet", + "qdrant-client", + "opentelemetry-exporter-otlp-proto-http" +] # These are the core dependencies required for running integration tests. They are shared across all # providers. If a provider requires additional dependencies, please add them to your environment # separately. If you are using "uv" to execute your tests, you can use the "--with" flag to specify extra diff --git a/tests/unit/server/test_sse.py b/tests/unit/server/test_sse.py new file mode 100644 index 000000000..4a76bdc9b --- /dev/null +++ b/tests/unit/server/test_sse.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio + +import pytest + +from llama_stack.distribution.server.server import create_sse_event, sse_generator + + +@pytest.mark.asyncio +async def test_sse_generator_basic(): + # An AsyncIterator wrapped in an Awaitable, just like our web methods + async def async_event_gen(): + async def event_gen(): + yield "Test event 1" + yield "Test event 2" + + return event_gen() + + sse_gen = sse_generator(async_event_gen()) + assert sse_gen is not None + + # Test that the events are streamed correctly + seen_events = [] + async for event in sse_gen: + seen_events.append(event) + assert len(seen_events) == 2 + assert seen_events[0] == create_sse_event("Test event 1") + assert seen_events[1] == create_sse_event("Test event 2") + + +@pytest.mark.asyncio +async def test_sse_generator_client_disconnected(): + # An AsyncIterator wrapped in an Awaitable, just like our web methods + async def async_event_gen(): + async def event_gen(): + yield "Test event 1" + # Simulate a client disconnect before emitting event 2 + raise asyncio.CancelledError() + + return event_gen() + + sse_gen = sse_generator(async_event_gen()) + assert sse_gen is not None + + # Start reading the events, ensuring this doesn't raise an exception + seen_events = [] + async for event in sse_gen: + seen_events.append(event) + assert len(seen_events) == 1 + assert seen_events[0] == create_sse_event("Test event 1") diff --git a/uv.lock b/uv.lock index cd82a016c..e6368f131 100644 --- a/uv.lock +++ b/uv.lock @@ -1458,6 +1458,7 @@ unit = [ { name = "aiosqlite" }, { name = "chardet" }, { name = "openai" }, + { name = "opentelemetry-exporter-otlp-proto-http" }, { name = "pypdf" }, { name = "qdrant-client" }, { name = "sqlite-vec" }, @@ -1491,6 +1492,7 @@ requires-dist = [ { name = "openai", marker = "extra == 'test'" }, { name = "openai", marker = "extra == 'unit'" }, { name = "opentelemetry-exporter-otlp-proto-http", marker = "extra == 'test'" }, + { name = "opentelemetry-exporter-otlp-proto-http", marker = "extra == 'unit'" }, { name = "opentelemetry-sdk", marker = "extra == 'test'" }, { name = "pandas", marker = "extra == 'ui'" }, { name = "pillow" },