work with 2 different models

This commit is contained in:
Xi Yan 2024-09-19 21:40:37 -07:00
parent 7071c46422
commit 4b083eec03
2 changed files with 37 additions and 19 deletions

View file

@ -100,6 +100,17 @@ async def run_main(host: str, port: int, stream: bool):
async for log in EventLogger().log(iterator):
log.print()
cprint(f"User>{message.content}", "green")
iterator = client.chat_completion(
ChatCompletionRequest(
model="Meta-Llama3.1-8B",
messages=[message],
stream=stream,
)
)
async for log in EventLogger().log(iterator):
log.print()
def main(host: str, port: int, stream: bool = True):
asyncio.run(run_main(host, port, stream))

View file

@ -26,7 +26,7 @@ class InferenceRouterImpl(Inference):
models_api: Models,
) -> None:
# map of model_id to provider impl
self.providers = {}
self.model2providers = {}
self.models_api = models_api
async def initialize(self) -> None:
@ -42,6 +42,8 @@ class InferenceRouterImpl(Inference):
raise ValueError(
f"provider_id {model_spec.provider_id} is not available for inference. Please check run.yaml config spec to define a valid provider"
)
# look up and initialize provider implementations for each model
impl = await instantiate_provider(
inference_providers[model_spec.provider_id],
deps=[],
@ -50,9 +52,10 @@ class InferenceRouterImpl(Inference):
config=model_spec.provider_config,
),
)
cprint(f"impl={impl}", "blue")
# look up and initialize provider implementations for each model
core_model_id = model_spec.llama_model_metadata.core_model_id
self.model2providers[core_model_id.value] = impl
cprint(self.model2providers, "blue")
async def shutdown(self) -> None:
pass
@ -72,20 +75,24 @@ class InferenceRouterImpl(Inference):
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
print("router chat_completion")
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta="router chat completion",
if model not in self.model2providers:
raise ValueError(
f"Cannot find model {model} in running distribution. Please use register model first"
)
)
# async for chunk in self.providers[model].chat_completion(
# model=model,
# messages=messages,
# sampling_params=sampling_params,
# tools=tools,
# tool_choice=tool_choice,
# tool_prompt_format=tool_prompt_format,
# stream=stream,
# logprobs=logprobs,
# ):
# yield chunk
# yield ChatCompletionResponseStreamChunk(
# event=ChatCompletionResponseEvent(
# event_type=ChatCompletionResponseEventType.progress,
# delta="router chat completion",
# )
# )
async for chunk in self.model2providers[model].chat_completion(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
):
yield chunk