From 7e6a11d17be345a2f23739aacd1cd9cf7fb4393c Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Sun, 24 Nov 2024 21:12:57 -0800 Subject: [PATCH] fix tgi to correctly pass llama model --- llama_stack/providers/remote/inference/tgi/tgi.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 621188284..01981c62b 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -89,8 +89,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) request = CompletionRequest( - model=model_id, + model=model.provider_resource_id, content=content, sampling_params=sampling_params, response_format=response_format, @@ -194,8 +195,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( - model=model_id, + model=model.provider_resource_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -249,7 +251,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): def _get_params(self, request: ChatCompletionRequest) -> dict: prompt, input_tokens = chat_completion_request_to_model_input_info( - request, self.formatter + request, self.register_helper.get_llama_model(request.model), self.formatter ) return dict( prompt=prompt,