mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 12:52:36 +00:00
fix tgi to correctly pass llama model
This commit is contained in:
parent
3cace74458
commit
7e6a11d17b
1 changed files with 5 additions and 3 deletions
|
|
@ -89,8 +89,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
request = CompletionRequest(
|
request = CompletionRequest(
|
||||||
model=model_id,
|
model=model.provider_resource_id,
|
||||||
content=content,
|
content=content,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
|
|
@ -194,8 +195,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model_id,
|
model=model.provider_resource_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
|
|
@ -249,7 +251,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||||
prompt, input_tokens = chat_completion_request_to_model_input_info(
|
prompt, input_tokens = chat_completion_request_to_model_input_info(
|
||||||
request, self.formatter
|
request, self.register_helper.get_llama_model(request.model), self.formatter
|
||||||
)
|
)
|
||||||
return dict(
|
return dict(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue