diff --git a/docs/cli_reference.md b/docs/cli_reference.md index aca750224..7457aa45e 100644 --- a/docs/cli_reference.md +++ b/docs/cli_reference.md @@ -288,7 +288,7 @@ llama stack list-distributions |--------------------------------|---------------------------------------|-------------------------------------------------------------------------------------------| | local-plus-tgi-inference | { | Use TGI (local or with [Hugging Face Inference Endpoints](https://huggingface.co/ | | | "inference": "remote::tgi", | inference-endpoints/dedicated)) for running LLM inference. When using HF Inference | -| | "safety": "meta-reference", | Endpoints, please provide hf_namespace (username or organization name) and endpoint name. | +| | "safety": "meta-reference", | Endpoints, you must provide the name of the endpoint. | | | "agentic_system": "meta-reference", | | | | "memory": "meta-reference-faiss" | | | | } | | diff --git a/llama_toolchain/inference/adapters/tgi/config.py b/llama_toolchain/inference/adapters/tgi/config.py index 93accd6e1..cb1ec8400 100644 --- a/llama_toolchain/inference/adapters/tgi/config.py +++ b/llama_toolchain/inference/adapters/tgi/config.py @@ -6,6 +6,7 @@ from typing import Optional +from huggingface_hub import HfApi from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field @@ -20,17 +21,16 @@ class TGIImplConfig(BaseModel): default=None, description="The HF token for Hugging Face Inference Endpoints (will default to locally saved token if not provided)", ) - hf_namespace: Optional[str] = Field( - default=None, - description="The username/organization name for the Hugging Face Inference Endpoint", - ) hf_endpoint_name: Optional[str] = Field( default=None, - description="The name of the Hugging Face Inference Endpoint", + description="The name of the Hugging Face Inference Endpoint : can be either in the format of '{namespace}/{endpoint_name}' (namespace can be the username or organization name) or just '{endpoint_name}' if logged into the same account as the namespace", ) def is_inference_endpoint(self) -> bool: - return self.hf_namespace is not None and self.hf_endpoint_name is not None + return self.hf_endpoint_name is not None + + def get_namespace(self) -> str: + return HfApi().whoami()["name"] def is_local_tgi(self) -> bool: return self.url is not None and self.url.startswith("http://localhost") diff --git a/llama_toolchain/inference/adapters/tgi/tgi.py b/llama_toolchain/inference/adapters/tgi/tgi.py index ba188a539..f5b807b0f 100644 --- a/llama_toolchain/inference/adapters/tgi/tgi.py +++ b/llama_toolchain/inference/adapters/tgi/tgi.py @@ -235,7 +235,19 @@ class TGIAdapter(Inference): class InferenceEndpointAdapter(TGIAdapter): def __init__(self, config: TGIImplConfig) -> None: super().__init__(config) - self.config.url = f"https://api.endpoints.huggingface.cloud/v2/endpoint/{config.hf_namespace}/{config.hf_endpoint_name}" + self.config.url = self._construct_endpoint_url(config.hf_endpoint_name) + + def _construct_endpoint_url(self, hf_endpoint_name: str) -> str: + assert hf_endpoint_name.count("/") <= 1, ( + "Endpoint name must be in the format of 'namespace/endpoint_name' " + "or 'endpoint_name'" + ) + if "/" not in hf_endpoint_name: + hf_namespace: str = self.config.get_namespace() + endpoint_path = f"{hf_namespace}/{hf_endpoint_name}" + else: + endpoint_path = hf_endpoint_name + return f"https://api.endpoints.huggingface.cloud/v2/endpoint/{endpoint_path}" @property def client(self) -> InferenceClient: