work with 2 different models

This commit is contained in:
Xi Yan 2024-09-19 21:40:37 -07:00
parent 7071c46422
commit 4b083eec03
2 changed files with 37 additions and 19 deletions

View file

@ -100,6 +100,17 @@ async def run_main(host: str, port: int, stream: bool):
async for log in EventLogger().log(iterator): async for log in EventLogger().log(iterator):
log.print() log.print()
cprint(f"User>{message.content}", "green")
iterator = client.chat_completion(
ChatCompletionRequest(
model="Meta-Llama3.1-8B",
messages=[message],
stream=stream,
)
)
async for log in EventLogger().log(iterator):
log.print()
def main(host: str, port: int, stream: bool = True): def main(host: str, port: int, stream: bool = True):
asyncio.run(run_main(host, port, stream)) asyncio.run(run_main(host, port, stream))

View file

@ -26,7 +26,7 @@ class InferenceRouterImpl(Inference):
models_api: Models, models_api: Models,
) -> None: ) -> None:
# map of model_id to provider impl # map of model_id to provider impl
self.providers = {} self.model2providers = {}
self.models_api = models_api self.models_api = models_api
async def initialize(self) -> None: async def initialize(self) -> None:
@ -42,6 +42,8 @@ class InferenceRouterImpl(Inference):
raise ValueError( raise ValueError(
f"provider_id {model_spec.provider_id} is not available for inference. Please check run.yaml config spec to define a valid provider" f"provider_id {model_spec.provider_id} is not available for inference. Please check run.yaml config spec to define a valid provider"
) )
# look up and initialize provider implementations for each model
impl = await instantiate_provider( impl = await instantiate_provider(
inference_providers[model_spec.provider_id], inference_providers[model_spec.provider_id],
deps=[], deps=[],
@ -50,9 +52,10 @@ class InferenceRouterImpl(Inference):
config=model_spec.provider_config, config=model_spec.provider_config,
), ),
) )
cprint(f"impl={impl}", "blue")
# look up and initialize provider implementations for each model
core_model_id = model_spec.llama_model_metadata.core_model_id core_model_id = model_spec.llama_model_metadata.core_model_id
self.model2providers[core_model_id.value] = impl
cprint(self.model2providers, "blue")
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
@ -72,20 +75,24 @@ class InferenceRouterImpl(Inference):
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
print("router chat_completion") print("router chat_completion")
yield ChatCompletionResponseStreamChunk( if model not in self.model2providers:
event=ChatCompletionResponseEvent( raise ValueError(
event_type=ChatCompletionResponseEventType.progress, f"Cannot find model {model} in running distribution. Please use register model first"
delta="router chat completion",
) )
) # yield ChatCompletionResponseStreamChunk(
# async for chunk in self.providers[model].chat_completion( # event=ChatCompletionResponseEvent(
# model=model, # event_type=ChatCompletionResponseEventType.progress,
# messages=messages, # delta="router chat completion",
# sampling_params=sampling_params, # )
# tools=tools, # )
# tool_choice=tool_choice, async for chunk in self.model2providers[model].chat_completion(
# tool_prompt_format=tool_prompt_format, model=model,
# stream=stream, messages=messages,
# logprobs=logprobs, sampling_params=sampling_params,
# ): tools=tools,
# yield chunk tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
):
yield chunk