Add support for Triton streaming & triton async completions

This commit is contained in:
Sophia Loris 2024-07-19 09:35:27 -05:00
parent 1b3050477a
commit d5c65c6be2
3 changed files with 199 additions and 33 deletions

View file

@ -4,15 +4,23 @@ from enum import Enum
import requests # type: ignore import requests # type: ignore
import time import time
from typing import Callable, Optional, List, Sequence, Any, Union, Dict from typing import Callable, Optional, List, Sequence, Any, Union, Dict
from litellm.utils import ModelResponse, Choices, Usage, map_finish_reason, CustomStreamWrapper, Message, EmbeddingResponse from litellm.utils import (
ModelResponse,
Choices,
Delta,
Usage,
map_finish_reason,
CustomStreamWrapper,
Message,
EmbeddingResponse,
)
import litellm import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from .base import BaseLLM from .base import BaseLLM
import httpx # type: ignore import httpx # type: ignore
class TritonError(Exception): class TritonError(Exception):
def __init__(self, status_code: int, message: str) -> None: def __init__(self, status_code: int, message: str) -> None:
self.status_code = status_code self.status_code = status_code
@ -26,6 +34,7 @@ class TritonError(Exception):
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class TritonChatCompletion(BaseLLM): class TritonChatCompletion(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -127,41 +136,68 @@ class TritonChatCompletion(BaseLLM):
optional_params=None, optional_params=None,
client=None, client=None,
stream: bool = False, stream: bool = False,
acompletion: bool = False,
) -> ModelResponse: ) -> ModelResponse:
type_of_model = "" type_of_model = ""
optional_params.pop("stream", False)
if api_base.endswith("generate"): ### This is a trtllm model if api_base.endswith("generate"): ### This is a trtllm model
text_input = messages[0]["content"] text_input = messages[0]["content"]
data_for_triton: Dict[str, Any] = { data_for_triton: Dict[str, Any] = {
"text_input": str(text_input), "text_input": prompt_factory(model=model, messages=messages),
"parameters": { "parameters": {
"max_tokens": int(optional_params.get("max_tokens", 20)), "max_tokens": int(optional_params.get("max_tokens", 2000)),
"bad_words": [""], "bad_words": [""],
"stop_words": [""] "stop_words": [""],
} },
"stream": bool(stream),
} }
data_for_triton["parameters"].update(optional_params) data_for_triton["parameters"].update(optional_params)
type_of_model = "trtllm" type_of_model = "trtllm"
elif api_base.endswith("infer"): ### This is an infer model with a custom model on triton elif api_base.endswith(
"infer"
): ### This is an infer model with a custom model on triton
text_input = messages[0]["content"] text_input = messages[0]["content"]
data_for_triton = { data_for_triton = {
"inputs": [{"name": "text_input", "shape": [1], "datatype": "BYTES", "data": [text_input]}] "inputs": [
{
"name": "text_input",
"shape": [1],
"datatype": "BYTES",
"data": [text_input],
}
]
} }
for k, v in optional_params.items(): for k, v in optional_params.items():
if not (k == "stream" or k == "max_retries"): if not (k == "stream" or k == "max_retries"):
datatype = "INT32" if isinstance(v, int) else "BYTES" datatype = "INT32" if isinstance(v, int) else "BYTES"
datatype = "FP32" if isinstance(v, float) else datatype datatype = "FP32" if isinstance(v, float) else datatype
data_for_triton['inputs'].append({"name": k, "shape": [1], "datatype": datatype, "data": [v]}) data_for_triton["inputs"].append(
{"name": k, "shape": [1], "datatype": datatype, "data": [v]}
)
if "max_tokens" not in optional_params: if "max_tokens" not in optional_params:
data_for_triton['inputs'].append({"name": "max_tokens", "shape": [1], "datatype": "INT32", "data": [20]}) data_for_triton["inputs"].append(
{
"name": "max_tokens",
"shape": [1],
"datatype": "INT32",
"data": [20],
}
)
type_of_model = "infer" type_of_model = "infer"
else: ## Unknown model type passthrough else: ## Unknown model type passthrough
data_for_triton = { data_for_triton = {
"inputs": [{"name": "text_input", "shape": [1], "datatype": "BYTES", "data": [messages[0]["content"]]}] "inputs": [
{
"name": "text_input",
"shape": [1],
"datatype": "BYTES",
"data": [messages[0]["content"]],
}
]
} }
if logging_obj: if logging_obj:
@ -174,24 +210,108 @@ class TritonChatCompletion(BaseLLM):
"http_client": client, "http_client": client,
}, },
) )
handler = requests.Session()
handler.timeout = (600.0, 5.0)
response = handler.post(url=api_base, json=data_for_triton) headers = {"Content-Type": "application/json"}
data_for_triton = json.dumps(data_for_triton)
if acompletion:
return self.acompletion(
model,
data_for_triton,
headers=headers,
logging_obj=logging_obj,
api_base=api_base,
stream=stream,
model_response=model_response,
type_of_model=type_of_model,
)
else:
handler = HTTPHandler()
if stream:
return self._handle_stream(
handler, api_base, data_for_triton, model, logging_obj
)
else:
response = handler.post(url=api_base, data=data_for_triton, headers=headers)
return self._handle_response(
response, model_response, logging_obj, type_of_model=type_of_model
)
async def acompletion(
self,
model: str,
data_for_triton,
api_base,
stream,
logging_obj,
headers,
model_response,
type_of_model,
) -> ModelResponse:
handler = AsyncHTTPHandler()
if stream:
return self._ahandle_stream(
handler, api_base, data_for_triton, model, logging_obj
)
else:
response = await handler.post(
url=api_base, data=data_for_triton, headers=headers
)
return self._handle_response(
response, model_response, logging_obj, type_of_model=type_of_model
)
def _handle_stream(self, handler, api_base, data_for_triton, model, logging_obj):
response = handler.post(
url=api_base + "_stream", data=data_for_triton, stream=True
)
streamwrapper = litellm.CustomStreamWrapper(
response.iter_lines(),
model=model,
custom_llm_provider="triton",
logging_obj=logging_obj,
)
for chunk in streamwrapper:
yield (chunk)
async def _ahandle_stream(
self, handler, api_base, data_for_triton, model, logging_obj
):
response = await handler.post(
url=api_base + "_stream", data=data_for_triton, stream=True
)
streamwrapper = litellm.CustomStreamWrapper(
response.aiter_lines(),
model=model,
custom_llm_provider="triton",
logging_obj=logging_obj,
)
async for chunk in streamwrapper:
yield (chunk)
def _handle_response(self, response, model_response, logging_obj, type_of_model):
if logging_obj: if logging_obj:
logging_obj.post_call(original_response=response) logging_obj.post_call(original_response=response)
if response.status_code != 200: if response.status_code != 200:
raise TritonError(status_code=response.status_code, message=response.text) raise TritonError(status_code=response.status_code, message=response.text)
_json_response = response.json()
_json_response = response.json()
model_response.model = _json_response.get("model_name", "None") model_response.model = _json_response.get("model_name", "None")
if type_of_model == "trtllm": if type_of_model == "trtllm":
model_response.choices = [Choices(index=0, message=Message(content=_json_response['text_output']))] model_response.choices = [
Choices(index=0, message=Message(content=_json_response["text_output"]))
]
elif type_of_model == "infer": elif type_of_model == "infer":
model_response.choices = [Choices(index=0, message=Message(content=_json_response['outputs'][0]['data']))] model_response.choices = [
Choices(
index=0,
message=Message(content=_json_response["outputs"][0]["data"]),
)
]
else: else:
model_response.choices = [Choices(index=0, message=Message(content=_json_response['outputs']))] model_response.choices = [
Choices(index=0, message=Message(content=_json_response["outputs"]))
]
return model_response return model_response

View file

@ -333,6 +333,7 @@ async def acompletion(
or custom_llm_provider == "predibase" or custom_llm_provider == "predibase"
or custom_llm_provider == "bedrock" or custom_llm_provider == "bedrock"
or custom_llm_provider == "databricks" or custom_llm_provider == "databricks"
or custom_llm_provider == "triton"
or custom_llm_provider in litellm.openai_compatible_providers or custom_llm_provider in litellm.openai_compatible_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context) init_response = await loop.run_in_executor(None, func_with_context)
@ -2267,6 +2268,8 @@ def completion(
model_response=model_response, model_response=model_response,
optional_params=optional_params, optional_params=optional_params,
logging_obj=logging, logging_obj=logging,
stream=stream,
acompletion=acompletion
) )
## RESPONSE OBJECT ## RESPONSE OBJECT

View file

@ -11013,6 +11013,42 @@ class CustomStreamWrapper:
except Exception as e: except Exception as e:
raise e raise e
def handle_triton_stream(self, chunk):
try:
if isinstance(chunk, dict):
parsed_response = chunk
elif isinstance(chunk, (str, bytes)):
if isinstance(chunk, bytes):
chunk = chunk.decode("utf-8")
if "text_output" in chunk:
response = chunk.replace("data: ", "").strip()
parsed_response = json.loads(response)
else:
return {
"text": "",
"is_finished": False,
"prompt_tokens": 0,
"completion_tokens": 0,
}
else:
print_verbose(f"chunk: {chunk} (Type: {type(chunk)})")
raise ValueError(
f"Unable to parse response. Original response: {chunk}"
)
text = parsed_response.get("text_output", "")
finish_reason = parsed_response.get("stop_reason")
is_finished = parsed_response.get("is_finished", False)
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
"prompt_tokens": parsed_response.get("input_token_count", 0),
"completion_tokens": parsed_response.get("generated_token_count", 0),
}
return {"text": "", "is_finished": False}
except Exception as e:
raise e
def handle_clarifai_completion_chunk(self, chunk): def handle_clarifai_completion_chunk(self, chunk):
try: try:
if isinstance(chunk, dict): if isinstance(chunk, dict):
@ -11337,6 +11373,12 @@ class CustomStreamWrapper:
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"] self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "triton":
response_obj = self.handle_triton_stream(chunk)
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "text-completion-openai": elif self.custom_llm_provider == "text-completion-openai":
response_obj = self.handle_openai_text_completion_chunk(chunk) response_obj = self.handle_openai_text_completion_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
@ -11773,6 +11815,7 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "predibase" or self.custom_llm_provider == "predibase"
or self.custom_llm_provider == "databricks" or self.custom_llm_provider == "databricks"
or self.custom_llm_provider == "bedrock" or self.custom_llm_provider == "bedrock"
or self.custom_llm_provider == "triton"
or self.custom_llm_provider in litellm.openai_compatible_endpoints or self.custom_llm_provider in litellm.openai_compatible_endpoints
): ):
async for chunk in self.completion_stream: async for chunk in self.completion_stream: