From 2837dd63cbc34cf18a96fcbf650d779556a15aeb Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sat, 7 Dec 2024 14:15:31 -0800 Subject: [PATCH] support GETs with streaming also --- llama_stack/distribution/library_client.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 81d7eae48..4de06ae08 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -96,7 +96,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient): custom_provider_registry: Optional[ProviderRegistry] = None, ): super().__init__() - self.async_client = LlamaStackAsLibraryAsyncClient( + self.async_client = AsyncLlamaStackAsLibraryClient( config_path_or_template_name, custom_provider_registry ) self.pool_executor = ThreadPoolExecutor(max_workers=4) @@ -105,8 +105,13 @@ class LlamaStackAsLibraryClient(LlamaStackClient): asyncio.run(self.async_client.initialize()) def get(self, *args, **kwargs): - assert not kwargs.get("stream"), "GET never called with stream=True" - return asyncio.run(self.async_client.get(*args, **kwargs)) + if kwargs.get("stream"): + return stream_across_asyncio_run_boundary( + lambda: self.async_client.get(*args, **kwargs), + self.pool_executor, + ) + else: + return asyncio.run(self.async_client.get(*args, **kwargs)) def post(self, *args, **kwargs): if kwargs.get("stream"): @@ -118,7 +123,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient): return asyncio.run(self.async_client.post(*args, **kwargs)) -class LlamaStackAsLibraryAsyncClient(AsyncLlamaStackClient): +class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): def __init__( self, config_path_or_template_name: str,