diff --git a/README.md b/README.md index 90665b480..be9aa320e 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ This repository contains the Llama Stack API specifications as well as API Providers and Llama Stack Distributions. -The Llama Stack defines and standardizes the building blocks needed to bring generative AI applications to market. These blocks span the entire development lifecycle: from model training and fine-tuning, through product evaluation, to building and running AI agents in production. Beyond definition, we are building providers for the Llama Stack APIs. These we're developing open-source versions and partnering with providers , ensuring developers can assemble AI solutions using consistent, interlocking pieces across platforms. The ultimate goal is to accelerate innovation in the AI space. +The Llama Stack defines and standardizes the building blocks needed to bring generative AI applications to market. These blocks span the entire development lifecycle: from model training and fine-tuning, through product evaluation, to building and running AI agents in production. Beyond definition, we are building providers for the Llama Stack APIs. These we're developing open-source versions and partnering with providers, ensuring developers can assemble AI solutions using consistent, interlocking pieces across platforms. The ultimate goal is to accelerate innovation in the AI space. The Stack APIs are rapidly improving, but still very much work in progress and we invite feedback as well as direct contributions. @@ -59,7 +59,7 @@ A Distribution is where APIs and Providers are assembled together to provide a c | **Distribution Provider** | **Docker** | **Inference** | **Memory** | **Safety** | **Telemetry** | | :----: | :----: | :----: | :----: | :----: | :----: | | Meta Reference | [Local GPU](https://hub.docker.com/repository/docker/llamastack/llamastack-local-gpu/general), [Local CPU](https://hub.docker.com/repository/docker/llamastack/llamastack-local-cpu/general) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | -| Dell-TGI | | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | +| Dell-TGI | [Local TGI + Chroma](https://hub.docker.com/repository/docker/llamastack/llamastack-local-tgi-chroma/general) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | ## Installation diff --git a/docs/cli_reference.md b/docs/cli_reference.md index 2ebdadd4f..1c62188ef 100644 --- a/docs/cli_reference.md +++ b/docs/cli_reference.md @@ -483,4 +483,4 @@ Similarly you can test safety (if you configured llama-guard and/or prompt-guard python -m llama_stack.apis.safety.client localhost 5000 ``` -You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/sdk_examples) repo. +You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/examples) repo. diff --git a/docs/getting_started.md b/docs/getting_started.md index 5e2f21eac..83f08cfa6 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -433,4 +433,4 @@ Similarly you can test safety (if you configured llama-guard and/or prompt-guard python -m llama_stack.apis.safety.client localhost 5000 ``` -You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps) repo. +You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/examples) repo. diff --git a/llama_stack/cli/stack/build.py b/llama_stack/cli/stack/build.py index 2321c8f2f..132aef7e5 100644 --- a/llama_stack/cli/stack/build.py +++ b/llama_stack/cli/stack/build.py @@ -95,9 +95,9 @@ class StackBuild(Subcommand): # save build.yaml spec for building same distribution again if build_config.image_type == ImageType.docker.value: # docker needs build file to be in the llama-stack repo dir to be able to copy over to the image - llama_stack_path = Path(os.path.relpath(__file__)).parent.parent.parent + llama_stack_path = Path(os.path.abspath(__file__)).parent.parent.parent.parent build_dir = ( - llama_stack_path / "configs/distributions" / build_config.image_type + llama_stack_path / "tmp/configs/" ) else: build_dir = ( diff --git a/llama_stack/distribution/build_container.sh b/llama_stack/distribution/build_container.sh index 3efef6c97..fec1e394f 100755 --- a/llama_stack/distribution/build_container.sh +++ b/llama_stack/distribution/build_container.sh @@ -103,7 +103,7 @@ add_to_docker < bool: - return self.hf_endpoint_name is not None + +@json_schema_type +class InferenceEndpointImplConfig(BaseModel): + endpoint_name: str = Field( + description="The name of the Hugging Face Inference Endpoint in the format of '{namespace}/{endpoint_name}' (e.g. 'my-cool-org/meta-llama-3-1-8b-instruct-rce'). Namespace is optional and will default to the user account if not provided.", + ) + api_token: Optional[str] = Field( + default=None, + description="Your Hugging Face user access token (will default to locally saved token if not provided)", + ) + + +@json_schema_type +class InferenceAPIImplConfig(BaseModel): + model_id: str = Field( + description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')", + ) + api_token: Optional[str] = Field( + default=None, + description="Your Hugging Face user access token (will default to locally saved token if not provided)", + ) diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index 4919ff86a..66f57442f 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -5,54 +5,33 @@ # the root directory of this source tree. -from typing import Any, AsyncGenerator, Dict +import logging +from typing import AsyncGenerator -import requests - -from huggingface_hub import HfApi, InferenceClient +from huggingface_hub import AsyncInferenceClient, HfApi from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import StopReason from llama_models.llama3.api.tokenizer import Tokenizer + from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.utils.inference.augment_messages import ( augment_messages_for_tools, ) -from .config import TGIImplConfig +from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig + +logger = logging.getLogger(__name__) -class TGIAdapter(Inference): - def __init__(self, config: TGIImplConfig) -> None: - self.config = config +class _HfAdapter(Inference): + client: AsyncInferenceClient + max_tokens: int + model_id: str + + def __init__(self) -> None: self.tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(self.tokenizer) - @property - def client(self) -> InferenceClient: - return InferenceClient(model=self.config.url, token=self.config.api_token) - - def _get_endpoint_info(self) -> Dict[str, Any]: - return { - **self.client.get_endpoint_info(), - "inference_url": self.config.url, - } - - async def initialize(self) -> None: - try: - info = self._get_endpoint_info() - if "model_id" not in info: - raise RuntimeError("Missing model_id in model info") - if "max_total_tokens" not in info: - raise RuntimeError("Missing max_total_tokens in model info") - self.max_tokens = info["max_total_tokens"] - - self.inference_url = info["inference_url"] - except Exception as e: - import traceback - - traceback.print_exc() - raise RuntimeError(f"Error initializing TGIAdapter: {e}") from e - async def shutdown(self) -> None: pass @@ -111,7 +90,7 @@ class TGIAdapter(Inference): options = self.get_chat_options(request) if not request.stream: - response = self.client.text_generation( + response = await self.client.text_generation( prompt=prompt, stream=False, details=True, @@ -147,7 +126,7 @@ class TGIAdapter(Inference): stop_reason = None tokens = [] - for response in self.client.text_generation( + async for response in await self.client.text_generation( prompt=prompt, stream=True, details=True, @@ -239,46 +218,36 @@ class TGIAdapter(Inference): ) -class InferenceEndpointAdapter(TGIAdapter): - def __init__(self, config: TGIImplConfig) -> None: - super().__init__(config) - self.config.url = self._construct_endpoint_url() +class TGIAdapter(_HfAdapter): + async def initialize(self, config: TGIImplConfig) -> None: + self.client = AsyncInferenceClient(model=config.url, token=config.api_token) + endpoint_info = await self.client.get_endpoint_info() + self.max_tokens = endpoint_info["max_total_tokens"] + self.model_id = endpoint_info["model_id"] - def _construct_endpoint_url(self) -> str: - hf_endpoint_name = self.config.hf_endpoint_name - assert hf_endpoint_name.count("/") <= 1, ( - "Endpoint name must be in the format of 'namespace/endpoint_name' " - "or 'endpoint_name'" + +class InferenceAPIAdapter(_HfAdapter): + async def initialize(self, config: InferenceAPIImplConfig) -> None: + self.client = AsyncInferenceClient( + model=config.model_id, token=config.api_token ) - if "/" not in hf_endpoint_name: - hf_namespace: str = self.get_namespace() - endpoint_path = f"{hf_namespace}/{hf_endpoint_name}" - else: - endpoint_path = hf_endpoint_name - return f"https://api.endpoints.huggingface.cloud/v2/endpoint/{endpoint_path}" + endpoint_info = await self.client.get_endpoint_info() + self.max_tokens = endpoint_info["max_total_tokens"] + self.model_id = endpoint_info["model_id"] - def get_namespace(self) -> str: - return HfApi().whoami()["name"] - @property - def client(self) -> InferenceClient: - return InferenceClient(model=self.inference_url, token=self.config.api_token) +class InferenceEndpointAdapter(_HfAdapter): + async def initialize(self, config: InferenceEndpointImplConfig) -> None: + # Get the inference endpoint details + api = HfApi(token=config.api_token) + endpoint = api.get_inference_endpoint(config.endpoint_name) - def _get_endpoint_info(self) -> Dict[str, Any]: - headers = { - "accept": "application/json", - "authorization": f"Bearer {self.config.api_token}", - } - response = requests.get(self.config.url, headers=headers) - response.raise_for_status() - endpoint_info = response.json() - return { - "inference_url": endpoint_info["status"]["url"], - "model_id": endpoint_info["model"]["repository"], - "max_total_tokens": int( - endpoint_info["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"] - ), - } + # Wait for the endpoint to be ready (if not already) + endpoint.wait(timeout=60) - async def initialize(self) -> None: - await super().initialize() + # Initialize the adapter + self.client = endpoint.async_client + self.model_id = endpoint.repository + self.max_tokens = int( + endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"] + ) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index db0d95527..31b3e2c2d 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -48,11 +48,29 @@ def available_providers() -> List[ProviderSpec]: api=Api.inference, adapter=AdapterSpec( adapter_id="tgi", - pip_packages=["huggingface_hub"], + pip_packages=["huggingface_hub", "aiohttp"], module="llama_stack.providers.adapters.inference.tgi", config_class="llama_stack.providers.adapters.inference.tgi.TGIImplConfig", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_id="hf::serverless", + pip_packages=["huggingface_hub", "aiohttp"], + module="llama_stack.providers.adapters.inference.tgi", + config_class="llama_stack.providers.adapters.inference.tgi.InferenceAPIImplConfig", + ), + ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_id="hf::endpoint", + pip_packages=["huggingface_hub", "aiohttp"], + module="llama_stack.providers.adapters.inference.tgi", + config_class="llama_stack.providers.adapters.inference.tgi.InferenceEndpointImplConfig", + ), + ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec(