Fix server conditional awaiting on coroutines

This commit is contained in:
Ashwin Bharambe 2024-10-08 17:03:31 -07:00
parent 216e7eb4d5
commit 8eee5b9adc
2 changed files with 9 additions and 5 deletions

View file

@ -193,6 +193,12 @@ def is_streaming_request(func_name: str, request: Request, **kwargs):
return kwargs.get("stream", False) return kwargs.get("stream", False)
async def maybe_await(value):
if inspect.iscoroutine(value):
return await value
return value
async def sse_generator(event_gen): async def sse_generator(event_gen):
try: try:
async for item in event_gen: async for item in event_gen:
@ -228,11 +234,8 @@ def create_dynamic_typed_route(func: Any, method: str):
sse_generator(func(**kwargs)), media_type="text/event-stream" sse_generator(func(**kwargs)), media_type="text/event-stream"
) )
else: else:
return ( value = func(**kwargs)
await func(**kwargs) return await maybe_await(value)
if asyncio.iscoroutinefunction(func)
else func(**kwargs)
)
except Exception as e: except Exception as e:
traceback.print_exception(e) traceback.print_exception(e)
raise translate_exception(e) from e raise translate_exception(e) from e

View file

@ -34,6 +34,7 @@ class MetaReferenceInferenceImpl(Inference):
# verify that the checkpoint actually is for this model lol # verify that the checkpoint actually is for this model lol
async def initialize(self) -> None: async def initialize(self) -> None:
print(f"Loading model `{self.model.descriptor()}`")
self.generator = LlamaModelParallelGenerator(self.config) self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start() self.generator.start()