This commit is contained in:
Xi Yan 2025-01-16 16:46:49 -08:00
parent e88faa91e2
commit 5c6e1e9d1e
3 changed files with 10 additions and 5 deletions

View file

@ -15,10 +15,10 @@ class TGIImplConfig(BaseModel):
url: str = Field( url: str = Field(
description="The URL for the TGI serving endpoint", description="The URL for the TGI serving endpoint",
) )
api_token: Optional[SecretStr] = Field( # api_token: Optional[SecretStr] = Field(
default=None, # default=None,
description="A bearer token if your TGI endpoint is protected.", # description="A bearer token if your TGI endpoint is protected.",
) # )
@classmethod @classmethod
def sample_run_config(cls, url: str = "${env.TGI_URL}", **kwargs): def sample_run_config(cls, url: str = "${env.TGI_URL}", **kwargs):

View file

@ -128,6 +128,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
fmt: ResponseFormat = None, fmt: ResponseFormat = None,
): ):
options = get_sampling_options(sampling_params) options = get_sampling_options(sampling_params)
if options["temperature"] == 0:
options["temperature"] = 0.1
# delete key "max_tokens" from options since its not supported by the API # delete key "max_tokens" from options since its not supported by the API
options.pop("max_tokens", None) options.pop("max_tokens", None)
if fmt: if fmt:
@ -230,6 +233,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
self, request: ChatCompletionRequest self, request: ChatCompletionRequest
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
print("TGI params", params)
r = await self.client.text_generation(**params) r = await self.client.text_generation(**params)
choice = OpenAICompatCompletionChoice( choice = OpenAICompatCompletionChoice(
@ -289,7 +293,7 @@ class TGIAdapter(_HfAdapter):
async def initialize(self, config: TGIImplConfig) -> None: async def initialize(self, config: TGIImplConfig) -> None:
log.info(f"Initializing TGI client with url={config.url}") log.info(f"Initializing TGI client with url={config.url}")
self.client = AsyncInferenceClient( self.client = AsyncInferenceClient(
model=config.url, token=config.api_token.get_secret_value() model=config.url,
) )
endpoint_info = await self.client.get_endpoint_info() endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"] self.max_tokens = endpoint_info["max_total_tokens"]

View file

@ -225,6 +225,7 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming(
tool_prompt_format=provider_tool_format, tool_prompt_format=provider_tool_format,
stream=False, stream=False,
) )
print(response)
# No content is returned for the system message since we expect the # No content is returned for the system message since we expect the
# response to be a tool call # response to be a tool call
assert response.completion_message.content == "" assert response.completion_message.content == ""