Make TGI adapter compatible with HF Inference API (#97)

This commit is contained in:
Lucain 2024-09-25 23:08:31 +02:00 committed by GitHub
parent 851c30597a
commit 615ed4bfbc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 122 additions and 96 deletions

View file

@ -0,0 +1,10 @@
name: local-hf-endpoint
distribution_spec:
description: "Like local, but use Hugging Face Inference Endpoints for running LLM inference.\nSee https://hf.co/docs/api-endpoints."
providers:
inference: remote::hf::endpoint
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -0,0 +1,10 @@
name: local-hf-serverless
distribution_spec:
description: "Like local, but use Hugging Face Inference API (serverless) for running LLM inference.\nSee https://hf.co/docs/api-inference."
providers:
inference: remote::hf::serverless
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -1,6 +1,6 @@
name: local-tgi name: local-tgi
distribution_spec: distribution_spec:
description: Use TGI (local or with Hugging Face Inference Endpoints for running LLM inference. When using HF Inference Endpoints, you must provide the name of the endpoint). description: Like local, but use a TGI server for running LLM inference.
providers: providers:
inference: remote::tgi inference: remote::tgi
memory: meta-reference memory: meta-reference

View file

@ -4,21 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .config import TGIImplConfig from typing import Union
from .tgi import InferenceEndpointAdapter, TGIAdapter
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
from .tgi import InferenceAPIAdapter, InferenceEndpointAdapter, TGIAdapter
async def get_adapter_impl(config: TGIImplConfig, _deps): async def get_adapter_impl(
assert isinstance(config, TGIImplConfig), f"Unexpected config type: {type(config)}" config: Union[InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig],
_deps,
if config.url is not None: ):
impl = TGIAdapter(config) if isinstance(config, TGIImplConfig):
elif config.is_inference_endpoint(): impl = TGIAdapter()
impl = InferenceEndpointAdapter(config) elif isinstance(config, InferenceAPIImplConfig):
impl = InferenceAPIAdapter()
elif isinstance(config, InferenceEndpointImplConfig):
impl = InferenceEndpointAdapter()
else: else:
raise ValueError( raise ValueError(
"Invalid configuration. Specify either an URL or HF Inference Endpoint details (namespace and endpoint name)." f"Invalid configuration. Expected 'TGIAdapter', 'InferenceAPIImplConfig' or 'InferenceEndpointImplConfig'. Got {type(config)}."
) )
await impl.initialize() await impl.initialize(config)
return impl return impl

View file

@ -12,18 +12,32 @@ from pydantic import BaseModel, Field
@json_schema_type @json_schema_type
class TGIImplConfig(BaseModel): class TGIImplConfig(BaseModel):
url: Optional[str] = Field( url: str = Field(
default=None, description="The URL for the TGI endpoint (e.g. 'http://localhost:8080')",
description="The URL for the local TGI endpoint (e.g., http://localhost:8080)",
) )
api_token: Optional[str] = Field( api_token: Optional[str] = Field(
default=None, default=None,
description="The HF token for Hugging Face Inference Endpoints (will default to locally saved token if not provided)", description="A bearer token if your TGI endpoint is protected.",
)
hf_endpoint_name: Optional[str] = Field(
default=None,
description="The name of the Hugging Face Inference Endpoint : can be either in the format of '{namespace}/{endpoint_name}' (namespace can be the username or organization name) or just '{endpoint_name}' if logged into the same account as the namespace",
) )
def is_inference_endpoint(self) -> bool:
return self.hf_endpoint_name is not None @json_schema_type
class InferenceEndpointImplConfig(BaseModel):
endpoint_name: str = Field(
description="The name of the Hugging Face Inference Endpoint in the format of '{namespace}/{endpoint_name}' (e.g. 'my-cool-org/meta-llama-3-1-8b-instruct-rce'). Namespace is optional and will default to the user account if not provided.",
)
api_token: Optional[str] = Field(
default=None,
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
)
@json_schema_type
class InferenceAPIImplConfig(BaseModel):
model_id: str = Field(
description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')",
)
api_token: Optional[str] = Field(
default=None,
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
)

View file

@ -5,54 +5,33 @@
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict import logging
from typing import AsyncGenerator
import requests from huggingface_hub import AsyncInferenceClient, HfApi
from huggingface_hub import HfApi, InferenceClient
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import StopReason from llama_models.llama3.api.datatypes import StopReason
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import ( from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools, augment_messages_for_tools,
) )
from .config import TGIImplConfig from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
logger = logging.getLogger(__name__)
class TGIAdapter(Inference): class _HfAdapter(Inference):
def __init__(self, config: TGIImplConfig) -> None: client: AsyncInferenceClient
self.config = config max_tokens: int
model_id: str
def __init__(self) -> None:
self.tokenizer = Tokenizer.get_instance() self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer) self.formatter = ChatFormat(self.tokenizer)
@property
def client(self) -> InferenceClient:
return InferenceClient(model=self.config.url, token=self.config.api_token)
def _get_endpoint_info(self) -> Dict[str, Any]:
return {
**self.client.get_endpoint_info(),
"inference_url": self.config.url,
}
async def initialize(self) -> None:
try:
info = self._get_endpoint_info()
if "model_id" not in info:
raise RuntimeError("Missing model_id in model info")
if "max_total_tokens" not in info:
raise RuntimeError("Missing max_total_tokens in model info")
self.max_tokens = info["max_total_tokens"]
self.inference_url = info["inference_url"]
except Exception as e:
import traceback
traceback.print_exc()
raise RuntimeError(f"Error initializing TGIAdapter: {e}") from e
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
@ -111,7 +90,7 @@ class TGIAdapter(Inference):
options = self.get_chat_options(request) options = self.get_chat_options(request)
if not request.stream: if not request.stream:
response = self.client.text_generation( response = await self.client.text_generation(
prompt=prompt, prompt=prompt,
stream=False, stream=False,
details=True, details=True,
@ -147,7 +126,7 @@ class TGIAdapter(Inference):
stop_reason = None stop_reason = None
tokens = [] tokens = []
for response in self.client.text_generation( async for response in await self.client.text_generation(
prompt=prompt, prompt=prompt,
stream=True, stream=True,
details=True, details=True,
@ -239,46 +218,36 @@ class TGIAdapter(Inference):
) )
class InferenceEndpointAdapter(TGIAdapter): class TGIAdapter(_HfAdapter):
def __init__(self, config: TGIImplConfig) -> None: async def initialize(self, config: TGIImplConfig) -> None:
super().__init__(config) self.client = AsyncInferenceClient(model=config.url, token=config.api_token)
self.config.url = self._construct_endpoint_url() endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]
def _construct_endpoint_url(self) -> str:
hf_endpoint_name = self.config.hf_endpoint_name class InferenceAPIAdapter(_HfAdapter):
assert hf_endpoint_name.count("/") <= 1, ( async def initialize(self, config: InferenceAPIImplConfig) -> None:
"Endpoint name must be in the format of 'namespace/endpoint_name' " self.client = AsyncInferenceClient(
"or 'endpoint_name'" model=config.model_id, token=config.api_token
) )
if "/" not in hf_endpoint_name: endpoint_info = await self.client.get_endpoint_info()
hf_namespace: str = self.get_namespace() self.max_tokens = endpoint_info["max_total_tokens"]
endpoint_path = f"{hf_namespace}/{hf_endpoint_name}" self.model_id = endpoint_info["model_id"]
else:
endpoint_path = hf_endpoint_name
return f"https://api.endpoints.huggingface.cloud/v2/endpoint/{endpoint_path}"
def get_namespace(self) -> str:
return HfApi().whoami()["name"]
@property class InferenceEndpointAdapter(_HfAdapter):
def client(self) -> InferenceClient: async def initialize(self, config: InferenceEndpointImplConfig) -> None:
return InferenceClient(model=self.inference_url, token=self.config.api_token) # Get the inference endpoint details
api = HfApi(token=config.api_token)
endpoint = api.get_inference_endpoint(config.endpoint_name)
def _get_endpoint_info(self) -> Dict[str, Any]: # Wait for the endpoint to be ready (if not already)
headers = { endpoint.wait(timeout=60)
"accept": "application/json",
"authorization": f"Bearer {self.config.api_token}",
}
response = requests.get(self.config.url, headers=headers)
response.raise_for_status()
endpoint_info = response.json()
return {
"inference_url": endpoint_info["status"]["url"],
"model_id": endpoint_info["model"]["repository"],
"max_total_tokens": int(
endpoint_info["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"]
),
}
async def initialize(self) -> None: # Initialize the adapter
await super().initialize() self.client = endpoint.async_client
self.model_id = endpoint.repository
self.max_tokens = int(
endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"]
)

View file

@ -48,11 +48,29 @@ def available_providers() -> List[ProviderSpec]:
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter=AdapterSpec(
adapter_id="tgi", adapter_id="tgi",
pip_packages=["huggingface_hub"], pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.adapters.inference.tgi", module="llama_stack.providers.adapters.inference.tgi",
config_class="llama_stack.providers.adapters.inference.tgi.TGIImplConfig", config_class="llama_stack.providers.adapters.inference.tgi.TGIImplConfig",
), ),
), ),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_id="hf::serverless",
pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.adapters.inference.tgi",
config_class="llama_stack.providers.adapters.inference.tgi.InferenceAPIImplConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_id="hf::endpoint",
pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.adapters.inference.tgi",
config_class="llama_stack.providers.adapters.inference.tgi.InferenceEndpointImplConfig",
),
),
remote_provider_spec( remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter=AdapterSpec(