diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index cebfabba5..fb2682668 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -128,6 +128,8 @@ class LlamaStackAsLibraryClient(LlamaStackClient): self.skip_logger_removal = skip_logger_removal self.provider_data = provider_data + self.loop = asyncio.new_event_loop() + def initialize(self): if in_notebook(): import nest_asyncio @@ -136,7 +138,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient): if not self.skip_logger_removal: self._remove_root_logger_handlers() - return asyncio.run(self.async_client.initialize()) + return self.loop.run_until_complete(self.async_client.initialize()) def _remove_root_logger_handlers(self): """ @@ -149,10 +151,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient): logger.info(f"Removed handler {handler.__class__.__name__} from root logger") def request(self, *args, **kwargs): - # NOTE: We are using AsyncLlamaStackClient under the hood - # A new event loop is needed to convert the AsyncStream - # from async client into SyncStream return type for streaming - loop = asyncio.new_event_loop() + loop = self.loop asyncio.set_event_loop(loop) if kwargs.get("stream"): @@ -169,7 +168,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient): pending = asyncio.all_tasks(loop) if pending: loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) - loop.close() return sync_generator() else: @@ -179,7 +177,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient): pending = asyncio.all_tasks(loop) if pending: loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) - loop.close() return result