clean up router inference

This commit is contained in:
Xi Yan 2024-09-19 21:53:27 -07:00
parent f6146f8e58
commit 5d3c02d0fb
2 changed files with 5 additions and 12 deletions

View file

@ -93,7 +93,7 @@ async def run_main(host: str, port: int, stream: bool):
cprint(f"User>{message.content}", "green") cprint(f"User>{message.content}", "green")
iterator = client.chat_completion( iterator = client.chat_completion(
ChatCompletionRequest( ChatCompletionRequest(
model="Meta-Llama3.1-8B", model="Meta-Llama3.1-8B-Instruct",
messages=[message], messages=[message],
stream=stream, stream=stream,
) )
@ -104,7 +104,7 @@ async def run_main(host: str, port: int, stream: bool):
cprint(f"User>{message.content}", "green") cprint(f"User>{message.content}", "green")
iterator = client.chat_completion( iterator = client.chat_completion(
ChatCompletionRequest( ChatCompletionRequest(
model="Meta-Llama3.1-8B-Instruct", model="Meta-Llama3.1-8B",
messages=[message], messages=[message],
stream=stream, stream=stream,
) )

View file

@ -58,9 +58,8 @@ class InferenceRouterImpl(Inference):
cprint(self.model2providers, "blue") cprint(self.model2providers, "blue")
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass for p in self.model2providers.values():
# for p in self.providers.values(): await p.shutdown()
# await p.shutdown()
async def chat_completion( async def chat_completion(
self, self,
@ -74,17 +73,11 @@ class InferenceRouterImpl(Inference):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
print("router chat_completion")
if model not in self.model2providers: if model not in self.model2providers:
raise ValueError( raise ValueError(
f"Cannot find model {model} in running distribution. Please use register model first" f"Cannot find model {model} in running distribution. Please use register model first"
) )
# yield ChatCompletionResponseStreamChunk(
# event=ChatCompletionResponseEvent(
# event_type=ChatCompletionResponseEventType.progress,
# delta="router chat completion",
# )
# )
async for chunk in self.model2providers[model].chat_completion( async for chunk in self.model2providers[model].chat_completion(
model=model, model=model,
messages=messages, messages=messages,