From 99e506525c343595fcab803db0c7ee9e56bbde08 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 29 May 2024 13:42:49 -0700 Subject: [PATCH] =?UTF-8?q?Revert=20"Added=20support=20for=20Triton=20chat?= =?UTF-8?q?=20completion=20using=20trtlllm=20generate=20endpo=E2=80=A6"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- litellm/llms/triton.py | 149 ++--------------------------------------- litellm/main.py | 20 ------ 2 files changed, 4 insertions(+), 165 deletions(-) diff --git a/litellm/llms/triton.py b/litellm/llms/triton.py index 43220eec1..711186b3f 100644 --- a/litellm/llms/triton.py +++ b/litellm/llms/triton.py @@ -4,13 +4,13 @@ from enum import Enum import requests, copy # type: ignore import time from typing import Callable, Optional, List -from litellm.utils import ModelResponse, Choices,Usage, map_finish_reason, CustomStreamWrapper, Message +from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper import litellm from .prompt_templates.factory import prompt_factory, custom_prompt from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from .base import BaseLLM import httpx # type: ignore -import requests + class TritonError(Exception): def __init__(self, status_code, message): @@ -30,49 +30,6 @@ class TritonChatCompletion(BaseLLM): def __init__(self) -> None: super().__init__() - async def acompletion( - self, - data: dict, - model_response: ModelResponse, - api_base: str, - logging_obj=None, - api_key: Optional[str] = None, - ): - - async_handler = httpx.AsyncHTTPHandler( - timeout=httpx.Timeout(timeout=600.0, connect=5.0) - ) - - if api_base.endswith("generate") : ### This is a trtllm model - - async with httpx.AsyncClient() as client: - response = await client.post(url=api_base, json=data) - - - - if response.status_code != 200: - raise TritonError(status_code=response.status_code, message=response.text) - - _text_response = response.text - - - if logging_obj: - logging_obj.post_call(original_response=_text_response) - - _json_response = response.json() - - _output_text = _json_response["outputs"][0]["data"][0] - # decode the byte string - _output_text = _output_text.encode("latin-1").decode("unicode_escape").encode( - "latin-1" - ).decode("utf-8") - - model_response.model = _json_response.get("model_name", "None") - model_response.choices[0].message.content = _output_text - - return model_response - - async def aembedding( self, data: dict, @@ -98,7 +55,7 @@ class TritonChatCompletion(BaseLLM): _json_response = response.json() _outputs = _json_response["outputs"] - _output_data = [ output["data"] for output in _outputs ] + _output_data = _outputs[0]["data"] _embedding_output = { "object": "embedding", "index": 0, @@ -127,7 +84,7 @@ class TritonChatCompletion(BaseLLM): "inputs": [ { "name": "input_text", - "shape": [len(input)], #size of the input data + "shape": [1], "datatype": "BYTES", "data": input, } @@ -160,101 +117,3 @@ class TritonChatCompletion(BaseLLM): raise Exception( "Only async embedding supported for triton, please use litellm.aembedding() for now" ) - ## Using Sync completion for now - Async completion not supported yet. - def completion( - self, - model: str, - messages: list, - timeout: float, - api_base: str, - model_response: ModelResponse, - api_key: Optional[str] = None, - logging_obj=None, - optional_params=None, - client=None, - stream=False, - ): - # check if model is llama - data_for_triton = {} - type_of_model = "" "" - if api_base.endswith("generate") : ### This is a trtllm model - # this is a llama model - text_input = messages[0]["content"] - data_for_triton = { - "text_input":f"{text_input}", - "parameters": { - "max_tokens": optional_params.get("max_tokens", 20), - "bad_words":[""], - "stop_words":[""] - }} - for k,v in optional_params.items(): - data_for_triton["parameters"][k] = v - type_of_model = "trtllm" - - elif api_base.endswith("infer"): ### This is a infer model with a custom model on triton - # this is a custom model - text_input = messages[0]["content"] - data_for_triton = { - "inputs": [{"name": "text_input","shape": [1],"datatype": "BYTES","data": [text_input] }] - } - - for k,v in optional_params.items(): - if not (k=="stream" or k=="max_retries"): ## skip these as they are added by litellm - datatype = "INT32" if type(v) == int else "BYTES" - datatype = "FP32" if type(v) == float else datatype - data_for_triton['inputs'].append({"name": k,"shape": [1],"datatype": datatype,"data": [v]}) - - # check for max_tokens which is required - if "max_tokens" not in optional_params: - data_for_triton['inputs'].append({"name": "max_tokens","shape": [1],"datatype": "INT32","data": [20]}) - - type_of_model = "infer" - else: ## Unknown model type passthrough - data_for_triton = { - messages[0]["content"] - } - - if logging_obj: - logging_obj.pre_call( - input=messages, - api_key=api_key, - additional_args={ - "complete_input_dict": optional_params, - "api_base": api_base, - "http_client": client, - }, - ) - handler = requests.Session() - handler.timeout = (600.0, 5.0) - - response = handler.post(url=api_base, json=data_for_triton) - - - if logging_obj: - logging_obj.post_call(original_response=response) - - if response.status_code != 200: - raise TritonError(status_code=response.status_code, message=response.text) - _json_response=response.json() - - model_response.model = _json_response.get("model_name", "None") - if type_of_model == "trtllm": - # The actual response is part of the text_output key in the response - model_response['choices'] = [ Choices(index=0, message= Message(content=_json_response['text_output']))] - elif type_of_model == "infer": - # The actual response is part of the outputs key in the response - model_response['choices'] = [ Choices(index=0, message= Message(content=_json_response['outputs'][0]['data']))] - else: - ## just passthrough the response - model_response['choices'] = [ Choices(index=0, message= Message(content=_json_response['outputs']))] - - """ - response = self.acompletion( - data=data_for_triton, - model_response=model_response, - logging_obj=logging_obj, - api_base=api_base, - api_key=api_key, - ) - """ - return model_response \ No newline at end of file diff --git a/litellm/main.py b/litellm/main.py index 52b4193ea..a7fbbfa69 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2254,26 +2254,6 @@ def completion( return generator response = generator - - elif custom_llm_provider == "triton": - api_base = ( - litellm.api_base or api_base - ) - model_response = triton_chat_completions.completion( - api_base=api_base, - timeout=timeout, - model=model, - messages=messages, - model_response=model_response, - optional_params=optional_params, - logging_obj=logging, - ) - - ## RESPONSE OBJECT - response = model_response - return response - - elif custom_llm_provider == "cloudflare": api_key = ( api_key