Use huggingface_hub inference client for TGI inference

This commit is contained in:
Celina Hanouti 2024-09-05 18:29:04 +02:00
parent 21bedc1596
commit e5bcfdac21
6 changed files with 179 additions and 142 deletions

View file

@ -248,44 +248,51 @@ llama stack list-distributions
``` ```
<pre style="font-family: monospace;"> <pre style="font-family: monospace;">
i+--------------------------------+---------------------------------------+----------------------------------------------------------------------+ +--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
| Distribution ID | Providers | Description | | Distribution ID | Providers | Description |
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+ +--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
| local | { | Use code from `llama_toolchain` itself to serve all llama stack APIs | | local | { | Use code from `llama_toolchain` itself to serve all llama stack APIs |
| | "inference": "meta-reference", | | | | "inference": "meta-reference", | |
| | "memory": "meta-reference-faiss", | | | | "memory": "meta-reference-faiss", | |
| | "safety": "meta-reference", | | | | "safety": "meta-reference", | |
| | "agentic_system": "meta-reference" | | | | "agentic_system": "meta-reference" | |
| | } | | | | } | |
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+ +--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
| remote | { | Point to remote services for all llama stack APIs | | remote | { | Point to remote services for all llama stack APIs |
| | "inference": "remote", | | | | "inference": "remote", | |
| | "safety": "remote", | | | | "safety": "remote", | |
| | "agentic_system": "remote", | | | | "agentic_system": "remote", | |
| | "memory": "remote" | | | | "memory": "remote" | |
| | } | | | | } | |
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+ +--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
| local-ollama | { | Like local, but use ollama for running LLM inference | | local-ollama | { | Like local, but use ollama for running LLM inference |
| | "inference": "remote::ollama", | | | | "inference": "remote::ollama", | |
| | "safety": "meta-reference", | | | | "safety": "meta-reference", | |
| | "agentic_system": "meta-reference", | | | | "agentic_system": "meta-reference", | |
| | "memory": "meta-reference-faiss" | | | | "memory": "meta-reference-faiss" | |
| | } | | | | } | |
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+ +--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
| local-plus-fireworks-inference | { | Use Fireworks.ai for running LLM inference | | local-plus-fireworks-inference | { | Use Fireworks.ai for running LLM inference |
| | "inference": "remote::fireworks", | | | | "inference": "remote::fireworks", | |
| | "safety": "meta-reference", | | | | "safety": "meta-reference", | |
| | "agentic_system": "meta-reference", | | | | "agentic_system": "meta-reference", | |
| | "memory": "meta-reference-faiss" | | | | "memory": "meta-reference-faiss" | |
| | } | | | | } | |
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+ +--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
| local-plus-together-inference | { | Use Together.ai for running LLM inference | | local-plus-together-inference | { | Use Together.ai for running LLM inference |
| | "inference": "remote::together", | | | | "inference": "remote::together", | |
| | "safety": "meta-reference", | | | | "safety": "meta-reference", | |
| | "agentic_system": "meta-reference", | | | | "agentic_system": "meta-reference", | |
| | "memory": "meta-reference-faiss" | | | | "memory": "meta-reference-faiss" | |
| | } | | | | } | |
+--------------------------------+---------------------------------------+----------------------------------------------------------------------+ +--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
| local-plus-tgi-inference | { | Use TGI (local or with <a href="https://huggingface.co/inference-endpoints/dedicated"> |
| | "inference": "remote::tgi", | Hugging Face Inference Endpoints</a>) for running LLM inference |
| | "safety": "meta-reference", | |
| | "agentic_system": "meta-reference", | |
| | "memory": "meta-reference-faiss" | |
| | } | |
+--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
</pre> </pre>
As you can see above, each “distribution” details the “providers” it is composed of. For example, `local` uses the “meta-reference” provider for inference while local-ollama relies on a different provider (Ollama) for inference. Similarly, you can use Fireworks or Together.AI for running inference as well. As you can see above, each “distribution” details the “providers” it is composed of. For example, `local` uses the “meta-reference” provider for inference while local-ollama relies on a different provider (Ollama) for inference. Similarly, you can use Fireworks or Together.AI for running inference as well.

View file

@ -58,6 +58,16 @@ def available_distribution_specs() -> List[DistributionSpec]:
Api.memory: "meta-reference-faiss", Api.memory: "meta-reference-faiss",
}, },
), ),
DistributionSpec(
distribution_id="local-plus-tgi-inference",
description="Use TGI for running LLM inference",
providers={
Api.inference: remote_provider_id("tgi"),
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
Api.memory: "meta-reference-faiss",
},
),
] ]

View file

@ -4,12 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_toolchain.core.datatypes import RemoteProviderConfig from .config import TGIImplConfig
async def get_adapter_impl(config: RemoteProviderConfig, _deps): async def get_adapter_impl(config: TGIImplConfig, _deps):
from .tgi import TGIInferenceAdapter from .tgi import TGIAdapter
impl = TGIInferenceAdapter(config.url) assert isinstance(
config, TGIImplConfig
), f"Unexpected config type: {type(config)}"
impl = TGIAdapter(config)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field, field_validator
@json_schema_type
class TGIImplConfig(BaseModel):
url: str = Field(
default="https://api-inference.huggingface.co",
description="The URL for the TGI endpoint",
)
api_token: Optional[str] = Field(
default="",
description="The HF token for Hugging Face Inference Endpoints",
)

View file

@ -4,63 +4,44 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import AsyncGenerator, List
import httpx from typing import AsyncGenerator
from huggingface_hub import InferenceClient
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from text_generation import Client from llama_toolchain.inference.api import *
from llama_toolchain.inference.api.api import ( # noqa: F403
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
)
from llama_toolchain.inference.api import * # noqa: F403 from .config import TGIImplConfig
from llama_toolchain.inference.prepare_messages import prepare_messages
HF_SUPPORTED_MODELS = {
SUPPORTED_MODELS = {
"Meta-Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct", "Meta-Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"Meta-Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct", "Meta-Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct",
"Meta-Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct", "Meta-Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct",
} }
class TGIInferenceAdapter(Inference): class TGIAdapter(Inference):
def __init__(self, url: str) -> None:
self.url = url.rstrip("/") def __init__(self, config: TGIImplConfig) -> None:
self.config = config
self.tokenizer = Tokenizer.get_instance() self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer) self.formatter = ChatFormat(self.tokenizer)
self.model = None
self.max_tokens = None @property
def client(self) -> InferenceClient:
return InferenceClient(base_url=self.config.url, token=self.config.api_token)
async def initialize(self) -> None: async def initialize(self) -> None:
hf_models = {v: k for k, v in SUPPORTED_MODELS.items()} pass
try:
print(f"Connecting to TGI server at: {self.url}")
async with httpx.AsyncClient() as client:
response = await client.get(f"{self.url}/info")
response.raise_for_status()
info = response.json()
if "model_id" not in info:
raise RuntimeError("Missing model_id in model info")
if "max_total_tokens" not in info:
raise RuntimeError("Missing max_total_tokens in model info")
self.max_tokens = info["max_total_tokens"]
model_id = info["model_id"]
if model_id not in hf_models:
raise RuntimeError(
f"TGI is serving model: {model_id}, use one of the supported models: {','.join(hf_models.keys())}"
)
self.model = hf_models[model_id]
except Exception as e:
import traceback
traceback.print_exc()
raise RuntimeError("Could not connect to TGI server") from e
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
@ -68,15 +49,25 @@ class TGIInferenceAdapter(Inference):
async def completion(self, request: CompletionRequest) -> AsyncGenerator: async def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError() raise NotImplementedError()
def _convert_messages(self, messages: List[Message]) -> List[Message]: def _convert_messages(self, messages: list[Message]) -> List[Message]: # type: ignore
ret = [] tgi_messages = []
for message in messages: for message in messages:
if message.role == "ipython": if message.role == "ipython":
role = "tool" role = "tool"
else: else:
role = message.role role = message.role
ret.append({"role": role, "content": message.content}) tgi_messages.append({"role": role, "content": message.content})
return ret
return tgi_messages
def resolve_hf_model(self, model_name: str) -> str:
model = resolve_model(model_name)
assert (
model is not None
and model.descriptor(shorten_default_variant=True) in HF_SUPPORTED_MODELS
), f"Unsupported model: {model_name}, use one of the supported models: {','.join(HF_SUPPORTED_MODELS.keys())}"
return HF_SUPPORTED_MODELS.get(model.descriptor(shorten_default_variant=True))
def get_chat_options(self, request: ChatCompletionRequest) -> dict: def get_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {} options = {}
@ -88,48 +79,34 @@ class TGIInferenceAdapter(Inference):
return options return options
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
messages = prepare_messages(request)
model_input = self.formatter.encode_dialog_prompt(messages)
prompt = self.tokenizer.decode(model_input.tokens)
max_new_tokens = min(
request.sampling_params.max_tokens or self.max_tokens,
self.max_tokens - len(model_input.tokens) - 1,
)
if request.model != self.model:
raise ValueError(
f"Model mismatch, expected: {self.model}, got: {request.model}"
)
options = self.get_chat_options(request) options = self.get_chat_options(request)
messages = self._convert_messages(request.messages)
client = Client(base_url=self.url)
if not request.stream: if not request.stream:
r = client.generate( response = self.client.chat_completion(
prompt, messages=messages,
max_new_tokens=max_new_tokens, stream=False,
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
**options, **options,
) )
stop_reason = None
if r.details.finish_reason: if response.choices[0].finish_reason:
if r.details.finish_reason == "stop": if (
response.choices[0].finish_reason == "stop_sequence"
or response.choices[0].finish_reason == "eos_token"
):
stop_reason = StopReason.end_of_turn stop_reason = StopReason.end_of_turn
elif r.details.finish_reason == "length": elif response.choices[0].finish_reason == "length":
stop_reason = StopReason.out_of_tokens stop_reason = StopReason.out_of_tokens
else:
stop_reason = StopReason.end_of_message
else:
stop_reason = StopReason.out_of_tokens
completion_message = self.formatter.decode_assistant_message_from_content( completion_message = self.formatter.decode_assistant_message_from_content(
r.generated_text, stop_reason response.choices[0].message.content,
stop_reason,
) )
yield ChatCompletionResponse( yield ChatCompletionResponse(
completion_message=completion_message, completion_message=completion_message,
logprobs=None, logprobs=None,
) )
else: else:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
@ -137,24 +114,35 @@ class TGIInferenceAdapter(Inference):
delta="", delta="",
) )
) )
buffer = "" buffer = ""
ipython = False ipython = False
stop_reason = None stop_reason = None
tokens = []
for response in client.generate_stream( for chunk in self.client.chat_completion(
prompt, messages=messages, stream=True, **options
max_new_tokens=max_new_tokens,
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
**options,
): ):
token_result = response.token if chunk.choices[0].finish_reason:
if (
stop_reason is None
and chunk.choices[0].finish_reason == "stop_sequence"
) or (
stop_reason is None
and chunk.choices[0].finish_reason == "eos_token"
):
stop_reason = StopReason.end_of_turn
elif (
stop_reason is None
and chunk.choices[0].finish_reason == "length"
):
stop_reason = StopReason.out_of_tokens
break
buffer += token_result.text text = chunk.choices[0].delta.content
tokens.append(token_result.id) if text is None:
continue
if not ipython and buffer.startswith("<|python_tag|>"): # check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"):
ipython = True ipython = True
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
@ -165,27 +153,25 @@ class TGIInferenceAdapter(Inference):
), ),
) )
) )
buffer = buffer[len("<|python_tag|>") :] buffer += text
continue continue
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if ipython: if ipython:
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
buffer += text
delta = ToolCallDelta( delta = ToolCallDelta(
content=text, content=text,
parse_status=ToolCallParseStatus.in_progress, parse_status=ToolCallParseStatus.in_progress,
) )
else:
delta = text
if stop_reason is None:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress, event_type=ChatCompletionResponseEventType.progress,
@ -193,12 +179,20 @@ class TGIInferenceAdapter(Inference):
stop_reason=stop_reason, stop_reason=stop_reason,
) )
) )
else:
if stop_reason is None: buffer += text
stop_reason = StopReason.out_of_tokens yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text,
stop_reason=stop_reason,
)
)
# parse tool calls and report errors # parse tool calls and report errors
message = self.formatter.decode_assistant_message(tokens, stop_reason) message = self.formatter.decode_assistant_message_from_content(
buffer, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0 parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls: if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(

View file

@ -39,8 +39,9 @@ def available_inference_providers() -> List[ProviderSpec]:
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter=AdapterSpec(
adapter_id="tgi", adapter_id="tgi",
pip_packages=["text-generation"], pip_packages=["huggingface_hub"],
module="llama_toolchain.inference.adapters.tgi", module="llama_toolchain.inference.adapters.tgi",
config_class="llama_toolchain.inference.adapters.tgi.TGIImplConfig",
), ),
), ),
remote_provider_spec( remote_provider_spec(