From e5bcfdac21da7796ee23fb4e5b19960766dd708b Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Thu, 5 Sep 2024 18:29:04 +0200 Subject: [PATCH] Use huggingface_hub inference client for TGI inference --- docs/cli_reference.md | 83 ++++---- llama_toolchain/core/distribution_registry.py | 10 + .../inference/adapters/tgi/__init__.py | 11 +- .../inference/adapters/tgi/config.py | 22 ++ llama_toolchain/inference/adapters/tgi/tgi.py | 192 +++++++++--------- llama_toolchain/inference/providers.py | 3 +- 6 files changed, 179 insertions(+), 142 deletions(-) create mode 100644 llama_toolchain/inference/adapters/tgi/config.py diff --git a/docs/cli_reference.md b/docs/cli_reference.md index d46cf722a..68942d552 100644 --- a/docs/cli_reference.md +++ b/docs/cli_reference.md @@ -248,44 +248,51 @@ llama stack list-distributions ```
-i+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
-| Distribution ID                | Providers                             | Description                                                          |
-+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
-| local                          | {                                     | Use code from `llama_toolchain` itself to serve all llama stack APIs |
-|                                |   "inference": "meta-reference",      |                                                                      |
-|                                |   "memory": "meta-reference-faiss",   |                                                                      |
-|                                |   "safety": "meta-reference",         |                                                                      |
-|                                |   "agentic_system": "meta-reference"  |                                                                      |
-|                                | }                                     |                                                                      |
-+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
-| remote                         | {                                     | Point to remote services for all llama stack APIs                    |
-|                                |   "inference": "remote",              |                                                                      |
-|                                |   "safety": "remote",                 |                                                                      |
-|                                |   "agentic_system": "remote",         |                                                                      |
-|                                |   "memory": "remote"                  |                                                                      |
-|                                | }                                     |                                                                      |
-+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
-| local-ollama                   | {                                     | Like local, but use ollama for running LLM inference                 |
-|                                |   "inference": "remote::ollama",      |                                                                      |
-|                                |   "safety": "meta-reference",         |                                                                      |
-|                                |   "agentic_system": "meta-reference", |                                                                      |
-|                                |   "memory": "meta-reference-faiss"    |                                                                      |
-|                                | }                                     |                                                                      |
-+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
-| local-plus-fireworks-inference | {                                     | Use Fireworks.ai for running LLM inference                           |
-|                                |   "inference": "remote::fireworks",   |                                                                      |
-|                                |   "safety": "meta-reference",         |                                                                      |
-|                                |   "agentic_system": "meta-reference", |                                                                      |
-|                                |   "memory": "meta-reference-faiss"    |                                                                      |
-|                                | }                                     |                                                                      |
-+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
-| local-plus-together-inference  | {                                     | Use Together.ai for running LLM inference                            |
-|                                |   "inference": "remote::together",    |                                                                      |
-|                                |   "safety": "meta-reference",         |                                                                      |
-|                                |   "agentic_system": "meta-reference", |                                                                      |
-|                                |   "memory": "meta-reference-faiss"    |                                                                      |
-|                                | }                                     |                                                                      |
-+--------------------------------+---------------------------------------+----------------------------------------------------------------------+
++--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
+| Distribution ID                | Providers                             | Description                                                                               |
++--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
+| local                          | {                                     | Use code from `llama_toolchain` itself to serve all llama stack APIs                      |
+|                                |   "inference": "meta-reference",      |                                                                                           |
+|                                |   "memory": "meta-reference-faiss",   |                                                                                           |
+|                                |   "safety": "meta-reference",         |                                                                                           |
+|                                |   "agentic_system": "meta-reference"  |                                                                                           |
+|                                | }                                     |                                                                                           |
++--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
+| remote                         | {                                     | Point to remote services for all llama stack APIs                                         |
+|                                |   "inference": "remote",              |                                                                                           |
+|                                |   "safety": "remote",                 |                                                                                           |
+|                                |   "agentic_system": "remote",         |                                                                                           |
+|                                |   "memory": "remote"                  |                                                                                           |
+|                                | }                                     |                                                                                           |
++--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
+| local-ollama                   | {                                     | Like local, but use ollama for running LLM inference                                      |
+|                                |   "inference": "remote::ollama",      |                                                                                           |
+|                                |   "safety": "meta-reference",         |                                                                                           |
+|                                |   "agentic_system": "meta-reference", |                                                                                           |
+|                                |   "memory": "meta-reference-faiss"    |                                                                                           |
+|                                | }                                     |                                                                                           |
++--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
+| local-plus-fireworks-inference | {                                     | Use Fireworks.ai for running LLM inference                                                |
+|                                |   "inference": "remote::fireworks",   |                                                                                           |
+|                                |   "safety": "meta-reference",         |                                                                                           |
+|                                |   "agentic_system": "meta-reference", |                                                                                           |
+|                                |   "memory": "meta-reference-faiss"    |                                                                                           |
+|                                | }                                     |                                                                                           |
++--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
+| local-plus-together-inference  | {                                     | Use Together.ai for running LLM inference                                                 |
+|                                |   "inference": "remote::together",    |                                                                                           |
+|                                |   "safety": "meta-reference",         |                                                                                           |
+|                                |   "agentic_system": "meta-reference", |                                                                                           |
+|                                |   "memory": "meta-reference-faiss"    |                                                                                           |
+|                                | }                                     |                                                                                           |
++--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
+| local-plus-tgi-inference       | {                                     | Use TGI (local or with     |
+|                                |   "inference": "remote::tgi",         | Hugging Face Inference Endpoints) for running LLM inference                           |
+|                                |   "safety": "meta-reference",         |                                                                                           |
+|                                |   "agentic_system": "meta-reference", |                                                                                           |
+|                                |   "memory": "meta-reference-faiss"    |                                                                                           |
+|                                | }                                     |                                                                                           |
++--------------------------------+---------------------------------------+-------------------------------------------------------------------------------------------+
 
As you can see above, each “distribution” details the “providers” it is composed of. For example, `local` uses the “meta-reference” provider for inference while local-ollama relies on a different provider (Ollama) for inference. Similarly, you can use Fireworks or Together.AI for running inference as well. diff --git a/llama_toolchain/core/distribution_registry.py b/llama_toolchain/core/distribution_registry.py index e134fdab6..9413e1374 100644 --- a/llama_toolchain/core/distribution_registry.py +++ b/llama_toolchain/core/distribution_registry.py @@ -58,6 +58,16 @@ def available_distribution_specs() -> List[DistributionSpec]: Api.memory: "meta-reference-faiss", }, ), + DistributionSpec( + distribution_id="local-plus-tgi-inference", + description="Use TGI for running LLM inference", + providers={ + Api.inference: remote_provider_id("tgi"), + Api.safety: "meta-reference", + Api.agentic_system: "meta-reference", + Api.memory: "meta-reference-faiss", + }, + ), ] diff --git a/llama_toolchain/inference/adapters/tgi/__init__.py b/llama_toolchain/inference/adapters/tgi/__init__.py index 4940667b4..86faf94cb 100644 --- a/llama_toolchain/inference/adapters/tgi/__init__.py +++ b/llama_toolchain/inference/adapters/tgi/__init__.py @@ -4,12 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_toolchain.core.datatypes import RemoteProviderConfig +from .config import TGIImplConfig -async def get_adapter_impl(config: RemoteProviderConfig, _deps): - from .tgi import TGIInferenceAdapter +async def get_adapter_impl(config: TGIImplConfig, _deps): + from .tgi import TGIAdapter - impl = TGIInferenceAdapter(config.url) + assert isinstance( + config, TGIImplConfig + ), f"Unexpected config type: {type(config)}" + impl = TGIAdapter(config) await impl.initialize() return impl diff --git a/llama_toolchain/inference/adapters/tgi/config.py b/llama_toolchain/inference/adapters/tgi/config.py new file mode 100644 index 000000000..fe969f4e6 --- /dev/null +++ b/llama_toolchain/inference/adapters/tgi/config.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field, field_validator + + +@json_schema_type +class TGIImplConfig(BaseModel): + url: str = Field( + default="https://api-inference.huggingface.co", + description="The URL for the TGI endpoint", + ) + api_token: Optional[str] = Field( + default="", + description="The HF token for Hugging Face Inference Endpoints", + ) diff --git a/llama_toolchain/inference/adapters/tgi/tgi.py b/llama_toolchain/inference/adapters/tgi/tgi.py index 7eb36ac36..849f67e69 100644 --- a/llama_toolchain/inference/adapters/tgi/tgi.py +++ b/llama_toolchain/inference/adapters/tgi/tgi.py @@ -4,63 +4,44 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator, List -import httpx +from typing import AsyncGenerator +from huggingface_hub import InferenceClient from llama_models.llama3.api.chat_format import ChatFormat - from llama_models.llama3.api.datatypes import Message, StopReason from llama_models.llama3.api.tokenizer import Tokenizer +from llama_models.sku_list import resolve_model -from text_generation import Client +from llama_toolchain.inference.api import * +from llama_toolchain.inference.api.api import ( # noqa: F403 + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseStreamChunk, +) -from llama_toolchain.inference.api import * # noqa: F403 -from llama_toolchain.inference.prepare_messages import prepare_messages +from .config import TGIImplConfig - -SUPPORTED_MODELS = { +HF_SUPPORTED_MODELS = { "Meta-Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct", "Meta-Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct", "Meta-Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct", } -class TGIInferenceAdapter(Inference): - def __init__(self, url: str) -> None: - self.url = url.rstrip("/") +class TGIAdapter(Inference): + + def __init__(self, config: TGIImplConfig) -> None: + self.config = config self.tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(self.tokenizer) - self.model = None - self.max_tokens = None + + @property + def client(self) -> InferenceClient: + return InferenceClient(base_url=self.config.url, token=self.config.api_token) async def initialize(self) -> None: - hf_models = {v: k for k, v in SUPPORTED_MODELS.items()} - - try: - print(f"Connecting to TGI server at: {self.url}") - async with httpx.AsyncClient() as client: - response = await client.get(f"{self.url}/info") - response.raise_for_status() - info = response.json() - if "model_id" not in info: - raise RuntimeError("Missing model_id in model info") - if "max_total_tokens" not in info: - raise RuntimeError("Missing max_total_tokens in model info") - self.max_tokens = info["max_total_tokens"] - - model_id = info["model_id"] - if model_id not in hf_models: - raise RuntimeError( - f"TGI is serving model: {model_id}, use one of the supported models: {','.join(hf_models.keys())}" - ) - - self.model = hf_models[model_id] - except Exception as e: - import traceback - - traceback.print_exc() - raise RuntimeError("Could not connect to TGI server") from e + pass async def shutdown(self) -> None: pass @@ -68,15 +49,25 @@ class TGIInferenceAdapter(Inference): async def completion(self, request: CompletionRequest) -> AsyncGenerator: raise NotImplementedError() - def _convert_messages(self, messages: List[Message]) -> List[Message]: - ret = [] + def _convert_messages(self, messages: list[Message]) -> List[Message]: # type: ignore + tgi_messages = [] for message in messages: if message.role == "ipython": role = "tool" else: role = message.role - ret.append({"role": role, "content": message.content}) - return ret + tgi_messages.append({"role": role, "content": message.content}) + + return tgi_messages + + def resolve_hf_model(self, model_name: str) -> str: + model = resolve_model(model_name) + assert ( + model is not None + and model.descriptor(shorten_default_variant=True) in HF_SUPPORTED_MODELS + ), f"Unsupported model: {model_name}, use one of the supported models: {','.join(HF_SUPPORTED_MODELS.keys())}" + + return HF_SUPPORTED_MODELS.get(model.descriptor(shorten_default_variant=True)) def get_chat_options(self, request: ChatCompletionRequest) -> dict: options = {} @@ -88,48 +79,34 @@ class TGIInferenceAdapter(Inference): return options async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: - messages = prepare_messages(request) - - model_input = self.formatter.encode_dialog_prompt(messages) - prompt = self.tokenizer.decode(model_input.tokens) - max_new_tokens = min( - request.sampling_params.max_tokens or self.max_tokens, - self.max_tokens - len(model_input.tokens) - 1, - ) - - if request.model != self.model: - raise ValueError( - f"Model mismatch, expected: {self.model}, got: {request.model}" - ) - options = self.get_chat_options(request) + messages = self._convert_messages(request.messages) - client = Client(base_url=self.url) if not request.stream: - r = client.generate( - prompt, - max_new_tokens=max_new_tokens, - stop_sequences=["<|eom_id|>", "<|eot_id|>"], + response = self.client.chat_completion( + messages=messages, + stream=False, **options, ) - - if r.details.finish_reason: - if r.details.finish_reason == "stop": + stop_reason = None + if response.choices[0].finish_reason: + if ( + response.choices[0].finish_reason == "stop_sequence" + or response.choices[0].finish_reason == "eos_token" + ): stop_reason = StopReason.end_of_turn - elif r.details.finish_reason == "length": + elif response.choices[0].finish_reason == "length": stop_reason = StopReason.out_of_tokens - else: - stop_reason = StopReason.end_of_message - else: - stop_reason = StopReason.out_of_tokens completion_message = self.formatter.decode_assistant_message_from_content( - r.generated_text, stop_reason + response.choices[0].message.content, + stop_reason, ) yield ChatCompletionResponse( completion_message=completion_message, logprobs=None, ) + else: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( @@ -137,24 +114,35 @@ class TGIInferenceAdapter(Inference): delta="", ) ) - buffer = "" ipython = False stop_reason = None - tokens = [] - for response in client.generate_stream( - prompt, - max_new_tokens=max_new_tokens, - stop_sequences=["<|eom_id|>", "<|eot_id|>"], - **options, + for chunk in self.client.chat_completion( + messages=messages, stream=True, **options ): - token_result = response.token + if chunk.choices[0].finish_reason: + if ( + stop_reason is None + and chunk.choices[0].finish_reason == "stop_sequence" + ) or ( + stop_reason is None + and chunk.choices[0].finish_reason == "eos_token" + ): + stop_reason = StopReason.end_of_turn + elif ( + stop_reason is None + and chunk.choices[0].finish_reason == "length" + ): + stop_reason = StopReason.out_of_tokens + break - buffer += token_result.text - tokens.append(token_result.id) + text = chunk.choices[0].delta.content + if text is None: + continue - if not ipython and buffer.startswith("<|python_tag|>"): + # check if its a tool call ( aka starts with <|python_tag|> ) + if not ipython and text.startswith("<|python_tag|>"): ipython = True yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( @@ -165,27 +153,25 @@ class TGIInferenceAdapter(Inference): ), ) ) - buffer = buffer[len("<|python_tag|>") :] + buffer += text continue - if token_result.text == "<|eot_id|>": - stop_reason = StopReason.end_of_turn - text = "" - elif token_result.text == "<|eom_id|>": - stop_reason = StopReason.end_of_message - text = "" - else: - text = token_result.text - if ipython: + if text == "<|eot_id|>": + stop_reason = StopReason.end_of_turn + text = "" + continue + elif text == "<|eom_id|>": + stop_reason = StopReason.end_of_message + text = "" + continue + + buffer += text delta = ToolCallDelta( content=text, parse_status=ToolCallParseStatus.in_progress, ) - else: - delta = text - if stop_reason is None: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, @@ -193,12 +179,20 @@ class TGIInferenceAdapter(Inference): stop_reason=stop_reason, ) ) - - if stop_reason is None: - stop_reason = StopReason.out_of_tokens + else: + buffer += text + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=text, + stop_reason=stop_reason, + ) + ) # parse tool calls and report errors - message = self.formatter.decode_assistant_message(tokens, stop_reason) + message = self.formatter.decode_assistant_message_from_content( + buffer, stop_reason + ) parsed_tool_calls = len(message.tool_calls) > 0 if ipython and not parsed_tool_calls: yield ChatCompletionResponseStreamChunk( diff --git a/llama_toolchain/inference/providers.py b/llama_toolchain/inference/providers.py index b469cb29b..e9f6b4072 100644 --- a/llama_toolchain/inference/providers.py +++ b/llama_toolchain/inference/providers.py @@ -39,8 +39,9 @@ def available_inference_providers() -> List[ProviderSpec]: api=Api.inference, adapter=AdapterSpec( adapter_id="tgi", - pip_packages=["text-generation"], + pip_packages=["huggingface_hub"], module="llama_toolchain.inference.adapters.tgi", + config_class="llama_toolchain.inference.adapters.tgi.TGIImplConfig", ), ), remote_provider_spec(