mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 10:42:39 +00:00
fix tgi
This commit is contained in:
parent
e88faa91e2
commit
5c6e1e9d1e
3 changed files with 10 additions and 5 deletions
|
@ -15,10 +15,10 @@ class TGIImplConfig(BaseModel):
|
||||||
url: str = Field(
|
url: str = Field(
|
||||||
description="The URL for the TGI serving endpoint",
|
description="The URL for the TGI serving endpoint",
|
||||||
)
|
)
|
||||||
api_token: Optional[SecretStr] = Field(
|
# api_token: Optional[SecretStr] = Field(
|
||||||
default=None,
|
# default=None,
|
||||||
description="A bearer token if your TGI endpoint is protected.",
|
# description="A bearer token if your TGI endpoint is protected.",
|
||||||
)
|
# )
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, url: str = "${env.TGI_URL}", **kwargs):
|
def sample_run_config(cls, url: str = "${env.TGI_URL}", **kwargs):
|
||||||
|
|
|
@ -128,6 +128,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
fmt: ResponseFormat = None,
|
fmt: ResponseFormat = None,
|
||||||
):
|
):
|
||||||
options = get_sampling_options(sampling_params)
|
options = get_sampling_options(sampling_params)
|
||||||
|
if options["temperature"] == 0:
|
||||||
|
options["temperature"] = 0.1
|
||||||
|
|
||||||
# delete key "max_tokens" from options since its not supported by the API
|
# delete key "max_tokens" from options since its not supported by the API
|
||||||
options.pop("max_tokens", None)
|
options.pop("max_tokens", None)
|
||||||
if fmt:
|
if fmt:
|
||||||
|
@ -230,6 +233,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
print("TGI params", params)
|
||||||
r = await self.client.text_generation(**params)
|
r = await self.client.text_generation(**params)
|
||||||
|
|
||||||
choice = OpenAICompatCompletionChoice(
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
@ -289,7 +293,7 @@ class TGIAdapter(_HfAdapter):
|
||||||
async def initialize(self, config: TGIImplConfig) -> None:
|
async def initialize(self, config: TGIImplConfig) -> None:
|
||||||
log.info(f"Initializing TGI client with url={config.url}")
|
log.info(f"Initializing TGI client with url={config.url}")
|
||||||
self.client = AsyncInferenceClient(
|
self.client = AsyncInferenceClient(
|
||||||
model=config.url, token=config.api_token.get_secret_value()
|
model=config.url,
|
||||||
)
|
)
|
||||||
endpoint_info = await self.client.get_endpoint_info()
|
endpoint_info = await self.client.get_endpoint_info()
|
||||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||||
|
|
|
@ -225,6 +225,7 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming(
|
||||||
tool_prompt_format=provider_tool_format,
|
tool_prompt_format=provider_tool_format,
|
||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
|
print(response)
|
||||||
# No content is returned for the system message since we expect the
|
# No content is returned for the system message since we expect the
|
||||||
# response to be a tool call
|
# response to be a tool call
|
||||||
assert response.completion_message.content == ""
|
assert response.completion_message.content == ""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue