Tgi fixture (#519)

# What does this PR do?

* Add a test fixture for tgi
* Fixes the logic to correctly pass the llama model for chat completion

Fixes #514

## Test Plan

pytest -k "tgi"
llama_stack/providers/tests/inference/test_text_inference.py --env
TGI_URL=http://localhost:$INFERENCE_PORT --env TGI_API_TOKEN=$HF_TOKEN
This commit is contained in:
Dinesh Yeduguru 2024-11-25 13:17:02 -08:00 committed by GitHub
parent 60cb7f64af
commit de7af28756
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 23 additions and 3 deletions

View file

@ -89,8 +89,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model_id,
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
@ -194,8 +195,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model_id,
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
@ -249,7 +251,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
def _get_params(self, request: ChatCompletionRequest) -> dict:
prompt, input_tokens = chat_completion_request_to_model_input_info(
request, self.formatter
request, self.register_helper.get_llama_model(request.model), self.formatter
)
return dict(
prompt=prompt,

View file

@ -20,6 +20,7 @@ from llama_stack.providers.remote.inference.bedrock import BedrockConfig
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.providers.remote.inference.tgi import TGIImplConfig
from llama_stack.providers.remote.inference.together import TogetherImplConfig
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test
@ -156,6 +157,22 @@ def inference_nvidia() -> ProviderFixture:
)
@pytest.fixture(scope="session")
def inference_tgi() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="tgi",
provider_type="remote::tgi",
config=TGIImplConfig(
url=get_env_or_fail("TGI_URL"),
api_token=os.getenv("TGI_API_TOKEN", None),
).model_dump(),
)
],
)
def get_model_short_name(model_name: str) -> str:
"""Convert model name to a short test identifier.
@ -190,6 +207,7 @@ INFERENCE_FIXTURES = [
"remote",
"bedrock",
"nvidia",
"tgi",
]