From a58dc684189d59623555599e6d5e8d4312fb2517 Mon Sep 17 00:00:00 2001 From: Giri Tatavarty Date: Tue, 28 May 2024 07:54:11 -0700 Subject: [PATCH] Added support for Triton chat completion using trtlllm generate endpoint and custom infer endpoint --- litellm/llms/triton.py | 149 +++++++++++++++++++++++++++++++++++++++-- litellm/main.py | 20 ++++++ 2 files changed, 165 insertions(+), 4 deletions(-) diff --git a/litellm/llms/triton.py b/litellm/llms/triton.py index 711186b3f..43220eec1 100644 --- a/litellm/llms/triton.py +++ b/litellm/llms/triton.py @@ -4,13 +4,13 @@ from enum import Enum import requests, copy # type: ignore import time from typing import Callable, Optional, List -from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper +from litellm.utils import ModelResponse, Choices,Usage, map_finish_reason, CustomStreamWrapper, Message import litellm from .prompt_templates.factory import prompt_factory, custom_prompt from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from .base import BaseLLM import httpx # type: ignore - +import requests class TritonError(Exception): def __init__(self, status_code, message): @@ -30,6 +30,49 @@ class TritonChatCompletion(BaseLLM): def __init__(self) -> None: super().__init__() + async def acompletion( + self, + data: dict, + model_response: ModelResponse, + api_base: str, + logging_obj=None, + api_key: Optional[str] = None, + ): + + async_handler = httpx.AsyncHTTPHandler( + timeout=httpx.Timeout(timeout=600.0, connect=5.0) + ) + + if api_base.endswith("generate") : ### This is a trtllm model + + async with httpx.AsyncClient() as client: + response = await client.post(url=api_base, json=data) + + + + if response.status_code != 200: + raise TritonError(status_code=response.status_code, message=response.text) + + _text_response = response.text + + + if logging_obj: + logging_obj.post_call(original_response=_text_response) + + _json_response = response.json() + + _output_text = _json_response["outputs"][0]["data"][0] + # decode the byte string + _output_text = _output_text.encode("latin-1").decode("unicode_escape").encode( + "latin-1" + ).decode("utf-8") + + model_response.model = _json_response.get("model_name", "None") + model_response.choices[0].message.content = _output_text + + return model_response + + async def aembedding( self, data: dict, @@ -55,7 +98,7 @@ class TritonChatCompletion(BaseLLM): _json_response = response.json() _outputs = _json_response["outputs"] - _output_data = _outputs[0]["data"] + _output_data = [ output["data"] for output in _outputs ] _embedding_output = { "object": "embedding", "index": 0, @@ -84,7 +127,7 @@ class TritonChatCompletion(BaseLLM): "inputs": [ { "name": "input_text", - "shape": [1], + "shape": [len(input)], #size of the input data "datatype": "BYTES", "data": input, } @@ -117,3 +160,101 @@ class TritonChatCompletion(BaseLLM): raise Exception( "Only async embedding supported for triton, please use litellm.aembedding() for now" ) + ## Using Sync completion for now - Async completion not supported yet. + def completion( + self, + model: str, + messages: list, + timeout: float, + api_base: str, + model_response: ModelResponse, + api_key: Optional[str] = None, + logging_obj=None, + optional_params=None, + client=None, + stream=False, + ): + # check if model is llama + data_for_triton = {} + type_of_model = "" "" + if api_base.endswith("generate") : ### This is a trtllm model + # this is a llama model + text_input = messages[0]["content"] + data_for_triton = { + "text_input":f"{text_input}", + "parameters": { + "max_tokens": optional_params.get("max_tokens", 20), + "bad_words":[""], + "stop_words":[""] + }} + for k,v in optional_params.items(): + data_for_triton["parameters"][k] = v + type_of_model = "trtllm" + + elif api_base.endswith("infer"): ### This is a infer model with a custom model on triton + # this is a custom model + text_input = messages[0]["content"] + data_for_triton = { + "inputs": [{"name": "text_input","shape": [1],"datatype": "BYTES","data": [text_input] }] + } + + for k,v in optional_params.items(): + if not (k=="stream" or k=="max_retries"): ## skip these as they are added by litellm + datatype = "INT32" if type(v) == int else "BYTES" + datatype = "FP32" if type(v) == float else datatype + data_for_triton['inputs'].append({"name": k,"shape": [1],"datatype": datatype,"data": [v]}) + + # check for max_tokens which is required + if "max_tokens" not in optional_params: + data_for_triton['inputs'].append({"name": "max_tokens","shape": [1],"datatype": "INT32","data": [20]}) + + type_of_model = "infer" + else: ## Unknown model type passthrough + data_for_triton = { + messages[0]["content"] + } + + if logging_obj: + logging_obj.pre_call( + input=messages, + api_key=api_key, + additional_args={ + "complete_input_dict": optional_params, + "api_base": api_base, + "http_client": client, + }, + ) + handler = requests.Session() + handler.timeout = (600.0, 5.0) + + response = handler.post(url=api_base, json=data_for_triton) + + + if logging_obj: + logging_obj.post_call(original_response=response) + + if response.status_code != 200: + raise TritonError(status_code=response.status_code, message=response.text) + _json_response=response.json() + + model_response.model = _json_response.get("model_name", "None") + if type_of_model == "trtllm": + # The actual response is part of the text_output key in the response + model_response['choices'] = [ Choices(index=0, message= Message(content=_json_response['text_output']))] + elif type_of_model == "infer": + # The actual response is part of the outputs key in the response + model_response['choices'] = [ Choices(index=0, message= Message(content=_json_response['outputs'][0]['data']))] + else: + ## just passthrough the response + model_response['choices'] = [ Choices(index=0, message= Message(content=_json_response['outputs']))] + + """ + response = self.acompletion( + data=data_for_triton, + model_response=model_response, + logging_obj=logging_obj, + api_base=api_base, + api_key=api_key, + ) + """ + return model_response \ No newline at end of file diff --git a/litellm/main.py b/litellm/main.py index 5da2b4a52..51da76028 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2254,6 +2254,26 @@ def completion( return generator response = generator + + elif custom_llm_provider == "triton": + api_base = ( + litellm.api_base or api_base + ) + model_response = triton_chat_completions.completion( + api_base=api_base, + timeout=timeout, + model=model, + messages=messages, + model_response=model_response, + optional_params=optional_params, + logging_obj=logging, + ) + + ## RESPONSE OBJECT + response = model_response + return response + + elif custom_llm_provider == "cloudflare": api_key = ( api_key