forked from phoenix-oss/llama-stack-mirror
[bugfix] fix streaming GeneratorExit exception with LlamaStackAsLibraryClient (#760)
# What does this PR do? #### Issue - Using Jupyter notebook with LlamaStackAsLibraryClient + streaming gives exception ``` Exception ignored in: <async_generator object HTTP11ConnectionByteStream.__aiter__ at 0x32a95a740> Traceback (most recent call last): File "/opt/anaconda3/envs/fresh/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 404, in _aiter_ yield part RuntimeError: async generator ignored GeneratorExit ``` - Reproduce w/ https://github.com/meta-llama/llama-stack/blob/notebook-streaming-debug/inline.ipynb #### Fix - Issue likely comes from stream_across_asyncio_run_boundary closing connection too soon when interacting in jupyter environment - This uses an alternative way to convert AsyncStream to SyncStream return type by sync version of LlamaStackAsLibraryClient, which calls AsyncLlamaStackAsLibraryClient calling async impls under the hood #### Additional changes - Moved tracing logic into AsyncLlamaStackAsLibraryClient.request s.t. streaming / non-streaming request for LlamaStackAsLibraryClient shares same code ## Test Plan - Test w/ together & fireworks & ollama with streaming and non-streaming using notebook in: https://github.com/meta-llama/llama-stack/blob/notebook-streaming-debug/inline.ipynb - Note: need to restart kernel and run pip install -e . in jupyter interpreter for local code change to take effect <img width="826" alt="image" src="https://github.com/user-attachments/assets/5f90985d-1aee-452c-a599-2157f5654fea" /> ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
This commit is contained in:
parent
2c2969f331
commit
194d12b304
1 changed files with 31 additions and 99 deletions
|
@ -9,12 +9,10 @@ import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import queue
|
|
||||||
import threading
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Generator, get_args, get_origin, Optional, TypeVar
|
from typing import Any, get_args, get_origin, Optional, TypeVar
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import yaml
|
import yaml
|
||||||
|
@ -64,71 +62,6 @@ def in_notebook():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def stream_across_asyncio_run_boundary(
|
|
||||||
async_gen_maker,
|
|
||||||
pool_executor: ThreadPoolExecutor,
|
|
||||||
path: Optional[str] = None,
|
|
||||||
provider_data: Optional[dict[str, Any]] = None,
|
|
||||||
) -> Generator[T, None, None]:
|
|
||||||
result_queue = queue.Queue()
|
|
||||||
stop_event = threading.Event()
|
|
||||||
|
|
||||||
async def consumer():
|
|
||||||
# make sure we make the generator in the event loop context
|
|
||||||
gen = await async_gen_maker()
|
|
||||||
await start_trace(path, {"__location__": "library_client"})
|
|
||||||
if provider_data:
|
|
||||||
set_request_provider_data(
|
|
||||||
{"X-LlamaStack-Provider-Data": json.dumps(provider_data)}
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
async for item in await gen:
|
|
||||||
result_queue.put(item)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error in generator {e}")
|
|
||||||
result_queue.put(e)
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
return
|
|
||||||
finally:
|
|
||||||
result_queue.put(StopIteration)
|
|
||||||
stop_event.set()
|
|
||||||
await end_trace()
|
|
||||||
|
|
||||||
def run_async():
|
|
||||||
# Run our own loop to avoid double async generator cleanup which is done
|
|
||||||
# by asyncio.run()
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
try:
|
|
||||||
task = loop.create_task(consumer())
|
|
||||||
loop.run_until_complete(task)
|
|
||||||
finally:
|
|
||||||
# Handle pending tasks like a generator's athrow()
|
|
||||||
pending = asyncio.all_tasks(loop)
|
|
||||||
if pending:
|
|
||||||
loop.run_until_complete(
|
|
||||||
asyncio.gather(*pending, return_exceptions=True)
|
|
||||||
)
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
future = pool_executor.submit(run_async)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# yield results as they come in
|
|
||||||
while not stop_event.is_set() or not result_queue.empty():
|
|
||||||
try:
|
|
||||||
item = result_queue.get(timeout=0.1)
|
|
||||||
if item is StopIteration:
|
|
||||||
break
|
|
||||||
if isinstance(item, Exception):
|
|
||||||
raise item
|
|
||||||
yield item
|
|
||||||
except queue.Empty:
|
|
||||||
continue
|
|
||||||
finally:
|
|
||||||
future.result()
|
|
||||||
|
|
||||||
|
|
||||||
def convert_pydantic_to_json_value(value: Any) -> Any:
|
def convert_pydantic_to_json_value(value: Any) -> Any:
|
||||||
if isinstance(value, Enum):
|
if isinstance(value, Enum):
|
||||||
return value.value
|
return value.value
|
||||||
|
@ -184,7 +117,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.async_client = AsyncLlamaStackAsLibraryClient(
|
self.async_client = AsyncLlamaStackAsLibraryClient(
|
||||||
config_path_or_template_name, custom_provider_registry
|
config_path_or_template_name, custom_provider_registry, provider_data
|
||||||
)
|
)
|
||||||
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
||||||
self.skip_logger_removal = skip_logger_removal
|
self.skip_logger_removal = skip_logger_removal
|
||||||
|
@ -210,39 +143,30 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
root_logger.removeHandler(handler)
|
root_logger.removeHandler(handler)
|
||||||
print(f"Removed handler {handler.__class__.__name__} from root logger")
|
print(f"Removed handler {handler.__class__.__name__} from root logger")
|
||||||
|
|
||||||
def _get_path(
|
|
||||||
self,
|
|
||||||
cast_to: Any,
|
|
||||||
options: Any,
|
|
||||||
*,
|
|
||||||
stream=False,
|
|
||||||
stream_cls=None,
|
|
||||||
):
|
|
||||||
return options.url
|
|
||||||
|
|
||||||
def request(self, *args, **kwargs):
|
def request(self, *args, **kwargs):
|
||||||
path = self._get_path(*args, **kwargs)
|
|
||||||
if kwargs.get("stream"):
|
if kwargs.get("stream"):
|
||||||
return stream_across_asyncio_run_boundary(
|
# NOTE: We are using AsyncLlamaStackClient under the hood
|
||||||
lambda: self.async_client.request(*args, **kwargs),
|
# A new event loop is needed to convert the AsyncStream
|
||||||
self.pool_executor,
|
# from async client into SyncStream return type for streaming
|
||||||
path=path,
|
loop = asyncio.new_event_loop()
|
||||||
provider_data=self.provider_data,
|
asyncio.set_event_loop(loop)
|
||||||
)
|
|
||||||
else:
|
|
||||||
|
|
||||||
async def _traced_request():
|
def sync_generator():
|
||||||
if self.provider_data:
|
|
||||||
set_request_provider_data(
|
|
||||||
{"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)}
|
|
||||||
)
|
|
||||||
await start_trace(path, {"__location__": "library_client"})
|
|
||||||
try:
|
try:
|
||||||
return await self.async_client.request(*args, **kwargs)
|
async_stream = loop.run_until_complete(
|
||||||
|
self.async_client.request(*args, **kwargs)
|
||||||
|
)
|
||||||
|
while True:
|
||||||
|
chunk = loop.run_until_complete(async_stream.__anext__())
|
||||||
|
yield chunk
|
||||||
|
except StopAsyncIteration:
|
||||||
|
pass
|
||||||
finally:
|
finally:
|
||||||
await end_trace()
|
loop.close()
|
||||||
|
|
||||||
return asyncio.run(_traced_request())
|
return sync_generator()
|
||||||
|
else:
|
||||||
|
return asyncio.run(self.async_client.request(*args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
|
@ -250,9 +174,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
self,
|
self,
|
||||||
config_path_or_template_name: str,
|
config_path_or_template_name: str,
|
||||||
custom_provider_registry: Optional[ProviderRegistry] = None,
|
custom_provider_registry: Optional[ProviderRegistry] = None,
|
||||||
|
provider_data: Optional[dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# when using the library client, we should not log to console since many
|
# when using the library client, we should not log to console since many
|
||||||
# of our logs are intended for server-side usage
|
# of our logs are intended for server-side usage
|
||||||
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
|
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
|
||||||
|
@ -273,6 +197,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
self.config_path_or_template_name = config_path_or_template_name
|
self.config_path_or_template_name = config_path_or_template_name
|
||||||
self.config = config
|
self.config = config
|
||||||
self.custom_provider_registry = custom_provider_registry
|
self.custom_provider_registry = custom_provider_registry
|
||||||
|
self.provider_data = provider_data
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
try:
|
try:
|
||||||
|
@ -329,17 +254,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
if not self.endpoint_impls:
|
if not self.endpoint_impls:
|
||||||
raise ValueError("Client not initialized")
|
raise ValueError("Client not initialized")
|
||||||
|
|
||||||
|
if self.provider_data:
|
||||||
|
set_request_provider_data(
|
||||||
|
{"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)}
|
||||||
|
)
|
||||||
|
await start_trace(options.url, {"__location__": "library_client"})
|
||||||
if stream:
|
if stream:
|
||||||
return self._call_streaming(
|
response = await self._call_streaming(
|
||||||
cast_to=cast_to,
|
cast_to=cast_to,
|
||||||
options=options,
|
options=options,
|
||||||
stream_cls=stream_cls,
|
stream_cls=stream_cls,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return await self._call_non_streaming(
|
response = await self._call_non_streaming(
|
||||||
cast_to=cast_to,
|
cast_to=cast_to,
|
||||||
options=options,
|
options=options,
|
||||||
)
|
)
|
||||||
|
await end_trace()
|
||||||
|
return response
|
||||||
|
|
||||||
async def _call_non_streaming(
|
async def _call_non_streaming(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue