more correct event loop handling for stream

This commit is contained in:
Xi Yan 2025-01-14 10:49:51 -08:00
parent a69c1b8bb1
commit ebbb57beae

View file

@ -144,8 +144,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
print(f"Removed handler {handler.__class__.__name__} from root logger") print(f"Removed handler {handler.__class__.__name__} from root logger")
def request(self, *args, **kwargs): def request(self, *args, **kwargs):
async_response = asyncio.run(self.async_client.request(*args, **kwargs))
if kwargs.get("stream"): if kwargs.get("stream"):
# NOTE: We are using AsyncLlamaStackClient under the hood # NOTE: We are using AsyncLlamaStackClient under the hood
# A new event loop is needed to convert the AsyncStream # A new event loop is needed to convert the AsyncStream
@ -155,8 +153,11 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
def sync_generator(): def sync_generator():
try: try:
async_stream = loop.run_until_complete(
self.async_client.request(*args, **kwargs)
)
while True: while True:
chunk = loop.run_until_complete(async_response.__anext__()) chunk = loop.run_until_complete(async_stream.__anext__())
yield chunk yield chunk
except StopAsyncIteration: except StopAsyncIteration:
pass pass
@ -165,7 +166,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
return sync_generator() return sync_generator()
else: else:
return async_response return asyncio.run(self.async_client.request(*args, **kwargs))
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):