diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 621188284..01981c62b 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -89,8 +89,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) request = CompletionRequest( - model=model_id, + model=model.provider_resource_id, content=content, sampling_params=sampling_params, response_format=response_format, @@ -194,8 +195,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( - model=model_id, + model=model.provider_resource_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -249,7 +251,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): def _get_params(self, request: ChatCompletionRequest) -> dict: prompt, input_tokens = chat_completion_request_to_model_input_info( - request, self.formatter + request, self.register_helper.get_llama_model(request.model), self.formatter ) return dict( prompt=prompt, diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index 2007818e5..a427eef12 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -20,6 +20,7 @@ from llama_stack.providers.remote.inference.bedrock import BedrockConfig from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig +from llama_stack.providers.remote.inference.tgi import TGIImplConfig from llama_stack.providers.remote.inference.together import TogetherImplConfig from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig from llama_stack.providers.tests.resolver import construct_stack_for_test @@ -156,6 +157,22 @@ def inference_nvidia() -> ProviderFixture: ) +@pytest.fixture(scope="session") +def inference_tgi() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="tgi", + provider_type="remote::tgi", + config=TGIImplConfig( + url=get_env_or_fail("TGI_URL"), + api_token=os.getenv("TGI_API_TOKEN", None), + ).model_dump(), + ) + ], + ) + + def get_model_short_name(model_name: str) -> str: """Convert model name to a short test identifier. @@ -190,6 +207,7 @@ INFERENCE_FIXTURES = [ "remote", "bedrock", "nvidia", + "tgi", ]