diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 2942920d4..02f82498b 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -166,14 +166,16 @@ async def maybe_await(value): async def sse_generator(event_gen_coroutine): - event_gen = await event_gen_coroutine + event_gen = None try: + event_gen = await event_gen_coroutine async for item in event_gen: yield create_sse_event(item) await asyncio.sleep(0.01) except asyncio.CancelledError: logger.info("Generator cancelled") - await event_gen.aclose() + if event_gen: + await event_gen.aclose() except Exception as e: logger.exception("Error in sse_generator") yield create_sse_event( diff --git a/tests/unit/server/test_sse.py b/tests/unit/server/test_sse.py index 4a76bdc9b..c78122294 100644 --- a/tests/unit/server/test_sse.py +++ b/tests/unit/server/test_sse.py @@ -47,9 +47,45 @@ async def test_sse_generator_client_disconnected(): sse_gen = sse_generator(async_event_gen()) assert sse_gen is not None - # Start reading the events, ensuring this doesn't raise an exception seen_events = [] async for event in sse_gen: seen_events.append(event) + + # We should see 1 event before the client disconnected assert len(seen_events) == 1 assert seen_events[0] == create_sse_event("Test event 1") + + +@pytest.mark.asyncio +async def test_sse_generator_client_disconnected_before_response_starts(): + # Disconnect before the response starts + async def async_event_gen(): + raise asyncio.CancelledError() + + sse_gen = sse_generator(async_event_gen()) + assert sse_gen is not None + + seen_events = [] + async for event in sse_gen: + seen_events.append(event) + + # No events should be seen since the client disconnected immediately + assert len(seen_events) == 0 + + +@pytest.mark.asyncio +async def test_sse_generator_error_before_response_starts(): + # Raise an error before the response starts + async def async_event_gen(): + raise Exception("Test error") + + sse_gen = sse_generator(async_event_gen()) + assert sse_gen is not None + + seen_events = [] + async for event in sse_gen: + seen_events.append(event) + + # We should have 1 error event + assert len(seen_events) == 1 + assert 'data: {"error":' in seen_events[0]