diff --git a/llama_toolchain/inference/adapters/tgi/tgi.py b/llama_toolchain/inference/adapters/tgi/tgi.py index f5b807b0f..73a72e390 100644 --- a/llama_toolchain/inference/adapters/tgi/tgi.py +++ b/llama_toolchain/inference/adapters/tgi/tgi.py @@ -235,9 +235,10 @@ class TGIAdapter(Inference): class InferenceEndpointAdapter(TGIAdapter): def __init__(self, config: TGIImplConfig) -> None: super().__init__(config) - self.config.url = self._construct_endpoint_url(config.hf_endpoint_name) + self.config.url = self._construct_endpoint_url() - def _construct_endpoint_url(self, hf_endpoint_name: str) -> str: + def _construct_endpoint_url(self) -> str: + hf_endpoint_name = self.config.hf_endpoint_name assert hf_endpoint_name.count("/") <= 1, ( "Endpoint name must be in the format of 'namespace/endpoint_name' " "or 'endpoint_name'"