From 8eee5b9adc3e55e4f2befcbf19e10368e3629f5c Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 8 Oct 2024 17:03:31 -0700 Subject: [PATCH] Fix server conditional awaiting on coroutines --- llama_stack/distribution/server/server.py | 13 ++++++++----- .../impls/meta_reference/inference/inference.py | 1 + 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 5c1a7806d..dd499db6b 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -193,6 +193,12 @@ def is_streaming_request(func_name: str, request: Request, **kwargs): return kwargs.get("stream", False) +async def maybe_await(value): + if inspect.iscoroutine(value): + return await value + return value + + async def sse_generator(event_gen): try: async for item in event_gen: @@ -228,11 +234,8 @@ def create_dynamic_typed_route(func: Any, method: str): sse_generator(func(**kwargs)), media_type="text/event-stream" ) else: - return ( - await func(**kwargs) - if asyncio.iscoroutinefunction(func) - else func(**kwargs) - ) + value = func(**kwargs) + return await maybe_await(value) except Exception as e: traceback.print_exception(e) raise translate_exception(e) from e diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index 43a131647..bda5e54c1 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -34,6 +34,7 @@ class MetaReferenceInferenceImpl(Inference): # verify that the checkpoint actually is for this model lol async def initialize(self) -> None: + print(f"Loading model `{self.model.descriptor()}`") self.generator = LlamaModelParallelGenerator(self.config) self.generator.start()