diff --git a/llama_stack/apis/inference/client.py b/llama_stack/apis/inference/client.py index 4d67fb4f6..51cc586fe 100644 --- a/llama_stack/apis/inference/client.py +++ b/llama_stack/apis/inference/client.py @@ -100,6 +100,17 @@ async def run_main(host: str, port: int, stream: bool): async for log in EventLogger().log(iterator): log.print() + cprint(f"User>{message.content}", "green") + iterator = client.chat_completion( + ChatCompletionRequest( + model="Meta-Llama3.1-8B", + messages=[message], + stream=stream, + ) + ) + async for log in EventLogger().log(iterator): + log.print() + def main(host: str, port: int, stream: bool = True): asyncio.run(run_main(host, port, stream)) diff --git a/llama_stack/providers/routers/inference/inference.py b/llama_stack/providers/routers/inference/inference.py index dbf2f3952..48d4f6f69 100644 --- a/llama_stack/providers/routers/inference/inference.py +++ b/llama_stack/providers/routers/inference/inference.py @@ -26,7 +26,7 @@ class InferenceRouterImpl(Inference): models_api: Models, ) -> None: # map of model_id to provider impl - self.providers = {} + self.model2providers = {} self.models_api = models_api async def initialize(self) -> None: @@ -42,6 +42,8 @@ class InferenceRouterImpl(Inference): raise ValueError( f"provider_id {model_spec.provider_id} is not available for inference. Please check run.yaml config spec to define a valid provider" ) + + # look up and initialize provider implementations for each model impl = await instantiate_provider( inference_providers[model_spec.provider_id], deps=[], @@ -50,9 +52,10 @@ class InferenceRouterImpl(Inference): config=model_spec.provider_config, ), ) - cprint(f"impl={impl}", "blue") - # look up and initialize provider implementations for each model core_model_id = model_spec.llama_model_metadata.core_model_id + self.model2providers[core_model_id.value] = impl + + cprint(self.model2providers, "blue") async def shutdown(self) -> None: pass @@ -72,20 +75,24 @@ class InferenceRouterImpl(Inference): logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: print("router chat_completion") - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta="router chat completion", + if model not in self.model2providers: + raise ValueError( + f"Cannot find model {model} in running distribution. Please use register model first" ) - ) - # async for chunk in self.providers[model].chat_completion( - # model=model, - # messages=messages, - # sampling_params=sampling_params, - # tools=tools, - # tool_choice=tool_choice, - # tool_prompt_format=tool_prompt_format, - # stream=stream, - # logprobs=logprobs, - # ): - # yield chunk + # yield ChatCompletionResponseStreamChunk( + # event=ChatCompletionResponseEvent( + # event_type=ChatCompletionResponseEventType.progress, + # delta="router chat completion", + # ) + # ) + async for chunk in self.model2providers[model].chat_completion( + model=model, + messages=messages, + sampling_params=sampling_params, + tools=tools, + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + stream=stream, + logprobs=logprobs, + ): + yield chunk