From 5d3c02d0fb2262bbb5f581cc360589fcb51126c9 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 19 Sep 2024 21:53:27 -0700 Subject: [PATCH] clean up router inference --- llama_stack/apis/inference/client.py | 4 ++-- .../providers/routers/inference/inference.py | 13 +++---------- 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/llama_stack/apis/inference/client.py b/llama_stack/apis/inference/client.py index cdcca8b6b..7ebfa4e73 100644 --- a/llama_stack/apis/inference/client.py +++ b/llama_stack/apis/inference/client.py @@ -93,7 +93,7 @@ async def run_main(host: str, port: int, stream: bool): cprint(f"User>{message.content}", "green") iterator = client.chat_completion( ChatCompletionRequest( - model="Meta-Llama3.1-8B", + model="Meta-Llama3.1-8B-Instruct", messages=[message], stream=stream, ) @@ -104,7 +104,7 @@ async def run_main(host: str, port: int, stream: bool): cprint(f"User>{message.content}", "green") iterator = client.chat_completion( ChatCompletionRequest( - model="Meta-Llama3.1-8B-Instruct", + model="Meta-Llama3.1-8B", messages=[message], stream=stream, ) diff --git a/llama_stack/providers/routers/inference/inference.py b/llama_stack/providers/routers/inference/inference.py index 48d4f6f69..be3d2e434 100644 --- a/llama_stack/providers/routers/inference/inference.py +++ b/llama_stack/providers/routers/inference/inference.py @@ -58,9 +58,8 @@ class InferenceRouterImpl(Inference): cprint(self.model2providers, "blue") async def shutdown(self) -> None: - pass - # for p in self.providers.values(): - # await p.shutdown() + for p in self.model2providers.values(): + await p.shutdown() async def chat_completion( self, @@ -74,17 +73,11 @@ class InferenceRouterImpl(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: - print("router chat_completion") if model not in self.model2providers: raise ValueError( f"Cannot find model {model} in running distribution. Please use register model first" ) - # yield ChatCompletionResponseStreamChunk( - # event=ChatCompletionResponseEvent( - # event_type=ChatCompletionResponseEventType.progress, - # delta="router chat completion", - # ) - # ) + async for chunk in self.model2providers[model].chat_completion( model=model, messages=messages,