forked from phoenix-oss/llama-stack-mirror
Tgi fixture (#519)
# What does this PR do? * Add a test fixture for tgi * Fixes the logic to correctly pass the llama model for chat completion Fixes #514 ## Test Plan pytest -k "tgi" llama_stack/providers/tests/inference/test_text_inference.py --env TGI_URL=http://localhost:$INFERENCE_PORT --env TGI_API_TOKEN=$HF_TOKEN
This commit is contained in:
parent
60cb7f64af
commit
de7af28756
2 changed files with 23 additions and 3 deletions
|
@ -89,8 +89,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = CompletionRequest(
|
||||
model=model_id,
|
||||
model=model.provider_resource_id,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
|
@ -194,8 +195,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model_id,
|
||||
model=model.provider_resource_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
|
@ -249,7 +251,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
prompt, input_tokens = chat_completion_request_to_model_input_info(
|
||||
request, self.formatter
|
||||
request, self.register_helper.get_llama_model(request.model), self.formatter
|
||||
)
|
||||
return dict(
|
||||
prompt=prompt,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue