Tgi fixture (#519)

# What does this PR do?

* Add a test fixture for tgi
* Fixes the logic to correctly pass the llama model for chat completion

Fixes #514

## Test Plan

pytest -k "tgi"
llama_stack/providers/tests/inference/test_text_inference.py --env
TGI_URL=http://localhost:$INFERENCE_PORT --env TGI_API_TOKEN=$HF_TOKEN
This commit is contained in:
Dinesh Yeduguru 2024-11-25 13:17:02 -08:00 committed by GitHub
parent 60cb7f64af
commit de7af28756
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 23 additions and 3 deletions

View file

@ -89,8 +89,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model_id,
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
@ -194,8 +195,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model_id,
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
@ -249,7 +251,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
def _get_params(self, request: ChatCompletionRequest) -> dict:
prompt, input_tokens = chat_completion_request_to_model_input_info(
request, self.formatter
request, self.register_helper.get_llama_model(request.model), self.formatter
)
return dict(
prompt=prompt,