[Inference] Use huggingface_hub inference client for TGI adapter (#53)

* Use huggingface_hub inference client for TGI inference

* Update the default value for TGI URL

* Use InferenceClient.text_generation for TGI inference

* Fixes post-review and split TGI adapter into local and Inference Endpoints ones

* Update CLI reference and add typing

* Rename TGI Adapter class

* Use HfApi to get the namespace when not provide in the hf endpoint name

* Remove unecessary method argument

* Improve TGI adapter initialization condition

* Move helper into impl file + fix merging conflicts
This commit is contained in:
Celina Hanouti 2024-09-12 18:11:35 +02:00 committed by GitHub
parent 191cd28831
commit 736092f6bc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 171 additions and 72 deletions

View file

@ -286,6 +286,13 @@ i+-------------------------------+---------------------------------------+------
| | "memory": "meta-reference-faiss" | | | | "memory": "meta-reference-faiss" | |
| | } | | | | } | |
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+ +--------------------------------+---------------------------------------+----------------------------------------------------------------------+
| local-plus-tgi-inference | { | Use TGI (local or with [Hugging Face Inference Endpoints](https:// |
| | "inference": "remote::tgi", | huggingface.co/inference-endpoints/dedicated)) for running LLM |
| | "safety": "meta-reference", | inference. When using HF Inference Endpoints, you must provide the |
| | "agentic_system": "meta-reference", | name of the endpoint. |
| | "memory": "meta-reference-faiss" | |
| | } | |
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
</pre> </pre>
As you can see above, each “distribution” details the “providers” it is composed of. For example, `local` uses the “meta-reference” provider for inference while local-ollama relies on a different provider (Ollama) for inference. Similarly, you can use Fireworks or Together.AI for running inference as well. As you can see above, each “distribution” details the “providers” it is composed of. For example, `local` uses the “meta-reference” provider for inference while local-ollama relies on a different provider (Ollama) for inference. Similarly, you can use Fireworks or Together.AI for running inference as well.

View file

@ -65,11 +65,23 @@ def available_distribution_specs() -> List[DistributionSpec]:
Api.telemetry: "console", Api.telemetry: "console",
}, },
), ),
DistributionSpec(
distribution_type="local-plus-tgi-inference",
description="Use TGI for running LLM inference",
providers={
Api.inference: remote_provider_type("tgi"),
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
Api.memory: "meta-reference-faiss",
},
),
] ]
@lru_cache() @lru_cache()
def resolve_distribution_spec(distribution_type: str) -> Optional[DistributionSpec]: def resolve_distribution_spec(
distribution_type: str,
) -> Optional[DistributionSpec]:
for spec in available_distribution_specs(): for spec in available_distribution_specs():
if spec.distribution_type == distribution_type: if spec.distribution_type == distribution_type:
return spec return spec

View file

@ -4,12 +4,21 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_toolchain.core.datatypes import RemoteProviderConfig from .config import TGIImplConfig
from .tgi import InferenceEndpointAdapter, TGIAdapter
async def get_adapter_impl(config: RemoteProviderConfig, _deps): async def get_adapter_impl(config: TGIImplConfig, _deps):
from .tgi import TGIInferenceAdapter assert isinstance(config, TGIImplConfig), f"Unexpected config type: {type(config)}"
if config.url is not None:
impl = TGIAdapter(config)
elif config.is_inference_endpoint():
impl = InferenceEndpointAdapter(config)
else:
raise ValueError(
"Invalid configuration. Specify either an URL or HF Inference Endpoint details (namespace and endpoint name)."
)
impl = TGIInferenceAdapter(config.url)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class TGIImplConfig(BaseModel):
url: Optional[str] = Field(
default=None,
description="The URL for the local TGI endpoint (e.g., http://localhost:8080)",
)
api_token: Optional[str] = Field(
default=None,
description="The HF token for Hugging Face Inference Endpoints (will default to locally saved token if not provided)",
)
hf_endpoint_name: Optional[str] = Field(
default=None,
description="The name of the Hugging Face Inference Endpoint : can be either in the format of '{namespace}/{endpoint_name}' (namespace can be the username or organization name) or just '{endpoint_name}' if logged into the same account as the namespace",
)
def is_inference_endpoint(self) -> bool:
return self.hf_endpoint_name is not None

View file

@ -4,63 +4,68 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import AsyncGenerator, List
import httpx from typing import Any, AsyncGenerator, Dict
import requests
from huggingface_hub import HfApi, InferenceClient
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import StopReason
from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from text_generation import Client
from llama_toolchain.inference.api import * # noqa: F403 from llama_toolchain.inference.api import * # noqa: F403
from llama_toolchain.inference.prepare_messages import prepare_messages from llama_toolchain.inference.prepare_messages import prepare_messages
from .config import TGIImplConfig
SUPPORTED_MODELS = { HF_SUPPORTED_MODELS = {
"Meta-Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct", "Meta-Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"Meta-Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct", "Meta-Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct",
"Meta-Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct", "Meta-Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct",
} }
class TGIInferenceAdapter(Inference): class TGIAdapter(Inference):
def __init__(self, url: str) -> None: def __init__(self, config: TGIImplConfig) -> None:
self.url = url.rstrip("/") self.config = config
self.tokenizer = Tokenizer.get_instance() self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer) self.formatter = ChatFormat(self.tokenizer)
self.model = None
self.max_tokens = None @property
def client(self) -> InferenceClient:
return InferenceClient(model=self.config.url, token=self.config.api_token)
def _get_endpoint_info(self) -> Dict[str, Any]:
return {
**self.client.get_endpoint_info(),
"inference_url": self.config.url,
}
async def initialize(self) -> None: async def initialize(self) -> None:
hf_models = {v: k for k, v in SUPPORTED_MODELS.items()}
try: try:
print(f"Connecting to TGI server at: {self.url}") info = self._get_endpoint_info()
async with httpx.AsyncClient() as client: if "model_id" not in info:
response = await client.get(f"{self.url}/info") raise RuntimeError("Missing model_id in model info")
response.raise_for_status() if "max_total_tokens" not in info:
info = response.json() raise RuntimeError("Missing max_total_tokens in model info")
if "model_id" not in info: self.max_tokens = info["max_total_tokens"]
raise RuntimeError("Missing model_id in model info")
if "max_total_tokens" not in info:
raise RuntimeError("Missing max_total_tokens in model info")
self.max_tokens = info["max_total_tokens"]
model_id = info["model_id"] model_id = info["model_id"]
if model_id not in hf_models: model_name = next(
raise RuntimeError( (name for name, id in HF_SUPPORTED_MODELS.items() if id == model_id),
f"TGI is serving model: {model_id}, use one of the supported models: {','.join(hf_models.keys())}" None,
) )
if model_name is None:
self.model = hf_models[model_id] raise RuntimeError(
f"TGI is serving model: {model_id}, use one of the supported models: {', '.join(HF_SUPPORTED_MODELS.values())}"
)
self.model_name = model_name
self.inference_url = info["inference_url"]
except Exception as e: except Exception as e:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
raise RuntimeError("Could not connect to TGI server") from e raise RuntimeError(f"Error initializing TGIAdapter: {e}") from e
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
@ -68,16 +73,6 @@ class TGIInferenceAdapter(Inference):
async def completion(self, request: CompletionRequest) -> AsyncGenerator: async def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError() raise NotImplementedError()
def _convert_messages(self, messages: List[Message]) -> List[Message]:
ret = []
for message in messages:
if message.role == "ipython":
role = "tool"
else:
role = message.role
ret.append({"role": role, "content": message.content})
return ret
def get_chat_options(self, request: ChatCompletionRequest) -> dict: def get_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {} options = {}
if request.sampling_params is not None: if request.sampling_params is not None:
@ -89,47 +84,47 @@ class TGIInferenceAdapter(Inference):
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
messages = prepare_messages(request) messages = prepare_messages(request)
model_input = self.formatter.encode_dialog_prompt(messages) model_input = self.formatter.encode_dialog_prompt(messages)
prompt = self.tokenizer.decode(model_input.tokens) prompt = self.tokenizer.decode(model_input.tokens)
input_tokens = len(model_input.tokens)
max_new_tokens = min( max_new_tokens = min(
request.sampling_params.max_tokens or self.max_tokens, request.sampling_params.max_tokens or (self.max_tokens - input_tokens),
self.max_tokens - len(model_input.tokens) - 1, self.max_tokens - input_tokens - 1,
) )
if request.model != self.model: print(f"Calculated max_new_tokens: {max_new_tokens}")
raise ValueError(
f"Model mismatch, expected: {self.model}, got: {request.model}" assert (
) request.model == self.model_name
), f"Model mismatch, expected {self.model_name}, got {request.model}"
options = self.get_chat_options(request) options = self.get_chat_options(request)
client = Client(base_url=self.url)
if not request.stream: if not request.stream:
r = client.generate( response = self.client.text_generation(
prompt, prompt=prompt,
stream=False,
details=True,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
stop_sequences=["<|eom_id|>", "<|eot_id|>"], stop_sequences=["<|eom_id|>", "<|eot_id|>"],
**options, **options,
) )
stop_reason = None
if r.details.finish_reason: if response.details.finish_reason:
if r.details.finish_reason == "stop": if response.details.finish_reason == "stop":
stop_reason = StopReason.end_of_turn stop_reason = StopReason.end_of_turn
elif r.details.finish_reason == "length": elif response.details.finish_reason == "length":
stop_reason = StopReason.out_of_tokens stop_reason = StopReason.out_of_tokens
else:
stop_reason = StopReason.end_of_message
else:
stop_reason = StopReason.out_of_tokens
completion_message = self.formatter.decode_assistant_message_from_content( completion_message = self.formatter.decode_assistant_message_from_content(
r.generated_text, stop_reason response.generated_text,
stop_reason,
) )
yield ChatCompletionResponse( yield ChatCompletionResponse(
completion_message=completion_message, completion_message=completion_message,
logprobs=None, logprobs=None,
) )
else: else:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
@ -137,14 +132,15 @@ class TGIInferenceAdapter(Inference):
delta="", delta="",
) )
) )
buffer = "" buffer = ""
ipython = False ipython = False
stop_reason = None stop_reason = None
tokens = [] tokens = []
for response in client.generate_stream( for response in self.client.text_generation(
prompt, prompt=prompt,
stream=True,
details=True,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
stop_sequences=["<|eom_id|>", "<|eot_id|>"], stop_sequences=["<|eom_id|>", "<|eot_id|>"],
**options, **options,
@ -231,3 +227,48 @@ class TGIInferenceAdapter(Inference):
stop_reason=stop_reason, stop_reason=stop_reason,
) )
) )
class InferenceEndpointAdapter(TGIAdapter):
def __init__(self, config: TGIImplConfig) -> None:
super().__init__(config)
self.config.url = self._construct_endpoint_url()
def _construct_endpoint_url(self) -> str:
hf_endpoint_name = self.config.hf_endpoint_name
assert hf_endpoint_name.count("/") <= 1, (
"Endpoint name must be in the format of 'namespace/endpoint_name' "
"or 'endpoint_name'"
)
if "/" not in hf_endpoint_name:
hf_namespace: str = self.get_namespace()
endpoint_path = f"{hf_namespace}/{hf_endpoint_name}"
else:
endpoint_path = hf_endpoint_name
return f"https://api.endpoints.huggingface.cloud/v2/endpoint/{endpoint_path}"
def get_namespace(self) -> str:
return HfApi().whoami()["name"]
@property
def client(self) -> InferenceClient:
return InferenceClient(model=self.inference_url, token=self.config.api_token)
def _get_endpoint_info(self) -> Dict[str, Any]:
headers = {
"accept": "application/json",
"authorization": f"Bearer {self.config.api_token}",
}
response = requests.get(self.config.url, headers=headers)
response.raise_for_status()
endpoint_info = response.json()
return {
"inference_url": endpoint_info["status"]["url"],
"model_id": endpoint_info["model"]["repository"],
"max_total_tokens": int(
endpoint_info["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"]
),
}
async def initialize(self) -> None:
await super().initialize()

View file

@ -39,8 +39,9 @@ def available_providers() -> List[ProviderSpec]:
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter=AdapterSpec(
adapter_id="tgi", adapter_id="tgi",
pip_packages=["text-generation"], pip_packages=["huggingface_hub"],
module="llama_toolchain.inference.adapters.tgi", module="llama_toolchain.inference.adapters.tgi",
config_class="llama_toolchain.inference.adapters.tgi.TGIImplConfig",
), ),
), ),
remote_provider_spec( remote_provider_spec(