From 046afcb94568467f77763a0079bbf4e37949454e Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 4 Sep 2024 17:36:45 -0700 Subject: [PATCH] Use the lower-level `generate_stream()` method for correct tool calling --- llama_toolchain/inference/adapters/tgi/tgi.py | 148 ++++++++++-------- llama_toolchain/inference/providers.py | 2 +- 2 files changed, 80 insertions(+), 70 deletions(-) diff --git a/llama_toolchain/inference/adapters/tgi/tgi.py b/llama_toolchain/inference/adapters/tgi/tgi.py index 9c90d8ef4..7eb36ac36 100644 --- a/llama_toolchain/inference/adapters/tgi/tgi.py +++ b/llama_toolchain/inference/adapters/tgi/tgi.py @@ -8,14 +8,15 @@ from typing import AsyncGenerator, List import httpx -from huggingface_hub import InferenceClient - from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message, StopReason from llama_models.llama3.api.tokenizer import Tokenizer +from text_generation import Client + from llama_toolchain.inference.api import * # noqa: F403 +from llama_toolchain.inference.prepare_messages import prepare_messages SUPPORTED_MODELS = { @@ -28,26 +29,38 @@ SUPPORTED_MODELS = { class TGIInferenceAdapter(Inference): def __init__(self, url: str) -> None: self.url = url.rstrip("/") - tokenizer = Tokenizer.get_instance() - self.formatter = ChatFormat(tokenizer) + self.tokenizer = Tokenizer.get_instance() + self.formatter = ChatFormat(self.tokenizer) self.model = None + self.max_tokens = None async def initialize(self) -> None: hf_models = {v: k for k, v in SUPPORTED_MODELS.items()} - async with httpx.AsyncClient() as client: - response = await client.get(f"{self.url}/info") - response.raise_for_status() - info = response.json() - if "model_id" not in info: - raise RuntimeError("Missing model_id in model info") - model_id = info["model_id"] - if model_id not in hf_models: - raise RuntimeError( - f"TGI is serving model: {model_id}, use one of the supported models: {','.join(hf_models.keys())}" - ) + try: + print(f"Connecting to TGI server at: {self.url}") + async with httpx.AsyncClient() as client: + response = await client.get(f"{self.url}/info") + response.raise_for_status() + info = response.json() + if "model_id" not in info: + raise RuntimeError("Missing model_id in model info") + if "max_total_tokens" not in info: + raise RuntimeError("Missing max_total_tokens in model info") + self.max_tokens = info["max_total_tokens"] - self.model = hf_models[model_id] + model_id = info["model_id"] + if model_id not in hf_models: + raise RuntimeError( + f"TGI is serving model: {model_id}, use one of the supported models: {','.join(hf_models.keys())}" + ) + + self.model = hf_models[model_id] + except Exception as e: + import traceback + + traceback.print_exc() + raise RuntimeError("Could not connect to TGI server") from e async def shutdown(self) -> None: pass @@ -75,6 +88,15 @@ class TGIInferenceAdapter(Inference): return options async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: + messages = prepare_messages(request) + + model_input = self.formatter.encode_dialog_prompt(messages) + prompt = self.tokenizer.decode(model_input.tokens) + max_new_tokens = min( + request.sampling_params.max_tokens or self.max_tokens, + self.max_tokens - len(model_input.tokens) - 1, + ) + if request.model != self.model: raise ValueError( f"Model mismatch, expected: {self.model}, got: {request.model}" @@ -82,23 +104,27 @@ class TGIInferenceAdapter(Inference): options = self.get_chat_options(request) - client = InferenceClient(base_url=self.url) + client = Client(base_url=self.url) if not request.stream: - r = client.chat.completions.create( - model=SUPPORTED_MODELS[self.model], - messages=self._convert_messages(request.messages), - stream=False, + r = client.generate( + prompt, + max_new_tokens=max_new_tokens, + stop_sequences=["<|eom_id|>", "<|eot_id|>"], **options, ) - stop_reason = None - if r.choices[0].finish_reason: - if r.choices[0].finish_reason == "stop": + + if r.details.finish_reason: + if r.details.finish_reason == "stop": stop_reason = StopReason.end_of_turn - elif r.choices[0].finish_reason == "length": + elif r.details.finish_reason == "length": stop_reason = StopReason.out_of_tokens + else: + stop_reason = StopReason.end_of_message + else: + stop_reason = StopReason.out_of_tokens completion_message = self.formatter.decode_assistant_message_from_content( - r.choices[0].message.content, stop_reason + r.generated_text, stop_reason ) yield ChatCompletionResponse( completion_message=completion_message, @@ -115,30 +141,20 @@ class TGIInferenceAdapter(Inference): buffer = "" ipython = False stop_reason = None + tokens = [] - response = client.chat.completions.create( - model=SUPPORTED_MODELS[self.model], - messages=self._convert_messages(request.messages), - stream=True, + for response in client.generate_stream( + prompt, + max_new_tokens=max_new_tokens, + stop_sequences=["<|eom_id|>", "<|eot_id|>"], **options, - ) - for chunk in response: - if chunk.choices[0].finish_reason: - if stop_reason is None and chunk.choices[0].finish_reason == "stop": - stop_reason = StopReason.end_of_turn - elif ( - stop_reason is None - and chunk.choices[0].finish_reason == "length" - ): - stop_reason = StopReason.out_of_tokens - break + ): + token_result = response.token - text = chunk.choices[0].delta.content - if text is None: - continue + buffer += token_result.text + tokens.append(token_result.id) - # check if its a tool call ( aka starts with <|python_tag|> ) - if not ipython and text.startswith("<|python_tag|>"): + if not ipython and buffer.startswith("<|python_tag|>"): ipython = True yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( @@ -149,25 +165,27 @@ class TGIInferenceAdapter(Inference): ), ) ) - buffer += text + buffer = buffer[len("<|python_tag|>") :] continue - if ipython: - if text == "<|eot_id|>": - stop_reason = StopReason.end_of_turn - text = "" - continue - elif text == "<|eom_id|>": - stop_reason = StopReason.end_of_message - text = "" - continue + if token_result.text == "<|eot_id|>": + stop_reason = StopReason.end_of_turn + text = "" + elif token_result.text == "<|eom_id|>": + stop_reason = StopReason.end_of_message + text = "" + else: + text = token_result.text - buffer += text + if ipython: delta = ToolCallDelta( content=text, parse_status=ToolCallParseStatus.in_progress, ) + else: + delta = text + if stop_reason is None: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, @@ -175,20 +193,12 @@ class TGIInferenceAdapter(Inference): stop_reason=stop_reason, ) ) - else: - buffer += text - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=text, - stop_reason=stop_reason, - ) - ) + + if stop_reason is None: + stop_reason = StopReason.out_of_tokens # parse tool calls and report errors - message = self.formatter.decode_assistant_message_from_content( - buffer, stop_reason - ) + message = self.formatter.decode_assistant_message(tokens, stop_reason) parsed_tool_calls = len(message.tool_calls) > 0 if ipython and not parsed_tool_calls: yield ChatCompletionResponseStreamChunk( diff --git a/llama_toolchain/inference/providers.py b/llama_toolchain/inference/providers.py index ae2a4bd16..b469cb29b 100644 --- a/llama_toolchain/inference/providers.py +++ b/llama_toolchain/inference/providers.py @@ -39,7 +39,7 @@ def available_inference_providers() -> List[ProviderSpec]: api=Api.inference, adapter=AdapterSpec( adapter_id="tgi", - pip_packages=["huggingface-hub"], + pip_packages=["text-generation"], module="llama_toolchain.inference.adapters.tgi", ), ),