From d5c65c6be22e1a6e3d2ebbae4f028e57aef58320 Mon Sep 17 00:00:00 2001 From: Sophia Loris Date: Fri, 19 Jul 2024 09:35:27 -0500 Subject: [PATCH] Add support for Triton streaming & triton async completions --- litellm/llms/triton.py | 186 +++++++++++++++++++++++++++++++++-------- litellm/main.py | 3 + litellm/utils.py | 43 ++++++++++ 3 files changed, 199 insertions(+), 33 deletions(-) diff --git a/litellm/llms/triton.py b/litellm/llms/triton.py index c681fd072..95cf38f1f 100644 --- a/litellm/llms/triton.py +++ b/litellm/llms/triton.py @@ -1,16 +1,24 @@ import os import json from enum import Enum -import requests # type: ignore +import requests # type: ignore import time from typing import Callable, Optional, List, Sequence, Any, Union, Dict -from litellm.utils import ModelResponse, Choices, Usage, map_finish_reason, CustomStreamWrapper, Message, EmbeddingResponse +from litellm.utils import ( + ModelResponse, + Choices, + Delta, + Usage, + map_finish_reason, + CustomStreamWrapper, + Message, + EmbeddingResponse, +) import litellm from .prompt_templates.factory import prompt_factory, custom_prompt -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from .base import BaseLLM -import httpx # type: ignore - +import httpx # type: ignore class TritonError(Exception): @@ -26,6 +34,7 @@ class TritonError(Exception): self.message ) # Call the base class constructor with the parameters it needs + class TritonChatCompletion(BaseLLM): def __init__(self) -> None: super().__init__() @@ -127,71 +136,182 @@ class TritonChatCompletion(BaseLLM): optional_params=None, client=None, stream: bool = False, + acompletion: bool = False, ) -> ModelResponse: - type_of_model = "" + optional_params.pop("stream", False) if api_base.endswith("generate"): ### This is a trtllm model text_input = messages[0]["content"] - data_for_triton: Dict[str, Any] = { - "text_input": str(text_input), + data_for_triton: Dict[str, Any] = { + "text_input": prompt_factory(model=model, messages=messages), "parameters": { - "max_tokens": int(optional_params.get("max_tokens", 20)), + "max_tokens": int(optional_params.get("max_tokens", 2000)), "bad_words": [""], - "stop_words": [""] - } + "stop_words": [""], + }, + "stream": bool(stream), } - data_for_triton["parameters"].update( optional_params) + data_for_triton["parameters"].update(optional_params) type_of_model = "trtllm" - elif api_base.endswith("infer"): ### This is an infer model with a custom model on triton + elif api_base.endswith( + "infer" + ): ### This is an infer model with a custom model on triton text_input = messages[0]["content"] data_for_triton = { - "inputs": [{"name": "text_input", "shape": [1], "datatype": "BYTES", "data": [text_input]}] + "inputs": [ + { + "name": "text_input", + "shape": [1], + "datatype": "BYTES", + "data": [text_input], + } + ] } for k, v in optional_params.items(): if not (k == "stream" or k == "max_retries"): datatype = "INT32" if isinstance(v, int) else "BYTES" datatype = "FP32" if isinstance(v, float) else datatype - data_for_triton['inputs'].append({"name": k, "shape": [1], "datatype": datatype, "data": [v]}) + data_for_triton["inputs"].append( + {"name": k, "shape": [1], "datatype": datatype, "data": [v]} + ) if "max_tokens" not in optional_params: - data_for_triton['inputs'].append({"name": "max_tokens", "shape": [1], "datatype": "INT32", "data": [20]}) + data_for_triton["inputs"].append( + { + "name": "max_tokens", + "shape": [1], + "datatype": "INT32", + "data": [20], + } + ) type_of_model = "infer" else: ## Unknown model type passthrough data_for_triton = { - "inputs": [{"name": "text_input", "shape": [1], "datatype": "BYTES", "data": [messages[0]["content"]]}] + "inputs": [ + { + "name": "text_input", + "shape": [1], + "datatype": "BYTES", + "data": [messages[0]["content"]], + } + ] } if logging_obj: logging_obj.pre_call( - input=messages, - api_key=api_key, - additional_args={ - "complete_input_dict": optional_params, - "api_base": api_base, - "http_client": client, - }, - ) - handler = requests.Session() - handler.timeout = (600.0, 5.0) + input=messages, + api_key=api_key, + additional_args={ + "complete_input_dict": optional_params, + "api_base": api_base, + "http_client": client, + }, + ) - response = handler.post(url=api_base, json=data_for_triton) + headers = {"Content-Type": "application/json"} + data_for_triton = json.dumps(data_for_triton) + if acompletion: + return self.acompletion( + model, + data_for_triton, + headers=headers, + logging_obj=logging_obj, + api_base=api_base, + stream=stream, + model_response=model_response, + type_of_model=type_of_model, + ) + else: + handler = HTTPHandler() + if stream: + return self._handle_stream( + handler, api_base, data_for_triton, model, logging_obj + ) + else: + response = handler.post(url=api_base, data=data_for_triton, headers=headers) + return self._handle_response( + response, model_response, logging_obj, type_of_model=type_of_model + ) + + async def acompletion( + self, + model: str, + data_for_triton, + api_base, + stream, + logging_obj, + headers, + model_response, + type_of_model, + ) -> ModelResponse: + handler = AsyncHTTPHandler() + if stream: + return self._ahandle_stream( + handler, api_base, data_for_triton, model, logging_obj + ) + else: + response = await handler.post( + url=api_base, data=data_for_triton, headers=headers + ) + + return self._handle_response( + response, model_response, logging_obj, type_of_model=type_of_model + ) + + def _handle_stream(self, handler, api_base, data_for_triton, model, logging_obj): + response = handler.post( + url=api_base + "_stream", data=data_for_triton, stream=True + ) + streamwrapper = litellm.CustomStreamWrapper( + response.iter_lines(), + model=model, + custom_llm_provider="triton", + logging_obj=logging_obj, + ) + for chunk in streamwrapper: + yield (chunk) + + async def _ahandle_stream( + self, handler, api_base, data_for_triton, model, logging_obj + ): + response = await handler.post( + url=api_base + "_stream", data=data_for_triton, stream=True + ) + streamwrapper = litellm.CustomStreamWrapper( + response.aiter_lines(), + model=model, + custom_llm_provider="triton", + logging_obj=logging_obj, + ) + async for chunk in streamwrapper: + yield (chunk) + + def _handle_response(self, response, model_response, logging_obj, type_of_model): if logging_obj: logging_obj.post_call(original_response=response) if response.status_code != 200: raise TritonError(status_code=response.status_code, message=response.text) - _json_response = response.json() + _json_response = response.json() model_response.model = _json_response.get("model_name", "None") if type_of_model == "trtllm": - model_response.choices = [Choices(index=0, message=Message(content=_json_response['text_output']))] + model_response.choices = [ + Choices(index=0, message=Message(content=_json_response["text_output"])) + ] elif type_of_model == "infer": - model_response.choices = [Choices(index=0, message=Message(content=_json_response['outputs'][0]['data']))] + model_response.choices = [ + Choices( + index=0, + message=Message(content=_json_response["outputs"][0]["data"]), + ) + ] else: - model_response.choices = [Choices(index=0, message=Message(content=_json_response['outputs']))] - + model_response.choices = [ + Choices(index=0, message=Message(content=_json_response["outputs"])) + ] return model_response diff --git a/litellm/main.py b/litellm/main.py index d30f2e95d..06d1abf82 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -333,6 +333,7 @@ async def acompletion( or custom_llm_provider == "predibase" or custom_llm_provider == "bedrock" or custom_llm_provider == "databricks" + or custom_llm_provider == "triton" or custom_llm_provider in litellm.openai_compatible_providers ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. init_response = await loop.run_in_executor(None, func_with_context) @@ -2267,6 +2268,8 @@ def completion( model_response=model_response, optional_params=optional_params, logging_obj=logging, + stream=stream, + acompletion=acompletion ) ## RESPONSE OBJECT diff --git a/litellm/utils.py b/litellm/utils.py index ea0f46c14..64964364c 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -11013,6 +11013,42 @@ class CustomStreamWrapper: except Exception as e: raise e + def handle_triton_stream(self, chunk): + try: + if isinstance(chunk, dict): + parsed_response = chunk + elif isinstance(chunk, (str, bytes)): + if isinstance(chunk, bytes): + chunk = chunk.decode("utf-8") + if "text_output" in chunk: + response = chunk.replace("data: ", "").strip() + parsed_response = json.loads(response) + else: + return { + "text": "", + "is_finished": False, + "prompt_tokens": 0, + "completion_tokens": 0, + } + else: + print_verbose(f"chunk: {chunk} (Type: {type(chunk)})") + raise ValueError( + f"Unable to parse response. Original response: {chunk}" + ) + text = parsed_response.get("text_output", "") + finish_reason = parsed_response.get("stop_reason") + is_finished = parsed_response.get("is_finished", False) + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + "prompt_tokens": parsed_response.get("input_token_count", 0), + "completion_tokens": parsed_response.get("generated_token_count", 0), + } + return {"text": "", "is_finished": False} + except Exception as e: + raise e + def handle_clarifai_completion_chunk(self, chunk): try: if isinstance(chunk, dict): @@ -11337,6 +11373,12 @@ class CustomStreamWrapper: completion_obj["content"] = response_obj["text"] if response_obj["is_finished"]: self.received_finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider == "triton": + response_obj = self.handle_triton_stream(chunk) + completion_obj["content"] = response_obj["text"] + print_verbose(f"completion obj content: {completion_obj['content']}") + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] elif self.custom_llm_provider == "text-completion-openai": response_obj = self.handle_openai_text_completion_chunk(chunk) completion_obj["content"] = response_obj["text"] @@ -11773,6 +11815,7 @@ class CustomStreamWrapper: or self.custom_llm_provider == "predibase" or self.custom_llm_provider == "databricks" or self.custom_llm_provider == "bedrock" + or self.custom_llm_provider == "triton" or self.custom_llm_provider in litellm.openai_compatible_endpoints ): async for chunk in self.completion_stream: