forked from phoenix-oss/llama-stack-mirror
Fix tgi adapter (#796)
# What does this PR do? - Fix TGI adapter ## Test Plan <img width="851" alt="image" src="https://github.com/user-attachments/assets/0084cbc6-6713-4079-b87b-0befd9aca0b0" /> - most inference working - agent test failure due to model outputs ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
This commit is contained in:
parent
73215460ba
commit
0fefd4390a
3 changed files with 8 additions and 6 deletions
|
@ -15,10 +15,6 @@ class TGIImplConfig(BaseModel):
|
||||||
url: str = Field(
|
url: str = Field(
|
||||||
description="The URL for the TGI serving endpoint",
|
description="The URL for the TGI serving endpoint",
|
||||||
)
|
)
|
||||||
api_token: Optional[SecretStr] = Field(
|
|
||||||
default=None,
|
|
||||||
description="A bearer token if your TGI endpoint is protected.",
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, url: str = "${env.TGI_URL}", **kwargs):
|
def sample_run_config(cls, url: str = "${env.TGI_URL}", **kwargs):
|
||||||
|
|
|
@ -128,6 +128,12 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
fmt: ResponseFormat = None,
|
fmt: ResponseFormat = None,
|
||||||
):
|
):
|
||||||
options = get_sampling_options(sampling_params)
|
options = get_sampling_options(sampling_params)
|
||||||
|
# TGI does not support temperature=0 when using greedy sampling
|
||||||
|
# We set it to 1e-3 instead, anything lower outputs garbage from TGI
|
||||||
|
# We can use top_p sampling strategy to specify lower temperature
|
||||||
|
if abs(options["temperature"]) < 1e-10:
|
||||||
|
options["temperature"] = 1e-3
|
||||||
|
|
||||||
# delete key "max_tokens" from options since its not supported by the API
|
# delete key "max_tokens" from options since its not supported by the API
|
||||||
options.pop("max_tokens", None)
|
options.pop("max_tokens", None)
|
||||||
if fmt:
|
if fmt:
|
||||||
|
@ -289,7 +295,7 @@ class TGIAdapter(_HfAdapter):
|
||||||
async def initialize(self, config: TGIImplConfig) -> None:
|
async def initialize(self, config: TGIImplConfig) -> None:
|
||||||
log.info(f"Initializing TGI client with url={config.url}")
|
log.info(f"Initializing TGI client with url={config.url}")
|
||||||
self.client = AsyncInferenceClient(
|
self.client = AsyncInferenceClient(
|
||||||
model=config.url, token=config.api_token.get_secret_value()
|
model=config.url,
|
||||||
)
|
)
|
||||||
endpoint_info = await self.client.get_endpoint_info()
|
endpoint_info = await self.client.get_endpoint_info()
|
||||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||||
|
|
|
@ -98,7 +98,7 @@ def agent_config(llama_stack_client):
|
||||||
instructions="You are a helpful assistant",
|
instructions="You are a helpful assistant",
|
||||||
sampling_params={
|
sampling_params={
|
||||||
"strategy": {
|
"strategy": {
|
||||||
"type": "greedy",
|
"type": "top_p",
|
||||||
"temperature": 1.0,
|
"temperature": 1.0,
|
||||||
"top_p": 0.9,
|
"top_p": 0.9,
|
||||||
},
|
},
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue