diff --git a/llama_stack/providers/remote/inference/tgi/config.py b/llama_stack/providers/remote/inference/tgi/config.py index f05005b25..4f690dec6 100644 --- a/llama_stack/providers/remote/inference/tgi/config.py +++ b/llama_stack/providers/remote/inference/tgi/config.py @@ -15,10 +15,6 @@ class TGIImplConfig(BaseModel): url: str = Field( description="The URL for the TGI serving endpoint", ) - api_token: Optional[SecretStr] = Field( - default=None, - description="A bearer token if your TGI endpoint is protected.", - ) @classmethod def sample_run_config(cls, url: str = "${env.TGI_URL}", **kwargs): diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 985fd3606..7f8c9d8ab 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -128,6 +128,12 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): fmt: ResponseFormat = None, ): options = get_sampling_options(sampling_params) + # TGI does not support temperature=0 when using greedy sampling + # We set it to 1e-3 instead, anything lower outputs garbage from TGI + # We can use top_p sampling strategy to specify lower temperature + if abs(options["temperature"]) < 1e-10: + options["temperature"] = 1e-3 + # delete key "max_tokens" from options since its not supported by the API options.pop("max_tokens", None) if fmt: @@ -289,7 +295,7 @@ class TGIAdapter(_HfAdapter): async def initialize(self, config: TGIImplConfig) -> None: log.info(f"Initializing TGI client with url={config.url}") self.client = AsyncInferenceClient( - model=config.url, token=config.api_token.get_secret_value() + model=config.url, ) endpoint_info = await self.client.get_endpoint_info() self.max_tokens = endpoint_info["max_total_tokens"] diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 485779064..f9b55b5cd 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -98,7 +98,7 @@ def agent_config(llama_stack_client): instructions="You are a helpful assistant", sampling_params={ "strategy": { - "type": "greedy", + "type": "top_p", "temperature": 1.0, "top_p": 0.9, },