diff --git a/litellm/llms/triton.py b/litellm/llms/triton.py index bc9a15a94..770898949 100644 --- a/litellm/llms/triton.py +++ b/litellm/llms/triton.py @@ -1,23 +1,28 @@ -import copy -import json import os -import time -import types +import json from enum import Enum -from typing import Callable, List, Optional - -import httpx # type: ignore import requests # type: ignore - +import time +from typing import Callable, Optional, List, Sequence, Any, Union, Dict +from litellm.utils import ( + ModelResponse, + Choices, + Delta, + Usage, + map_finish_reason, + CustomStreamWrapper, + Message, + EmbeddingResponse, +) import litellm -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler - +from .prompt_templates.factory import prompt_factory, custom_prompt +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from .base import BaseLLM -from .prompt_templates.factory import custom_prompt, prompt_factory +import httpx # type: ignore class TritonError(Exception): - def __init__(self, status_code, message): + def __init__(self, status_code: int, message: str) -> None: self.status_code = status_code self.message = message self.request = httpx.Request( @@ -41,8 +46,7 @@ class TritonChatCompletion(BaseLLM): api_base: str, logging_obj=None, api_key: Optional[str] = None, - ): - + ) -> EmbeddingResponse: async_handler = AsyncHTTPHandler( timeout=httpx.Timeout(timeout=600.0, connect=5.0) ) @@ -79,10 +83,10 @@ class TritonChatCompletion(BaseLLM): return model_response - def embedding( + async def embedding( self, model: str, - input: list, + input: List[str], timeout: float, api_base: str, model_response: litellm.utils.EmbeddingResponse, @@ -90,8 +94,8 @@ class TritonChatCompletion(BaseLLM): logging_obj=None, optional_params=None, client=None, - aembedding=None, - ): + aembedding: bool = False, + ) -> EmbeddingResponse: data_for_triton = { "inputs": [ { @@ -103,8 +107,6 @@ class TritonChatCompletion(BaseLLM): ] } - ## LOGGING - curl_string = f"curl {api_base} -X POST -H 'Content-Type: application/json' -d '{data_for_triton}'" logging_obj.pre_call( @@ -116,8 +118,8 @@ class TritonChatCompletion(BaseLLM): }, ) - if aembedding == True: - response = self.aembedding( + if aembedding: + response = await self.aembedding( data=data_for_triton, model_response=model_response, logging_obj=logging_obj, @@ -130,6 +132,198 @@ class TritonChatCompletion(BaseLLM): "Only async embedding supported for triton, please use litellm.aembedding() for now" ) + def completion( + self, + model: str, + messages: List[dict], + timeout: float, + api_base: str, + model_response: ModelResponse, + api_key: Optional[str] = None, + logging_obj=None, + optional_params=None, + client=None, + stream: bool = False, + acompletion: bool = False, + ) -> ModelResponse: + type_of_model = "" + optional_params.pop("stream", False) + if api_base.endswith("generate"): ### This is a trtllm model + text_input = messages[0]["content"] + data_for_triton: Dict[str, Any] = { + "text_input": prompt_factory(model=model, messages=messages), + "parameters": { + "max_tokens": int(optional_params.get("max_tokens", 2000)), + "bad_words": [""], + "stop_words": [""], + }, + "stream": bool(stream), + } + data_for_triton["parameters"].update(optional_params) + type_of_model = "trtllm" + + elif api_base.endswith( + "infer" + ): ### This is an infer model with a custom model on triton + text_input = messages[0]["content"] + data_for_triton = { + "inputs": [ + { + "name": "text_input", + "shape": [1], + "datatype": "BYTES", + "data": [text_input], + } + ] + } + + for k, v in optional_params.items(): + if not (k == "stream" or k == "max_retries"): + datatype = "INT32" if isinstance(v, int) else "BYTES" + datatype = "FP32" if isinstance(v, float) else datatype + data_for_triton["inputs"].append( + {"name": k, "shape": [1], "datatype": datatype, "data": [v]} + ) + + if "max_tokens" not in optional_params: + data_for_triton["inputs"].append( + { + "name": "max_tokens", + "shape": [1], + "datatype": "INT32", + "data": [20], + } + ) + + type_of_model = "infer" + else: ## Unknown model type passthrough + data_for_triton = { + "inputs": [ + { + "name": "text_input", + "shape": [1], + "datatype": "BYTES", + "data": [messages[0]["content"]], + } + ] + } + + if logging_obj: + logging_obj.pre_call( + input=messages, + api_key=api_key, + additional_args={ + "complete_input_dict": optional_params, + "api_base": api_base, + "http_client": client, + }, + ) + + headers = {"Content-Type": "application/json"} + data_for_triton = json.dumps(data_for_triton) + + if acompletion: + return self.acompletion( + model, + data_for_triton, + headers=headers, + logging_obj=logging_obj, + api_base=api_base, + stream=stream, + model_response=model_response, + type_of_model=type_of_model, + ) + else: + handler = HTTPHandler() + if stream: + return self._handle_stream( + handler, api_base, data_for_triton, model, logging_obj + ) + else: + response = handler.post(url=api_base, data=data_for_triton, headers=headers) + return self._handle_response( + response, model_response, logging_obj, type_of_model=type_of_model + ) + + async def acompletion( + self, + model: str, + data_for_triton, + api_base, + stream, + logging_obj, + headers, + model_response, + type_of_model, + ) -> ModelResponse: + handler = AsyncHTTPHandler() + if stream: + return self._ahandle_stream( + handler, api_base, data_for_triton, model, logging_obj + ) + else: + response = await handler.post( + url=api_base, data=data_for_triton, headers=headers + ) + + return self._handle_response( + response, model_response, logging_obj, type_of_model=type_of_model + ) + + def _handle_stream(self, handler, api_base, data_for_triton, model, logging_obj): + response = handler.post( + url=api_base + "_stream", data=data_for_triton, stream=True + ) + streamwrapper = litellm.CustomStreamWrapper( + response.iter_lines(), + model=model, + custom_llm_provider="triton", + logging_obj=logging_obj, + ) + for chunk in streamwrapper: + yield (chunk) + + async def _ahandle_stream( + self, handler, api_base, data_for_triton, model, logging_obj + ): + response = await handler.post( + url=api_base + "_stream", data=data_for_triton, stream=True + ) + streamwrapper = litellm.CustomStreamWrapper( + response.aiter_lines(), + model=model, + custom_llm_provider="triton", + logging_obj=logging_obj, + ) + async for chunk in streamwrapper: + yield (chunk) + + def _handle_response(self, response, model_response, logging_obj, type_of_model): + if logging_obj: + logging_obj.post_call(original_response=response) + + if response.status_code != 200: + raise TritonError(status_code=response.status_code, message=response.text) + + _json_response = response.json() + model_response.model = _json_response.get("model_name", "None") + if type_of_model == "trtllm": + model_response.choices = [ + Choices(index=0, message=Message(content=_json_response["text_output"])) + ] + elif type_of_model == "infer": + model_response.choices = [ + Choices( + index=0, + message=Message(content=_json_response["outputs"][0]["data"]), + ) + ] + else: + model_response.choices = [ + Choices(index=0, message=Message(content=_json_response["outputs"])) + ] + return model_response + @staticmethod def split_embedding_by_shape( data: List[float], shape: List[int] diff --git a/litellm/main.py b/litellm/main.py index 4e2df72cd..fad2e15cc 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -375,6 +375,7 @@ async def acompletion( or custom_llm_provider == "predibase" or custom_llm_provider == "bedrock" or custom_llm_provider == "databricks" + or custom_llm_provider == "triton" or custom_llm_provider == "clarifai" or custom_llm_provider == "watsonx" or custom_llm_provider in litellm.openai_compatible_providers @@ -2477,6 +2478,28 @@ def completion( return generator response = generator + + elif custom_llm_provider == "triton": + api_base = ( + litellm.api_base or api_base + ) + model_response = triton_chat_completions.completion( + api_base=api_base, + timeout=timeout, # type: ignore + model=model, + messages=messages, + model_response=model_response, + optional_params=optional_params, + logging_obj=logging, + stream=stream, + acompletion=acompletion + ) + + ## RESPONSE OBJECT + response = model_response + return response + + elif custom_llm_provider == "cloudflare": api_key = ( api_key diff --git a/litellm/utils.py b/litellm/utils.py index 9d798f119..97eb874d6 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -9097,6 +9097,42 @@ class CustomStreamWrapper: except Exception as e: raise e + def handle_triton_stream(self, chunk): + try: + if isinstance(chunk, dict): + parsed_response = chunk + elif isinstance(chunk, (str, bytes)): + if isinstance(chunk, bytes): + chunk = chunk.decode("utf-8") + if "text_output" in chunk: + response = chunk.replace("data: ", "").strip() + parsed_response = json.loads(response) + else: + return { + "text": "", + "is_finished": False, + "prompt_tokens": 0, + "completion_tokens": 0, + } + else: + print_verbose(f"chunk: {chunk} (Type: {type(chunk)})") + raise ValueError( + f"Unable to parse response. Original response: {chunk}" + ) + text = parsed_response.get("text_output", "") + finish_reason = parsed_response.get("stop_reason") + is_finished = parsed_response.get("is_finished", False) + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + "prompt_tokens": parsed_response.get("input_token_count", 0), + "completion_tokens": parsed_response.get("generated_token_count", 0), + } + return {"text": "", "is_finished": False} + except Exception as e: + raise e + def handle_clarifai_completion_chunk(self, chunk): try: if isinstance(chunk, dict): @@ -9516,6 +9552,12 @@ class CustomStreamWrapper: completion_obj["content"] = response_obj["text"] if response_obj["is_finished"]: self.received_finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider == "triton": + response_obj = self.handle_triton_stream(chunk) + completion_obj["content"] = response_obj["text"] + print_verbose(f"completion obj content: {completion_obj['content']}") + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] elif self.custom_llm_provider == "text-completion-openai": response_obj = self.handle_openai_text_completion_chunk(chunk) completion_obj["content"] = response_obj["text"] @@ -10071,6 +10113,7 @@ class CustomStreamWrapper: or self.custom_llm_provider == "predibase" or self.custom_llm_provider == "databricks" or self.custom_llm_provider == "bedrock" + or self.custom_llm_provider == "triton" or self.custom_llm_provider == "watsonx" or self.custom_llm_provider in litellm.openai_compatible_endpoints ):