clean up router inference

This commit is contained in:
Xi Yan 2024-09-19 21:53:27 -07:00
parent f6146f8e58
commit 5d3c02d0fb
2 changed files with 5 additions and 12 deletions

View file

@ -93,7 +93,7 @@ async def run_main(host: str, port: int, stream: bool):
cprint(f"User>{message.content}", "green")
iterator = client.chat_completion(
ChatCompletionRequest(
model="Meta-Llama3.1-8B",
model="Meta-Llama3.1-8B-Instruct",
messages=[message],
stream=stream,
)
@ -104,7 +104,7 @@ async def run_main(host: str, port: int, stream: bool):
cprint(f"User>{message.content}", "green")
iterator = client.chat_completion(
ChatCompletionRequest(
model="Meta-Llama3.1-8B-Instruct",
model="Meta-Llama3.1-8B",
messages=[message],
stream=stream,
)

View file

@ -58,9 +58,8 @@ class InferenceRouterImpl(Inference):
cprint(self.model2providers, "blue")
async def shutdown(self) -> None:
pass
# for p in self.providers.values():
# await p.shutdown()
for p in self.model2providers.values():
await p.shutdown()
async def chat_completion(
self,
@ -74,17 +73,11 @@ class InferenceRouterImpl(Inference):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
print("router chat_completion")
if model not in self.model2providers:
raise ValueError(
f"Cannot find model {model} in running distribution. Please use register model first"
)
# yield ChatCompletionResponseStreamChunk(
# event=ChatCompletionResponseEvent(
# event_type=ChatCompletionResponseEventType.progress,
# delta="router chat completion",
# )
# )
async for chunk in self.model2providers[model].chat_completion(
model=model,
messages=messages,