support GETs with streaming also

This commit is contained in:
Ashwin Bharambe 2024-12-07 14:15:31 -08:00
parent 1fba8f80c2
commit 2837dd63cb

View file

@ -96,7 +96,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
custom_provider_registry: Optional[ProviderRegistry] = None,
):
super().__init__()
self.async_client = LlamaStackAsLibraryAsyncClient(
self.async_client = AsyncLlamaStackAsLibraryClient(
config_path_or_template_name, custom_provider_registry
)
self.pool_executor = ThreadPoolExecutor(max_workers=4)
@ -105,8 +105,13 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
asyncio.run(self.async_client.initialize())
def get(self, *args, **kwargs):
assert not kwargs.get("stream"), "GET never called with stream=True"
return asyncio.run(self.async_client.get(*args, **kwargs))
if kwargs.get("stream"):
return stream_across_asyncio_run_boundary(
lambda: self.async_client.get(*args, **kwargs),
self.pool_executor,
)
else:
return asyncio.run(self.async_client.get(*args, **kwargs))
def post(self, *args, **kwargs):
if kwargs.get("stream"):
@ -118,7 +123,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
return asyncio.run(self.async_client.post(*args, **kwargs))
class LlamaStackAsLibraryAsyncClient(AsyncLlamaStackClient):
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
def __init__(
self,
config_path_or_template_name: str,