From 615ed4bfbcd3b1218fa4558d8161d3b7dd6e15f9 Mon Sep 17 00:00:00 2001 From: Lucain Date: Wed, 25 Sep 2024 23:08:31 +0200 Subject: [PATCH 1/4] Make TGI adapter compatible with HF Inference API (#97) --- .../templates/local-hf-endpoint-build.yaml | 10 ++ .../templates/local-hf-serverless-build.yaml | 10 ++ .../templates/local-tgi-build.yaml | 2 +- .../adapters/inference/tgi/__init__.py | 27 ++-- .../adapters/inference/tgi/config.py | 34 ++++-- .../providers/adapters/inference/tgi/tgi.py | 115 +++++++----------- llama_stack/providers/registry/inference.py | 20 ++- 7 files changed, 122 insertions(+), 96 deletions(-) create mode 100644 llama_stack/distribution/templates/local-hf-endpoint-build.yaml create mode 100644 llama_stack/distribution/templates/local-hf-serverless-build.yaml diff --git a/llama_stack/distribution/templates/local-hf-endpoint-build.yaml b/llama_stack/distribution/templates/local-hf-endpoint-build.yaml new file mode 100644 index 000000000..e5c4ae8cc --- /dev/null +++ b/llama_stack/distribution/templates/local-hf-endpoint-build.yaml @@ -0,0 +1,10 @@ +name: local-hf-endpoint +distribution_spec: + description: "Like local, but use Hugging Face Inference Endpoints for running LLM inference.\nSee https://hf.co/docs/api-endpoints." + providers: + inference: remote::hf::endpoint + memory: meta-reference + safety: meta-reference + agents: meta-reference + telemetry: meta-reference +image_type: conda diff --git a/llama_stack/distribution/templates/local-hf-serverless-build.yaml b/llama_stack/distribution/templates/local-hf-serverless-build.yaml new file mode 100644 index 000000000..752390b40 --- /dev/null +++ b/llama_stack/distribution/templates/local-hf-serverless-build.yaml @@ -0,0 +1,10 @@ +name: local-hf-serverless +distribution_spec: + description: "Like local, but use Hugging Face Inference API (serverless) for running LLM inference.\nSee https://hf.co/docs/api-inference." + providers: + inference: remote::hf::serverless + memory: meta-reference + safety: meta-reference + agents: meta-reference + telemetry: meta-reference +image_type: conda diff --git a/llama_stack/distribution/templates/local-tgi-build.yaml b/llama_stack/distribution/templates/local-tgi-build.yaml index e764aef8c..d4752539d 100644 --- a/llama_stack/distribution/templates/local-tgi-build.yaml +++ b/llama_stack/distribution/templates/local-tgi-build.yaml @@ -1,6 +1,6 @@ name: local-tgi distribution_spec: - description: Use TGI (local or with Hugging Face Inference Endpoints for running LLM inference. When using HF Inference Endpoints, you must provide the name of the endpoint). + description: Like local, but use a TGI server for running LLM inference. providers: inference: remote::tgi memory: meta-reference diff --git a/llama_stack/providers/adapters/inference/tgi/__init__.py b/llama_stack/providers/adapters/inference/tgi/__init__.py index 743807836..451650323 100644 --- a/llama_stack/providers/adapters/inference/tgi/__init__.py +++ b/llama_stack/providers/adapters/inference/tgi/__init__.py @@ -4,21 +4,26 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .config import TGIImplConfig -from .tgi import InferenceEndpointAdapter, TGIAdapter +from typing import Union + +from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig +from .tgi import InferenceAPIAdapter, InferenceEndpointAdapter, TGIAdapter -async def get_adapter_impl(config: TGIImplConfig, _deps): - assert isinstance(config, TGIImplConfig), f"Unexpected config type: {type(config)}" - - if config.url is not None: - impl = TGIAdapter(config) - elif config.is_inference_endpoint(): - impl = InferenceEndpointAdapter(config) +async def get_adapter_impl( + config: Union[InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig], + _deps, +): + if isinstance(config, TGIImplConfig): + impl = TGIAdapter() + elif isinstance(config, InferenceAPIImplConfig): + impl = InferenceAPIAdapter() + elif isinstance(config, InferenceEndpointImplConfig): + impl = InferenceEndpointAdapter() else: raise ValueError( - "Invalid configuration. Specify either an URL or HF Inference Endpoint details (namespace and endpoint name)." + f"Invalid configuration. Expected 'TGIAdapter', 'InferenceAPIImplConfig' or 'InferenceEndpointImplConfig'. Got {type(config)}." ) - await impl.initialize() + await impl.initialize(config) return impl diff --git a/llama_stack/providers/adapters/inference/tgi/config.py b/llama_stack/providers/adapters/inference/tgi/config.py index a0135dfdd..233205066 100644 --- a/llama_stack/providers/adapters/inference/tgi/config.py +++ b/llama_stack/providers/adapters/inference/tgi/config.py @@ -12,18 +12,32 @@ from pydantic import BaseModel, Field @json_schema_type class TGIImplConfig(BaseModel): - url: Optional[str] = Field( - default=None, - description="The URL for the local TGI endpoint (e.g., http://localhost:8080)", + url: str = Field( + description="The URL for the TGI endpoint (e.g. 'http://localhost:8080')", ) api_token: Optional[str] = Field( default=None, - description="The HF token for Hugging Face Inference Endpoints (will default to locally saved token if not provided)", - ) - hf_endpoint_name: Optional[str] = Field( - default=None, - description="The name of the Hugging Face Inference Endpoint : can be either in the format of '{namespace}/{endpoint_name}' (namespace can be the username or organization name) or just '{endpoint_name}' if logged into the same account as the namespace", + description="A bearer token if your TGI endpoint is protected.", ) - def is_inference_endpoint(self) -> bool: - return self.hf_endpoint_name is not None + +@json_schema_type +class InferenceEndpointImplConfig(BaseModel): + endpoint_name: str = Field( + description="The name of the Hugging Face Inference Endpoint in the format of '{namespace}/{endpoint_name}' (e.g. 'my-cool-org/meta-llama-3-1-8b-instruct-rce'). Namespace is optional and will default to the user account if not provided.", + ) + api_token: Optional[str] = Field( + default=None, + description="Your Hugging Face user access token (will default to locally saved token if not provided)", + ) + + +@json_schema_type +class InferenceAPIImplConfig(BaseModel): + model_id: str = Field( + description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')", + ) + api_token: Optional[str] = Field( + default=None, + description="Your Hugging Face user access token (will default to locally saved token if not provided)", + ) diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index 4919ff86a..66f57442f 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -5,54 +5,33 @@ # the root directory of this source tree. -from typing import Any, AsyncGenerator, Dict +import logging +from typing import AsyncGenerator -import requests - -from huggingface_hub import HfApi, InferenceClient +from huggingface_hub import AsyncInferenceClient, HfApi from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import StopReason from llama_models.llama3.api.tokenizer import Tokenizer + from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.utils.inference.augment_messages import ( augment_messages_for_tools, ) -from .config import TGIImplConfig +from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig + +logger = logging.getLogger(__name__) -class TGIAdapter(Inference): - def __init__(self, config: TGIImplConfig) -> None: - self.config = config +class _HfAdapter(Inference): + client: AsyncInferenceClient + max_tokens: int + model_id: str + + def __init__(self) -> None: self.tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(self.tokenizer) - @property - def client(self) -> InferenceClient: - return InferenceClient(model=self.config.url, token=self.config.api_token) - - def _get_endpoint_info(self) -> Dict[str, Any]: - return { - **self.client.get_endpoint_info(), - "inference_url": self.config.url, - } - - async def initialize(self) -> None: - try: - info = self._get_endpoint_info() - if "model_id" not in info: - raise RuntimeError("Missing model_id in model info") - if "max_total_tokens" not in info: - raise RuntimeError("Missing max_total_tokens in model info") - self.max_tokens = info["max_total_tokens"] - - self.inference_url = info["inference_url"] - except Exception as e: - import traceback - - traceback.print_exc() - raise RuntimeError(f"Error initializing TGIAdapter: {e}") from e - async def shutdown(self) -> None: pass @@ -111,7 +90,7 @@ class TGIAdapter(Inference): options = self.get_chat_options(request) if not request.stream: - response = self.client.text_generation( + response = await self.client.text_generation( prompt=prompt, stream=False, details=True, @@ -147,7 +126,7 @@ class TGIAdapter(Inference): stop_reason = None tokens = [] - for response in self.client.text_generation( + async for response in await self.client.text_generation( prompt=prompt, stream=True, details=True, @@ -239,46 +218,36 @@ class TGIAdapter(Inference): ) -class InferenceEndpointAdapter(TGIAdapter): - def __init__(self, config: TGIImplConfig) -> None: - super().__init__(config) - self.config.url = self._construct_endpoint_url() +class TGIAdapter(_HfAdapter): + async def initialize(self, config: TGIImplConfig) -> None: + self.client = AsyncInferenceClient(model=config.url, token=config.api_token) + endpoint_info = await self.client.get_endpoint_info() + self.max_tokens = endpoint_info["max_total_tokens"] + self.model_id = endpoint_info["model_id"] - def _construct_endpoint_url(self) -> str: - hf_endpoint_name = self.config.hf_endpoint_name - assert hf_endpoint_name.count("/") <= 1, ( - "Endpoint name must be in the format of 'namespace/endpoint_name' " - "or 'endpoint_name'" + +class InferenceAPIAdapter(_HfAdapter): + async def initialize(self, config: InferenceAPIImplConfig) -> None: + self.client = AsyncInferenceClient( + model=config.model_id, token=config.api_token ) - if "/" not in hf_endpoint_name: - hf_namespace: str = self.get_namespace() - endpoint_path = f"{hf_namespace}/{hf_endpoint_name}" - else: - endpoint_path = hf_endpoint_name - return f"https://api.endpoints.huggingface.cloud/v2/endpoint/{endpoint_path}" + endpoint_info = await self.client.get_endpoint_info() + self.max_tokens = endpoint_info["max_total_tokens"] + self.model_id = endpoint_info["model_id"] - def get_namespace(self) -> str: - return HfApi().whoami()["name"] - @property - def client(self) -> InferenceClient: - return InferenceClient(model=self.inference_url, token=self.config.api_token) +class InferenceEndpointAdapter(_HfAdapter): + async def initialize(self, config: InferenceEndpointImplConfig) -> None: + # Get the inference endpoint details + api = HfApi(token=config.api_token) + endpoint = api.get_inference_endpoint(config.endpoint_name) - def _get_endpoint_info(self) -> Dict[str, Any]: - headers = { - "accept": "application/json", - "authorization": f"Bearer {self.config.api_token}", - } - response = requests.get(self.config.url, headers=headers) - response.raise_for_status() - endpoint_info = response.json() - return { - "inference_url": endpoint_info["status"]["url"], - "model_id": endpoint_info["model"]["repository"], - "max_total_tokens": int( - endpoint_info["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"] - ), - } + # Wait for the endpoint to be ready (if not already) + endpoint.wait(timeout=60) - async def initialize(self) -> None: - await super().initialize() + # Initialize the adapter + self.client = endpoint.async_client + self.model_id = endpoint.repository + self.max_tokens = int( + endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"] + ) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index db0d95527..31b3e2c2d 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -48,11 +48,29 @@ def available_providers() -> List[ProviderSpec]: api=Api.inference, adapter=AdapterSpec( adapter_id="tgi", - pip_packages=["huggingface_hub"], + pip_packages=["huggingface_hub", "aiohttp"], module="llama_stack.providers.adapters.inference.tgi", config_class="llama_stack.providers.adapters.inference.tgi.TGIImplConfig", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_id="hf::serverless", + pip_packages=["huggingface_hub", "aiohttp"], + module="llama_stack.providers.adapters.inference.tgi", + config_class="llama_stack.providers.adapters.inference.tgi.InferenceAPIImplConfig", + ), + ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_id="hf::endpoint", + pip_packages=["huggingface_hub", "aiohttp"], + module="llama_stack.providers.adapters.inference.tgi", + config_class="llama_stack.providers.adapters.inference.tgi.InferenceEndpointImplConfig", + ), + ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec( From 37be3fb1844ce4f8a8879be420bf812e9358a503 Mon Sep 17 00:00:00 2001 From: machina-source <58921460+machina-source@users.noreply.github.com> Date: Wed, 25 Sep 2024 16:18:46 -0500 Subject: [PATCH 2/4] Fix links & format (#104) Fix broken examples link to llama-stack-apps repo Remove extra space in README.md --- README.md | 2 +- docs/cli_reference.md | 2 +- docs/getting_started.md | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 90665b480..7ac5abe0d 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ This repository contains the Llama Stack API specifications as well as API Providers and Llama Stack Distributions. -The Llama Stack defines and standardizes the building blocks needed to bring generative AI applications to market. These blocks span the entire development lifecycle: from model training and fine-tuning, through product evaluation, to building and running AI agents in production. Beyond definition, we are building providers for the Llama Stack APIs. These we're developing open-source versions and partnering with providers , ensuring developers can assemble AI solutions using consistent, interlocking pieces across platforms. The ultimate goal is to accelerate innovation in the AI space. +The Llama Stack defines and standardizes the building blocks needed to bring generative AI applications to market. These blocks span the entire development lifecycle: from model training and fine-tuning, through product evaluation, to building and running AI agents in production. Beyond definition, we are building providers for the Llama Stack APIs. These we're developing open-source versions and partnering with providers, ensuring developers can assemble AI solutions using consistent, interlocking pieces across platforms. The ultimate goal is to accelerate innovation in the AI space. The Stack APIs are rapidly improving, but still very much work in progress and we invite feedback as well as direct contributions. diff --git a/docs/cli_reference.md b/docs/cli_reference.md index 2ebdadd4f..1c62188ef 100644 --- a/docs/cli_reference.md +++ b/docs/cli_reference.md @@ -483,4 +483,4 @@ Similarly you can test safety (if you configured llama-guard and/or prompt-guard python -m llama_stack.apis.safety.client localhost 5000 ``` -You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/sdk_examples) repo. +You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/examples) repo. diff --git a/docs/getting_started.md b/docs/getting_started.md index 5e2f21eac..83f08cfa6 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -433,4 +433,4 @@ Similarly you can test safety (if you configured llama-guard and/or prompt-guard python -m llama_stack.apis.safety.client localhost 5000 ``` -You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps) repo. +You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/examples) repo. From ca7602a64289d272c20f4e702767caecb1fcafcf Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 25 Sep 2024 15:11:51 -0700 Subject: [PATCH 3/4] fix #100 --- llama_stack/cli/stack/build.py | 4 ++-- llama_stack/distribution/build_container.sh | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/llama_stack/cli/stack/build.py b/llama_stack/cli/stack/build.py index 2321c8f2f..132aef7e5 100644 --- a/llama_stack/cli/stack/build.py +++ b/llama_stack/cli/stack/build.py @@ -95,9 +95,9 @@ class StackBuild(Subcommand): # save build.yaml spec for building same distribution again if build_config.image_type == ImageType.docker.value: # docker needs build file to be in the llama-stack repo dir to be able to copy over to the image - llama_stack_path = Path(os.path.relpath(__file__)).parent.parent.parent + llama_stack_path = Path(os.path.abspath(__file__)).parent.parent.parent.parent build_dir = ( - llama_stack_path / "configs/distributions" / build_config.image_type + llama_stack_path / "tmp/configs/" ) else: build_dir = ( diff --git a/llama_stack/distribution/build_container.sh b/llama_stack/distribution/build_container.sh index 3efef6c97..fec1e394f 100755 --- a/llama_stack/distribution/build_container.sh +++ b/llama_stack/distribution/build_container.sh @@ -103,7 +103,7 @@ add_to_docker < Date: Wed, 25 Sep 2024 17:29:17 -0700 Subject: [PATCH 4/4] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7ac5abe0d..be9aa320e 100644 --- a/README.md +++ b/README.md @@ -59,7 +59,7 @@ A Distribution is where APIs and Providers are assembled together to provide a c | **Distribution Provider** | **Docker** | **Inference** | **Memory** | **Safety** | **Telemetry** | | :----: | :----: | :----: | :----: | :----: | :----: | | Meta Reference | [Local GPU](https://hub.docker.com/repository/docker/llamastack/llamastack-local-gpu/general), [Local CPU](https://hub.docker.com/repository/docker/llamastack/llamastack-local-cpu/general) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | -| Dell-TGI | | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | +| Dell-TGI | [Local TGI + Chroma](https://hub.docker.com/repository/docker/llamastack/llamastack-local-tgi-chroma/general) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | ## Installation