diff --git a/llama_toolchain/inference/adapters/tgi/tgi.py b/llama_toolchain/inference/adapters/tgi/tgi.py index 849f67e69..4d32dbe5f 100644 --- a/llama_toolchain/inference/adapters/tgi/tgi.py +++ b/llama_toolchain/inference/adapters/tgi/tgi.py @@ -9,9 +9,8 @@ from typing import AsyncGenerator from huggingface_hub import InferenceClient from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import Message, StopReason +from llama_models.llama3.api.datatypes import StopReason from llama_models.llama3.api.tokenizer import Tokenizer -from llama_models.sku_list import resolve_model from llama_toolchain.inference.api import * from llama_toolchain.inference.api.api import ( # noqa: F403 @@ -19,6 +18,7 @@ from llama_toolchain.inference.api.api import ( # noqa: F403 ChatCompletionResponse, ChatCompletionResponseStreamChunk, ) +from llama_toolchain.inference.prepare_messages import prepare_messages from .config import TGIImplConfig @@ -49,26 +49,6 @@ class TGIAdapter(Inference): async def completion(self, request: CompletionRequest) -> AsyncGenerator: raise NotImplementedError() - def _convert_messages(self, messages: list[Message]) -> List[Message]: # type: ignore - tgi_messages = [] - for message in messages: - if message.role == "ipython": - role = "tool" - else: - role = message.role - tgi_messages.append({"role": role, "content": message.content}) - - return tgi_messages - - def resolve_hf_model(self, model_name: str) -> str: - model = resolve_model(model_name) - assert ( - model is not None - and model.descriptor(shorten_default_variant=True) in HF_SUPPORTED_MODELS - ), f"Unsupported model: {model_name}, use one of the supported models: {','.join(HF_SUPPORTED_MODELS.keys())}" - - return HF_SUPPORTED_MODELS.get(model.descriptor(shorten_default_variant=True)) - def get_chat_options(self, request: ChatCompletionRequest) -> dict: options = {} if request.sampling_params is not None: @@ -79,27 +59,36 @@ class TGIAdapter(Inference): return options async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: + messages = prepare_messages(request) + model_input = self.formatter.encode_dialog_prompt(messages) + prompt = self.tokenizer.decode(model_input.tokens) + + model_info = self.client.get_endpoint_info(model=self.config.url) + max_new_tokens = min( + request.sampling_params.max_tokens or model_info["max_total_tokens"], + model_info["max_total_tokens"] - len(model_input.tokens) - 1, + ) + options = self.get_chat_options(request) - messages = self._convert_messages(request.messages) if not request.stream: - response = self.client.chat_completion( - messages=messages, + response = self.client.text_generation( + prompt=prompt, stream=False, + details=True, + max_new_tokens=max_new_tokens, + stop_sequences=["<|eom_id|>", "<|eot_id|>"], **options, ) stop_reason = None - if response.choices[0].finish_reason: - if ( - response.choices[0].finish_reason == "stop_sequence" - or response.choices[0].finish_reason == "eos_token" - ): + if response.details.finish_reason: + if response.details.finish_reason == "stop": stop_reason = StopReason.end_of_turn - elif response.choices[0].finish_reason == "length": + elif response.details.finish_reason == "length": stop_reason = StopReason.out_of_tokens completion_message = self.formatter.decode_assistant_message_from_content( - response.choices[0].message.content, + response.generated_text, stop_reason, ) yield ChatCompletionResponse( @@ -117,32 +106,22 @@ class TGIAdapter(Inference): buffer = "" ipython = False stop_reason = None + tokens = [] - for chunk in self.client.chat_completion( - messages=messages, stream=True, **options + for response in self.client.text_generation( + prompt=prompt, + stream=True, + details=True, + max_new_tokens=max_new_tokens, + stop_sequences=["<|eom_id|>", "<|eot_id|>"], + **options, ): - if chunk.choices[0].finish_reason: - if ( - stop_reason is None - and chunk.choices[0].finish_reason == "stop_sequence" - ) or ( - stop_reason is None - and chunk.choices[0].finish_reason == "eos_token" - ): - stop_reason = StopReason.end_of_turn - elif ( - stop_reason is None - and chunk.choices[0].finish_reason == "length" - ): - stop_reason = StopReason.out_of_tokens - break + token_result = response.token - text = chunk.choices[0].delta.content - if text is None: - continue + buffer += token_result.text + tokens.append(token_result.id) - # check if its a tool call ( aka starts with <|python_tag|> ) - if not ipython and text.startswith("<|python_tag|>"): + if not ipython and buffer.startswith("<|python_tag|>"): ipython = True yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( @@ -153,25 +132,27 @@ class TGIAdapter(Inference): ), ) ) - buffer += text + buffer = buffer[len("<|python_tag|>") :] continue - if ipython: - if text == "<|eot_id|>": - stop_reason = StopReason.end_of_turn - text = "" - continue - elif text == "<|eom_id|>": - stop_reason = StopReason.end_of_message - text = "" - continue + if token_result.text == "<|eot_id|>": + stop_reason = StopReason.end_of_turn + text = "" + elif token_result.text == "<|eom_id|>": + stop_reason = StopReason.end_of_message + text = "" + else: + text = token_result.text - buffer += text + if ipython: delta = ToolCallDelta( content=text, parse_status=ToolCallParseStatus.in_progress, ) + else: + delta = text + if stop_reason is None: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, @@ -179,20 +160,12 @@ class TGIAdapter(Inference): stop_reason=stop_reason, ) ) - else: - buffer += text - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=text, - stop_reason=stop_reason, - ) - ) + + if stop_reason is None: + stop_reason = StopReason.out_of_tokens # parse tool calls and report errors - message = self.formatter.decode_assistant_message_from_content( - buffer, stop_reason - ) + message = self.formatter.decode_assistant_message(tokens, stop_reason) parsed_tool_calls = len(message.tool_calls) > 0 if ipython and not parsed_tool_calls: yield ChatCompletionResponseStreamChunk(