diff --git a/llama_stack/distribution/templates/local-hf-endpoint-build.yaml b/llama_stack/distribution/templates/local-hf-endpoint-build.yaml new file mode 100644 index 000000000..e5c4ae8cc --- /dev/null +++ b/llama_stack/distribution/templates/local-hf-endpoint-build.yaml @@ -0,0 +1,10 @@ +name: local-hf-endpoint +distribution_spec: + description: "Like local, but use Hugging Face Inference Endpoints for running LLM inference.\nSee https://hf.co/docs/api-endpoints." + providers: + inference: remote::hf::endpoint + memory: meta-reference + safety: meta-reference + agents: meta-reference + telemetry: meta-reference +image_type: conda diff --git a/llama_stack/distribution/templates/local-hf-serverless-build.yaml b/llama_stack/distribution/templates/local-hf-serverless-build.yaml new file mode 100644 index 000000000..752390b40 --- /dev/null +++ b/llama_stack/distribution/templates/local-hf-serverless-build.yaml @@ -0,0 +1,10 @@ +name: local-hf-serverless +distribution_spec: + description: "Like local, but use Hugging Face Inference API (serverless) for running LLM inference.\nSee https://hf.co/docs/api-inference." + providers: + inference: remote::hf::serverless + memory: meta-reference + safety: meta-reference + agents: meta-reference + telemetry: meta-reference +image_type: conda diff --git a/llama_stack/distribution/templates/local-tgi-build.yaml b/llama_stack/distribution/templates/local-tgi-build.yaml index e764aef8c..d4752539d 100644 --- a/llama_stack/distribution/templates/local-tgi-build.yaml +++ b/llama_stack/distribution/templates/local-tgi-build.yaml @@ -1,6 +1,6 @@ name: local-tgi distribution_spec: - description: Use TGI (local or with Hugging Face Inference Endpoints for running LLM inference. When using HF Inference Endpoints, you must provide the name of the endpoint). + description: Like local, but use a TGI server for running LLM inference. providers: inference: remote::tgi memory: meta-reference diff --git a/llama_stack/providers/adapters/inference/tgi/__init__.py b/llama_stack/providers/adapters/inference/tgi/__init__.py index 743807836..451650323 100644 --- a/llama_stack/providers/adapters/inference/tgi/__init__.py +++ b/llama_stack/providers/adapters/inference/tgi/__init__.py @@ -4,21 +4,26 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .config import TGIImplConfig -from .tgi import InferenceEndpointAdapter, TGIAdapter +from typing import Union + +from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig +from .tgi import InferenceAPIAdapter, InferenceEndpointAdapter, TGIAdapter -async def get_adapter_impl(config: TGIImplConfig, _deps): - assert isinstance(config, TGIImplConfig), f"Unexpected config type: {type(config)}" - - if config.url is not None: - impl = TGIAdapter(config) - elif config.is_inference_endpoint(): - impl = InferenceEndpointAdapter(config) +async def get_adapter_impl( + config: Union[InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig], + _deps, +): + if isinstance(config, TGIImplConfig): + impl = TGIAdapter() + elif isinstance(config, InferenceAPIImplConfig): + impl = InferenceAPIAdapter() + elif isinstance(config, InferenceEndpointImplConfig): + impl = InferenceEndpointAdapter() else: raise ValueError( - "Invalid configuration. Specify either an URL or HF Inference Endpoint details (namespace and endpoint name)." + f"Invalid configuration. Expected 'TGIAdapter', 'InferenceAPIImplConfig' or 'InferenceEndpointImplConfig'. Got {type(config)}." ) - await impl.initialize() + await impl.initialize(config) return impl diff --git a/llama_stack/providers/adapters/inference/tgi/config.py b/llama_stack/providers/adapters/inference/tgi/config.py index a0135dfdd..233205066 100644 --- a/llama_stack/providers/adapters/inference/tgi/config.py +++ b/llama_stack/providers/adapters/inference/tgi/config.py @@ -12,18 +12,32 @@ from pydantic import BaseModel, Field @json_schema_type class TGIImplConfig(BaseModel): - url: Optional[str] = Field( - default=None, - description="The URL for the local TGI endpoint (e.g., http://localhost:8080)", + url: str = Field( + description="The URL for the TGI endpoint (e.g. 'http://localhost:8080')", ) api_token: Optional[str] = Field( default=None, - description="The HF token for Hugging Face Inference Endpoints (will default to locally saved token if not provided)", - ) - hf_endpoint_name: Optional[str] = Field( - default=None, - description="The name of the Hugging Face Inference Endpoint : can be either in the format of '{namespace}/{endpoint_name}' (namespace can be the username or organization name) or just '{endpoint_name}' if logged into the same account as the namespace", + description="A bearer token if your TGI endpoint is protected.", ) - def is_inference_endpoint(self) -> bool: - return self.hf_endpoint_name is not None + +@json_schema_type +class InferenceEndpointImplConfig(BaseModel): + endpoint_name: str = Field( + description="The name of the Hugging Face Inference Endpoint in the format of '{namespace}/{endpoint_name}' (e.g. 'my-cool-org/meta-llama-3-1-8b-instruct-rce'). Namespace is optional and will default to the user account if not provided.", + ) + api_token: Optional[str] = Field( + default=None, + description="Your Hugging Face user access token (will default to locally saved token if not provided)", + ) + + +@json_schema_type +class InferenceAPIImplConfig(BaseModel): + model_id: str = Field( + description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')", + ) + api_token: Optional[str] = Field( + default=None, + description="Your Hugging Face user access token (will default to locally saved token if not provided)", + ) diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index 4919ff86a..66f57442f 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -5,54 +5,33 @@ # the root directory of this source tree. -from typing import Any, AsyncGenerator, Dict +import logging +from typing import AsyncGenerator -import requests - -from huggingface_hub import HfApi, InferenceClient +from huggingface_hub import AsyncInferenceClient, HfApi from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import StopReason from llama_models.llama3.api.tokenizer import Tokenizer + from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.utils.inference.augment_messages import ( augment_messages_for_tools, ) -from .config import TGIImplConfig +from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig + +logger = logging.getLogger(__name__) -class TGIAdapter(Inference): - def __init__(self, config: TGIImplConfig) -> None: - self.config = config +class _HfAdapter(Inference): + client: AsyncInferenceClient + max_tokens: int + model_id: str + + def __init__(self) -> None: self.tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(self.tokenizer) - @property - def client(self) -> InferenceClient: - return InferenceClient(model=self.config.url, token=self.config.api_token) - - def _get_endpoint_info(self) -> Dict[str, Any]: - return { - **self.client.get_endpoint_info(), - "inference_url": self.config.url, - } - - async def initialize(self) -> None: - try: - info = self._get_endpoint_info() - if "model_id" not in info: - raise RuntimeError("Missing model_id in model info") - if "max_total_tokens" not in info: - raise RuntimeError("Missing max_total_tokens in model info") - self.max_tokens = info["max_total_tokens"] - - self.inference_url = info["inference_url"] - except Exception as e: - import traceback - - traceback.print_exc() - raise RuntimeError(f"Error initializing TGIAdapter: {e}") from e - async def shutdown(self) -> None: pass @@ -111,7 +90,7 @@ class TGIAdapter(Inference): options = self.get_chat_options(request) if not request.stream: - response = self.client.text_generation( + response = await self.client.text_generation( prompt=prompt, stream=False, details=True, @@ -147,7 +126,7 @@ class TGIAdapter(Inference): stop_reason = None tokens = [] - for response in self.client.text_generation( + async for response in await self.client.text_generation( prompt=prompt, stream=True, details=True, @@ -239,46 +218,36 @@ class TGIAdapter(Inference): ) -class InferenceEndpointAdapter(TGIAdapter): - def __init__(self, config: TGIImplConfig) -> None: - super().__init__(config) - self.config.url = self._construct_endpoint_url() +class TGIAdapter(_HfAdapter): + async def initialize(self, config: TGIImplConfig) -> None: + self.client = AsyncInferenceClient(model=config.url, token=config.api_token) + endpoint_info = await self.client.get_endpoint_info() + self.max_tokens = endpoint_info["max_total_tokens"] + self.model_id = endpoint_info["model_id"] - def _construct_endpoint_url(self) -> str: - hf_endpoint_name = self.config.hf_endpoint_name - assert hf_endpoint_name.count("/") <= 1, ( - "Endpoint name must be in the format of 'namespace/endpoint_name' " - "or 'endpoint_name'" + +class InferenceAPIAdapter(_HfAdapter): + async def initialize(self, config: InferenceAPIImplConfig) -> None: + self.client = AsyncInferenceClient( + model=config.model_id, token=config.api_token ) - if "/" not in hf_endpoint_name: - hf_namespace: str = self.get_namespace() - endpoint_path = f"{hf_namespace}/{hf_endpoint_name}" - else: - endpoint_path = hf_endpoint_name - return f"https://api.endpoints.huggingface.cloud/v2/endpoint/{endpoint_path}" + endpoint_info = await self.client.get_endpoint_info() + self.max_tokens = endpoint_info["max_total_tokens"] + self.model_id = endpoint_info["model_id"] - def get_namespace(self) -> str: - return HfApi().whoami()["name"] - @property - def client(self) -> InferenceClient: - return InferenceClient(model=self.inference_url, token=self.config.api_token) +class InferenceEndpointAdapter(_HfAdapter): + async def initialize(self, config: InferenceEndpointImplConfig) -> None: + # Get the inference endpoint details + api = HfApi(token=config.api_token) + endpoint = api.get_inference_endpoint(config.endpoint_name) - def _get_endpoint_info(self) -> Dict[str, Any]: - headers = { - "accept": "application/json", - "authorization": f"Bearer {self.config.api_token}", - } - response = requests.get(self.config.url, headers=headers) - response.raise_for_status() - endpoint_info = response.json() - return { - "inference_url": endpoint_info["status"]["url"], - "model_id": endpoint_info["model"]["repository"], - "max_total_tokens": int( - endpoint_info["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"] - ), - } + # Wait for the endpoint to be ready (if not already) + endpoint.wait(timeout=60) - async def initialize(self) -> None: - await super().initialize() + # Initialize the adapter + self.client = endpoint.async_client + self.model_id = endpoint.repository + self.max_tokens = int( + endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"] + ) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index db0d95527..31b3e2c2d 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -48,11 +48,29 @@ def available_providers() -> List[ProviderSpec]: api=Api.inference, adapter=AdapterSpec( adapter_id="tgi", - pip_packages=["huggingface_hub"], + pip_packages=["huggingface_hub", "aiohttp"], module="llama_stack.providers.adapters.inference.tgi", config_class="llama_stack.providers.adapters.inference.tgi.TGIImplConfig", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_id="hf::serverless", + pip_packages=["huggingface_hub", "aiohttp"], + module="llama_stack.providers.adapters.inference.tgi", + config_class="llama_stack.providers.adapters.inference.tgi.InferenceAPIImplConfig", + ), + ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_id="hf::endpoint", + pip_packages=["huggingface_hub", "aiohttp"], + module="llama_stack.providers.adapters.inference.tgi", + config_class="llama_stack.providers.adapters.inference.tgi.InferenceEndpointImplConfig", + ), + ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec(