[bugfix] fix streaming GeneratorExit exception with LlamaStackAsLibraryClient (#760)

# What does this PR do?

#### Issue
- Using Jupyter notebook with LlamaStackAsLibraryClient + streaming
gives exception
```
Exception ignored in: <async_generator object HTTP11ConnectionByteStream.__aiter__ at 0x32a95a740>
Traceback (most recent call last):
  File "/opt/anaconda3/envs/fresh/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 404, in _aiter_
    yield part
RuntimeError: async generator ignored GeneratorExit
```

- Reproduce w/
https://github.com/meta-llama/llama-stack/blob/notebook-streaming-debug/inline.ipynb

#### Fix
- Issue likely comes from stream_across_asyncio_run_boundary closing
connection too soon when interacting in jupyter environment
- This uses an alternative way to convert AsyncStream to SyncStream
return type by sync version of LlamaStackAsLibraryClient, which calls
AsyncLlamaStackAsLibraryClient calling async impls under the hood

#### Additional changes
- Moved tracing logic into AsyncLlamaStackAsLibraryClient.request s.t.
streaming / non-streaming request for LlamaStackAsLibraryClient shares
same code

## Test Plan

- Test w/ together & fireworks & ollama with streaming and non-streaming
using notebook in:
https://github.com/meta-llama/llama-stack/blob/notebook-streaming-debug/inline.ipynb
- Note: need to restart kernel and run pip install -e . in jupyter
interpreter for local code change to take effect

<img width="826" alt="image"
src="https://github.com/user-attachments/assets/5f90985d-1aee-452c-a599-2157f5654fea"
/>


## Sources

Please link relevant resources if necessary.


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
This commit is contained in:
Xi Yan 2025-01-14 10:58:46 -08:00 committed by GitHub
parent 2c2969f331
commit 194d12b304
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -9,12 +9,10 @@ import inspect
import json import json
import logging import logging
import os import os
import queue
import threading
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, Generator, get_args, get_origin, Optional, TypeVar from typing import Any, get_args, get_origin, Optional, TypeVar
import httpx import httpx
import yaml import yaml
@ -64,71 +62,6 @@ def in_notebook():
return True return True
def stream_across_asyncio_run_boundary(
async_gen_maker,
pool_executor: ThreadPoolExecutor,
path: Optional[str] = None,
provider_data: Optional[dict[str, Any]] = None,
) -> Generator[T, None, None]:
result_queue = queue.Queue()
stop_event = threading.Event()
async def consumer():
# make sure we make the generator in the event loop context
gen = await async_gen_maker()
await start_trace(path, {"__location__": "library_client"})
if provider_data:
set_request_provider_data(
{"X-LlamaStack-Provider-Data": json.dumps(provider_data)}
)
try:
async for item in await gen:
result_queue.put(item)
except Exception as e:
print(f"Error in generator {e}")
result_queue.put(e)
except asyncio.CancelledError:
return
finally:
result_queue.put(StopIteration)
stop_event.set()
await end_trace()
def run_async():
# Run our own loop to avoid double async generator cleanup which is done
# by asyncio.run()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
task = loop.create_task(consumer())
loop.run_until_complete(task)
finally:
# Handle pending tasks like a generator's athrow()
pending = asyncio.all_tasks(loop)
if pending:
loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True)
)
loop.close()
future = pool_executor.submit(run_async)
try:
# yield results as they come in
while not stop_event.is_set() or not result_queue.empty():
try:
item = result_queue.get(timeout=0.1)
if item is StopIteration:
break
if isinstance(item, Exception):
raise item
yield item
except queue.Empty:
continue
finally:
future.result()
def convert_pydantic_to_json_value(value: Any) -> Any: def convert_pydantic_to_json_value(value: Any) -> Any:
if isinstance(value, Enum): if isinstance(value, Enum):
return value.value return value.value
@ -184,7 +117,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
): ):
super().__init__() super().__init__()
self.async_client = AsyncLlamaStackAsLibraryClient( self.async_client = AsyncLlamaStackAsLibraryClient(
config_path_or_template_name, custom_provider_registry config_path_or_template_name, custom_provider_registry, provider_data
) )
self.pool_executor = ThreadPoolExecutor(max_workers=4) self.pool_executor = ThreadPoolExecutor(max_workers=4)
self.skip_logger_removal = skip_logger_removal self.skip_logger_removal = skip_logger_removal
@ -210,39 +143,30 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
root_logger.removeHandler(handler) root_logger.removeHandler(handler)
print(f"Removed handler {handler.__class__.__name__} from root logger") print(f"Removed handler {handler.__class__.__name__} from root logger")
def _get_path(
self,
cast_to: Any,
options: Any,
*,
stream=False,
stream_cls=None,
):
return options.url
def request(self, *args, **kwargs): def request(self, *args, **kwargs):
path = self._get_path(*args, **kwargs)
if kwargs.get("stream"): if kwargs.get("stream"):
return stream_across_asyncio_run_boundary( # NOTE: We are using AsyncLlamaStackClient under the hood
lambda: self.async_client.request(*args, **kwargs), # A new event loop is needed to convert the AsyncStream
self.pool_executor, # from async client into SyncStream return type for streaming
path=path, loop = asyncio.new_event_loop()
provider_data=self.provider_data, asyncio.set_event_loop(loop)
)
else:
async def _traced_request(): def sync_generator():
if self.provider_data:
set_request_provider_data(
{"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)}
)
await start_trace(path, {"__location__": "library_client"})
try: try:
return await self.async_client.request(*args, **kwargs) async_stream = loop.run_until_complete(
self.async_client.request(*args, **kwargs)
)
while True:
chunk = loop.run_until_complete(async_stream.__anext__())
yield chunk
except StopAsyncIteration:
pass
finally: finally:
await end_trace() loop.close()
return asyncio.run(_traced_request()) return sync_generator()
else:
return asyncio.run(self.async_client.request(*args, **kwargs))
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
@ -250,9 +174,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
self, self,
config_path_or_template_name: str, config_path_or_template_name: str,
custom_provider_registry: Optional[ProviderRegistry] = None, custom_provider_registry: Optional[ProviderRegistry] = None,
provider_data: Optional[dict[str, Any]] = None,
): ):
super().__init__() super().__init__()
# when using the library client, we should not log to console since many # when using the library client, we should not log to console since many
# of our logs are intended for server-side usage # of our logs are intended for server-side usage
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",") current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
@ -273,6 +197,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
self.config_path_or_template_name = config_path_or_template_name self.config_path_or_template_name = config_path_or_template_name
self.config = config self.config = config
self.custom_provider_registry = custom_provider_registry self.custom_provider_registry = custom_provider_registry
self.provider_data = provider_data
async def initialize(self): async def initialize(self):
try: try:
@ -329,17 +254,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
if not self.endpoint_impls: if not self.endpoint_impls:
raise ValueError("Client not initialized") raise ValueError("Client not initialized")
if self.provider_data:
set_request_provider_data(
{"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)}
)
await start_trace(options.url, {"__location__": "library_client"})
if stream: if stream:
return self._call_streaming( response = await self._call_streaming(
cast_to=cast_to, cast_to=cast_to,
options=options, options=options,
stream_cls=stream_cls, stream_cls=stream_cls,
) )
else: else:
return await self._call_non_streaming( response = await self._call_non_streaming(
cast_to=cast_to, cast_to=cast_to,
options=options, options=options,
) )
await end_trace()
return response
async def _call_non_streaming( async def _call_non_streaming(
self, self,