Merge pull request #3905 from giritatavarty-8451/litellm_triton_chatcompletion_support

Litellm triton chatcompletion support - Resubmit of #3895
This commit is contained in:
Ishaan Jaff 2024-07-23 10:30:26 -07:00 committed by GitHub
commit 1355932bf4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 282 additions and 22 deletions

View file

@ -1,23 +1,28 @@
import copy
import json
import os import os
import time import json
import types
from enum import Enum from enum import Enum
from typing import Callable, List, Optional
import httpx # type: ignore
import requests # type: ignore import requests # type: ignore
import time
from typing import Callable, Optional, List, Sequence, Any, Union, Dict
from litellm.utils import (
ModelResponse,
Choices,
Delta,
Usage,
map_finish_reason,
CustomStreamWrapper,
Message,
EmbeddingResponse,
)
import litellm import litellm
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from .base import BaseLLM from .base import BaseLLM
from .prompt_templates.factory import custom_prompt, prompt_factory import httpx # type: ignore
class TritonError(Exception): class TritonError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code: int, message: str) -> None:
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request( self.request = httpx.Request(
@ -41,8 +46,7 @@ class TritonChatCompletion(BaseLLM):
api_base: str, api_base: str,
logging_obj=None, logging_obj=None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
): ) -> EmbeddingResponse:
async_handler = AsyncHTTPHandler( async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0) timeout=httpx.Timeout(timeout=600.0, connect=5.0)
) )
@ -79,10 +83,10 @@ class TritonChatCompletion(BaseLLM):
return model_response return model_response
def embedding( async def embedding(
self, self,
model: str, model: str,
input: list, input: List[str],
timeout: float, timeout: float,
api_base: str, api_base: str,
model_response: litellm.utils.EmbeddingResponse, model_response: litellm.utils.EmbeddingResponse,
@ -90,8 +94,8 @@ class TritonChatCompletion(BaseLLM):
logging_obj=None, logging_obj=None,
optional_params=None, optional_params=None,
client=None, client=None,
aembedding=None, aembedding: bool = False,
): ) -> EmbeddingResponse:
data_for_triton = { data_for_triton = {
"inputs": [ "inputs": [
{ {
@ -103,8 +107,6 @@ class TritonChatCompletion(BaseLLM):
] ]
} }
## LOGGING
curl_string = f"curl {api_base} -X POST -H 'Content-Type: application/json' -d '{data_for_triton}'" curl_string = f"curl {api_base} -X POST -H 'Content-Type: application/json' -d '{data_for_triton}'"
logging_obj.pre_call( logging_obj.pre_call(
@ -116,8 +118,8 @@ class TritonChatCompletion(BaseLLM):
}, },
) )
if aembedding == True: if aembedding:
response = self.aembedding( response = await self.aembedding(
data=data_for_triton, data=data_for_triton,
model_response=model_response, model_response=model_response,
logging_obj=logging_obj, logging_obj=logging_obj,
@ -130,6 +132,198 @@ class TritonChatCompletion(BaseLLM):
"Only async embedding supported for triton, please use litellm.aembedding() for now" "Only async embedding supported for triton, please use litellm.aembedding() for now"
) )
def completion(
self,
model: str,
messages: List[dict],
timeout: float,
api_base: str,
model_response: ModelResponse,
api_key: Optional[str] = None,
logging_obj=None,
optional_params=None,
client=None,
stream: bool = False,
acompletion: bool = False,
) -> ModelResponse:
type_of_model = ""
optional_params.pop("stream", False)
if api_base.endswith("generate"): ### This is a trtllm model
text_input = messages[0]["content"]
data_for_triton: Dict[str, Any] = {
"text_input": prompt_factory(model=model, messages=messages),
"parameters": {
"max_tokens": int(optional_params.get("max_tokens", 2000)),
"bad_words": [""],
"stop_words": [""],
},
"stream": bool(stream),
}
data_for_triton["parameters"].update(optional_params)
type_of_model = "trtllm"
elif api_base.endswith(
"infer"
): ### This is an infer model with a custom model on triton
text_input = messages[0]["content"]
data_for_triton = {
"inputs": [
{
"name": "text_input",
"shape": [1],
"datatype": "BYTES",
"data": [text_input],
}
]
}
for k, v in optional_params.items():
if not (k == "stream" or k == "max_retries"):
datatype = "INT32" if isinstance(v, int) else "BYTES"
datatype = "FP32" if isinstance(v, float) else datatype
data_for_triton["inputs"].append(
{"name": k, "shape": [1], "datatype": datatype, "data": [v]}
)
if "max_tokens" not in optional_params:
data_for_triton["inputs"].append(
{
"name": "max_tokens",
"shape": [1],
"datatype": "INT32",
"data": [20],
}
)
type_of_model = "infer"
else: ## Unknown model type passthrough
data_for_triton = {
"inputs": [
{
"name": "text_input",
"shape": [1],
"datatype": "BYTES",
"data": [messages[0]["content"]],
}
]
}
if logging_obj:
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={
"complete_input_dict": optional_params,
"api_base": api_base,
"http_client": client,
},
)
headers = {"Content-Type": "application/json"}
data_for_triton = json.dumps(data_for_triton)
if acompletion:
return self.acompletion(
model,
data_for_triton,
headers=headers,
logging_obj=logging_obj,
api_base=api_base,
stream=stream,
model_response=model_response,
type_of_model=type_of_model,
)
else:
handler = HTTPHandler()
if stream:
return self._handle_stream(
handler, api_base, data_for_triton, model, logging_obj
)
else:
response = handler.post(url=api_base, data=data_for_triton, headers=headers)
return self._handle_response(
response, model_response, logging_obj, type_of_model=type_of_model
)
async def acompletion(
self,
model: str,
data_for_triton,
api_base,
stream,
logging_obj,
headers,
model_response,
type_of_model,
) -> ModelResponse:
handler = AsyncHTTPHandler()
if stream:
return self._ahandle_stream(
handler, api_base, data_for_triton, model, logging_obj
)
else:
response = await handler.post(
url=api_base, data=data_for_triton, headers=headers
)
return self._handle_response(
response, model_response, logging_obj, type_of_model=type_of_model
)
def _handle_stream(self, handler, api_base, data_for_triton, model, logging_obj):
response = handler.post(
url=api_base + "_stream", data=data_for_triton, stream=True
)
streamwrapper = litellm.CustomStreamWrapper(
response.iter_lines(),
model=model,
custom_llm_provider="triton",
logging_obj=logging_obj,
)
for chunk in streamwrapper:
yield (chunk)
async def _ahandle_stream(
self, handler, api_base, data_for_triton, model, logging_obj
):
response = await handler.post(
url=api_base + "_stream", data=data_for_triton, stream=True
)
streamwrapper = litellm.CustomStreamWrapper(
response.aiter_lines(),
model=model,
custom_llm_provider="triton",
logging_obj=logging_obj,
)
async for chunk in streamwrapper:
yield (chunk)
def _handle_response(self, response, model_response, logging_obj, type_of_model):
if logging_obj:
logging_obj.post_call(original_response=response)
if response.status_code != 200:
raise TritonError(status_code=response.status_code, message=response.text)
_json_response = response.json()
model_response.model = _json_response.get("model_name", "None")
if type_of_model == "trtllm":
model_response.choices = [
Choices(index=0, message=Message(content=_json_response["text_output"]))
]
elif type_of_model == "infer":
model_response.choices = [
Choices(
index=0,
message=Message(content=_json_response["outputs"][0]["data"]),
)
]
else:
model_response.choices = [
Choices(index=0, message=Message(content=_json_response["outputs"]))
]
return model_response
@staticmethod @staticmethod
def split_embedding_by_shape( def split_embedding_by_shape(
data: List[float], shape: List[int] data: List[float], shape: List[int]

View file

@ -375,6 +375,7 @@ async def acompletion(
or custom_llm_provider == "predibase" or custom_llm_provider == "predibase"
or custom_llm_provider == "bedrock" or custom_llm_provider == "bedrock"
or custom_llm_provider == "databricks" or custom_llm_provider == "databricks"
or custom_llm_provider == "triton"
or custom_llm_provider == "clarifai" or custom_llm_provider == "clarifai"
or custom_llm_provider == "watsonx" or custom_llm_provider == "watsonx"
or custom_llm_provider in litellm.openai_compatible_providers or custom_llm_provider in litellm.openai_compatible_providers
@ -2477,6 +2478,28 @@ def completion(
return generator return generator
response = generator response = generator
elif custom_llm_provider == "triton":
api_base = (
litellm.api_base or api_base
)
model_response = triton_chat_completions.completion(
api_base=api_base,
timeout=timeout, # type: ignore
model=model,
messages=messages,
model_response=model_response,
optional_params=optional_params,
logging_obj=logging,
stream=stream,
acompletion=acompletion
)
## RESPONSE OBJECT
response = model_response
return response
elif custom_llm_provider == "cloudflare": elif custom_llm_provider == "cloudflare":
api_key = ( api_key = (
api_key api_key

View file

@ -9097,6 +9097,42 @@ class CustomStreamWrapper:
except Exception as e: except Exception as e:
raise e raise e
def handle_triton_stream(self, chunk):
try:
if isinstance(chunk, dict):
parsed_response = chunk
elif isinstance(chunk, (str, bytes)):
if isinstance(chunk, bytes):
chunk = chunk.decode("utf-8")
if "text_output" in chunk:
response = chunk.replace("data: ", "").strip()
parsed_response = json.loads(response)
else:
return {
"text": "",
"is_finished": False,
"prompt_tokens": 0,
"completion_tokens": 0,
}
else:
print_verbose(f"chunk: {chunk} (Type: {type(chunk)})")
raise ValueError(
f"Unable to parse response. Original response: {chunk}"
)
text = parsed_response.get("text_output", "")
finish_reason = parsed_response.get("stop_reason")
is_finished = parsed_response.get("is_finished", False)
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
"prompt_tokens": parsed_response.get("input_token_count", 0),
"completion_tokens": parsed_response.get("generated_token_count", 0),
}
return {"text": "", "is_finished": False}
except Exception as e:
raise e
def handle_clarifai_completion_chunk(self, chunk): def handle_clarifai_completion_chunk(self, chunk):
try: try:
if isinstance(chunk, dict): if isinstance(chunk, dict):
@ -9516,6 +9552,12 @@ class CustomStreamWrapper:
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"] self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "triton":
response_obj = self.handle_triton_stream(chunk)
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "text-completion-openai": elif self.custom_llm_provider == "text-completion-openai":
response_obj = self.handle_openai_text_completion_chunk(chunk) response_obj = self.handle_openai_text_completion_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
@ -10071,6 +10113,7 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "predibase" or self.custom_llm_provider == "predibase"
or self.custom_llm_provider == "databricks" or self.custom_llm_provider == "databricks"
or self.custom_llm_provider == "bedrock" or self.custom_llm_provider == "bedrock"
or self.custom_llm_provider == "triton"
or self.custom_llm_provider == "watsonx" or self.custom_llm_provider == "watsonx"
or self.custom_llm_provider in litellm.openai_compatible_endpoints or self.custom_llm_provider in litellm.openai_compatible_endpoints
): ):