From b96e705680fa96c49e278763be9a63834c8e90b2 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Mon, 9 Sep 2024 17:47:49 +0200 Subject: [PATCH] Fixes post-review and split TGI adapter into local and Inference Endpoints ones --- .../inference/adapters/tgi/__init__.py | 16 ++-- .../inference/adapters/tgi/config.py | 24 ++++-- llama_toolchain/inference/adapters/tgi/tgi.py | 75 +++++++++++++++++-- 3 files changed, 98 insertions(+), 17 deletions(-) diff --git a/llama_toolchain/inference/adapters/tgi/__init__.py b/llama_toolchain/inference/adapters/tgi/__init__.py index 86faf94cb..3185e04dc 100644 --- a/llama_toolchain/inference/adapters/tgi/__init__.py +++ b/llama_toolchain/inference/adapters/tgi/__init__.py @@ -5,14 +5,20 @@ # the root directory of this source tree. from .config import TGIImplConfig +from .tgi import InferenceEndpointAdapter, LocalTGIAdapter async def get_adapter_impl(config: TGIImplConfig, _deps): - from .tgi import TGIAdapter + assert isinstance(config, TGIImplConfig), f"Unexpected config type: {type(config)}" + + if config.is_local_tgi(): + impl = LocalTGIAdapter(config) + elif config.is_inference_endpoint(): + impl = InferenceEndpointAdapter(config) + else: + raise ValueError( + "Invalid configuration. Specify either a local URL or Inference Endpoint details." + ) - assert isinstance( - config, TGIImplConfig - ), f"Unexpected config type: {type(config)}" - impl = TGIAdapter(config) await impl.initialize() return impl diff --git a/llama_toolchain/inference/adapters/tgi/config.py b/llama_toolchain/inference/adapters/tgi/config.py index 00132fc6c..267ddbced 100644 --- a/llama_toolchain/inference/adapters/tgi/config.py +++ b/llama_toolchain/inference/adapters/tgi/config.py @@ -12,11 +12,25 @@ from pydantic import BaseModel, Field, field_validator @json_schema_type class TGIImplConfig(BaseModel): - url: str = Field( - default="https://huggingface.co/inference-endpoints/dedicated", - description="The URL for the TGI endpoint", + url: Optional[str] = Field( + default=None, + description="The URL for the local TGI endpoint (e.g., http://localhost:8080)", ) api_token: Optional[str] = Field( - default="", - description="The HF token for Hugging Face Inference Endpoints", + default=None, + description="The HF token for Hugging Face Inference Endpoints (will default to locally saved token if not provided)", ) + hf_namespace: Optional[str] = Field( + default=None, + description="The username/organization name for the Hugging Face Inference Endpoint", + ) + hf_endpoint_name: Optional[str] = Field( + default=None, + description="The name of the Hugging Face Inference Endpoint", + ) + + def is_inference_endpoint(self) -> bool: + return self.hf_namespace is not None and self.hf_endpoint_name is not None + + def is_local_tgi(self) -> bool: + return self.url is not None and self.url.startswith("http://localhost") diff --git a/llama_toolchain/inference/adapters/tgi/tgi.py b/llama_toolchain/inference/adapters/tgi/tgi.py index 4d32dbe5f..0557dfc04 100644 --- a/llama_toolchain/inference/adapters/tgi/tgi.py +++ b/llama_toolchain/inference/adapters/tgi/tgi.py @@ -7,6 +7,7 @@ from typing import AsyncGenerator +import requests from huggingface_hub import InferenceClient from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import StopReason @@ -29,7 +30,7 @@ HF_SUPPORTED_MODELS = { } -class TGIAdapter(Inference): +class LocalTGIAdapter(Inference): def __init__(self, config: TGIImplConfig) -> None: self.config = config @@ -38,10 +39,36 @@ class TGIAdapter(Inference): @property def client(self) -> InferenceClient: - return InferenceClient(base_url=self.config.url, token=self.config.api_token) + return InferenceClient(model=self.config.url, token=self.config.api_token) + + def _get_endpoint_info(self): + return {**self.client.get_endpoint_info(), "inference_url": self.config.url} async def initialize(self) -> None: - pass + try: + info = self._get_endpoint_info() + if "model_id" not in info: + raise RuntimeError("Missing model_id in model info") + if "max_total_tokens" not in info: + raise RuntimeError("Missing max_total_tokens in model info") + self.max_tokens = info["max_total_tokens"] + + model_id = info["model_id"] + model_name = next( + (name for name, id in HF_SUPPORTED_MODELS.items() if id == model_id), + None, + ) + if model_name is None: + raise RuntimeError( + f"TGI is serving model: {model_id}, use one of the supported models: {', '.join(HF_SUPPORTED_MODELS.values())}" + ) + self.model_name = model_name + self.inference_url = info["inference_url"] + except Exception as e: + import traceback + + traceback.print_exc() + raise RuntimeError(f"Error initializing TGIAdapter: {e}") async def shutdown(self) -> None: pass @@ -63,14 +90,19 @@ class TGIAdapter(Inference): model_input = self.formatter.encode_dialog_prompt(messages) prompt = self.tokenizer.decode(model_input.tokens) - model_info = self.client.get_endpoint_info(model=self.config.url) + input_tokens = len(model_input.tokens) max_new_tokens = min( - request.sampling_params.max_tokens or model_info["max_total_tokens"], - model_info["max_total_tokens"] - len(model_input.tokens) - 1, + request.sampling_params.max_tokens or (self.max_tokens - input_tokens), + self.max_tokens - input_tokens - 1, ) - options = self.get_chat_options(request) + print(f"Calculated max_new_tokens: {max_new_tokens}") + assert ( + request.model == self.model_name + ), f"Model mismatch, expected {self.model_name}, got {request.model}" + + options = self.get_chat_options(request) if not request.stream: response = self.client.text_generation( prompt=prompt, @@ -198,3 +230,32 @@ class TGIAdapter(Inference): stop_reason=stop_reason, ) ) + + +class InferenceEndpointAdapter(LocalTGIAdapter): + def __init__(self, config: TGIImplConfig) -> None: + super().__init__(config) + self.config.url = f"https://api.endpoints.huggingface.cloud/v2/endpoint/{config.hf_namespace}/{config.hf_endpoint_name}" + + @property + def client(self) -> InferenceClient: + return InferenceClient(model=self.inference_url, token=self.config.api_token) + + def _get_endpoint_info(self) -> Dict[str, Any]: + headers = { + "accept": "application/json", + "authorization": f"Bearer {self.config.api_token}", + } + response = requests.get(self.config.url, headers=headers) + response.raise_for_status() + endpoint_info = response.json() + return { + "inference_url": endpoint_info["status"]["url"], + "model_id": endpoint_info["model"]["repository"], + "max_total_tokens": int( + endpoint_info["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"] + ), + } + + async def initialize(self) -> None: + await super().initialize()