From 5c6e1e9d1e67115a6dbacf3d1ef58cf1d1a4b7f6 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 16 Jan 2025 16:46:49 -0800 Subject: [PATCH] fix tgi --- llama_stack/providers/remote/inference/tgi/config.py | 8 ++++---- llama_stack/providers/remote/inference/tgi/tgi.py | 6 +++++- tests/client-sdk/inference/test_inference.py | 1 + 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/llama_stack/providers/remote/inference/tgi/config.py b/llama_stack/providers/remote/inference/tgi/config.py index f05005b25..020ca8bf8 100644 --- a/llama_stack/providers/remote/inference/tgi/config.py +++ b/llama_stack/providers/remote/inference/tgi/config.py @@ -15,10 +15,10 @@ class TGIImplConfig(BaseModel): url: str = Field( description="The URL for the TGI serving endpoint", ) - api_token: Optional[SecretStr] = Field( - default=None, - description="A bearer token if your TGI endpoint is protected.", - ) + # api_token: Optional[SecretStr] = Field( + # default=None, + # description="A bearer token if your TGI endpoint is protected.", + # ) @classmethod def sample_run_config(cls, url: str = "${env.TGI_URL}", **kwargs): diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 985fd3606..363bc4299 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -128,6 +128,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): fmt: ResponseFormat = None, ): options = get_sampling_options(sampling_params) + if options["temperature"] == 0: + options["temperature"] = 0.1 + # delete key "max_tokens" from options since its not supported by the API options.pop("max_tokens", None) if fmt: @@ -230,6 +233,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): self, request: ChatCompletionRequest ) -> ChatCompletionResponse: params = await self._get_params(request) + print("TGI params", params) r = await self.client.text_generation(**params) choice = OpenAICompatCompletionChoice( @@ -289,7 +293,7 @@ class TGIAdapter(_HfAdapter): async def initialize(self, config: TGIImplConfig) -> None: log.info(f"Initializing TGI client with url={config.url}") self.client = AsyncInferenceClient( - model=config.url, token=config.api_token.get_secret_value() + model=config.url, ) endpoint_info = await self.client.get_endpoint_info() self.max_tokens = endpoint_info["max_total_tokens"] diff --git a/tests/client-sdk/inference/test_inference.py b/tests/client-sdk/inference/test_inference.py index 671a37926..175e5f1f2 100644 --- a/tests/client-sdk/inference/test_inference.py +++ b/tests/client-sdk/inference/test_inference.py @@ -225,6 +225,7 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming( tool_prompt_format=provider_tool_format, stream=False, ) + print(response) # No content is returned for the system message since we expect the # response to be a tool call assert response.completion_message.content == ""