fix tgi to correctly pass llama model

This commit is contained in:
Dinesh Yeduguru 2024-11-24 21:12:57 -08:00
parent 3cace74458
commit 7e6a11d17b

View file

@ -89,8 +89,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = CompletionRequest( request = CompletionRequest(
model=model_id, model=model.provider_resource_id,
content=content, content=content,
sampling_params=sampling_params, sampling_params=sampling_params,
response_format=response_format, response_format=response_format,
@ -194,8 +195,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model_id, model=model.provider_resource_id,
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
tools=tools or [], tools=tools or [],
@ -249,7 +251,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
def _get_params(self, request: ChatCompletionRequest) -> dict: def _get_params(self, request: ChatCompletionRequest) -> dict:
prompt, input_tokens = chat_completion_request_to_model_input_info( prompt, input_tokens = chat_completion_request_to_model_input_info(
request, self.formatter request, self.register_helper.get_llama_model(request.model), self.formatter
) )
return dict( return dict(
prompt=prompt, prompt=prompt,