Fix server conditional awaiting on coroutines

This commit is contained in:
Ashwin Bharambe 2024-10-08 17:03:31 -07:00
parent 216e7eb4d5
commit 8eee5b9adc
2 changed files with 9 additions and 5 deletions

View file

@ -193,6 +193,12 @@ def is_streaming_request(func_name: str, request: Request, **kwargs):
return kwargs.get("stream", False)
async def maybe_await(value):
if inspect.iscoroutine(value):
return await value
return value
async def sse_generator(event_gen):
try:
async for item in event_gen:
@ -228,11 +234,8 @@ def create_dynamic_typed_route(func: Any, method: str):
sse_generator(func(**kwargs)), media_type="text/event-stream"
)
else:
return (
await func(**kwargs)
if asyncio.iscoroutinefunction(func)
else func(**kwargs)
)
value = func(**kwargs)
return await maybe_await(value)
except Exception as e:
traceback.print_exception(e)
raise translate_exception(e) from e

View file

@ -34,6 +34,7 @@ class MetaReferenceInferenceImpl(Inference):
# verify that the checkpoint actually is for this model lol
async def initialize(self) -> None:
print(f"Loading model `{self.model.descriptor()}`")
self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start()