forked from phoenix/litellm-mirror
Add support for Triton streaming & triton async completions
This commit is contained in:
parent
1b3050477a
commit
d5c65c6be2
3 changed files with 199 additions and 33 deletions
|
@ -1,16 +1,24 @@
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional, List, Sequence, Any, Union, Dict
|
from typing import Callable, Optional, List, Sequence, Any, Union, Dict
|
||||||
from litellm.utils import ModelResponse, Choices, Usage, map_finish_reason, CustomStreamWrapper, Message, EmbeddingResponse
|
from litellm.utils import (
|
||||||
|
ModelResponse,
|
||||||
|
Choices,
|
||||||
|
Delta,
|
||||||
|
Usage,
|
||||||
|
map_finish_reason,
|
||||||
|
CustomStreamWrapper,
|
||||||
|
Message,
|
||||||
|
EmbeddingResponse,
|
||||||
|
)
|
||||||
import litellm
|
import litellm
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TritonError(Exception):
|
class TritonError(Exception):
|
||||||
|
@ -26,6 +34,7 @@ class TritonError(Exception):
|
||||||
self.message
|
self.message
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
class TritonChatCompletion(BaseLLM):
|
class TritonChatCompletion(BaseLLM):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -127,71 +136,182 @@ class TritonChatCompletion(BaseLLM):
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
client=None,
|
client=None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
|
acompletion: bool = False,
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
|
|
||||||
type_of_model = ""
|
type_of_model = ""
|
||||||
|
optional_params.pop("stream", False)
|
||||||
if api_base.endswith("generate"): ### This is a trtllm model
|
if api_base.endswith("generate"): ### This is a trtllm model
|
||||||
text_input = messages[0]["content"]
|
text_input = messages[0]["content"]
|
||||||
data_for_triton: Dict[str, Any] = {
|
data_for_triton: Dict[str, Any] = {
|
||||||
"text_input": str(text_input),
|
"text_input": prompt_factory(model=model, messages=messages),
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"max_tokens": int(optional_params.get("max_tokens", 20)),
|
"max_tokens": int(optional_params.get("max_tokens", 2000)),
|
||||||
"bad_words": [""],
|
"bad_words": [""],
|
||||||
"stop_words": [""]
|
"stop_words": [""],
|
||||||
}
|
},
|
||||||
|
"stream": bool(stream),
|
||||||
}
|
}
|
||||||
data_for_triton["parameters"].update( optional_params)
|
data_for_triton["parameters"].update(optional_params)
|
||||||
type_of_model = "trtllm"
|
type_of_model = "trtllm"
|
||||||
|
|
||||||
elif api_base.endswith("infer"): ### This is an infer model with a custom model on triton
|
elif api_base.endswith(
|
||||||
|
"infer"
|
||||||
|
): ### This is an infer model with a custom model on triton
|
||||||
text_input = messages[0]["content"]
|
text_input = messages[0]["content"]
|
||||||
data_for_triton = {
|
data_for_triton = {
|
||||||
"inputs": [{"name": "text_input", "shape": [1], "datatype": "BYTES", "data": [text_input]}]
|
"inputs": [
|
||||||
|
{
|
||||||
|
"name": "text_input",
|
||||||
|
"shape": [1],
|
||||||
|
"datatype": "BYTES",
|
||||||
|
"data": [text_input],
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, v in optional_params.items():
|
for k, v in optional_params.items():
|
||||||
if not (k == "stream" or k == "max_retries"):
|
if not (k == "stream" or k == "max_retries"):
|
||||||
datatype = "INT32" if isinstance(v, int) else "BYTES"
|
datatype = "INT32" if isinstance(v, int) else "BYTES"
|
||||||
datatype = "FP32" if isinstance(v, float) else datatype
|
datatype = "FP32" if isinstance(v, float) else datatype
|
||||||
data_for_triton['inputs'].append({"name": k, "shape": [1], "datatype": datatype, "data": [v]})
|
data_for_triton["inputs"].append(
|
||||||
|
{"name": k, "shape": [1], "datatype": datatype, "data": [v]}
|
||||||
|
)
|
||||||
|
|
||||||
if "max_tokens" not in optional_params:
|
if "max_tokens" not in optional_params:
|
||||||
data_for_triton['inputs'].append({"name": "max_tokens", "shape": [1], "datatype": "INT32", "data": [20]})
|
data_for_triton["inputs"].append(
|
||||||
|
{
|
||||||
|
"name": "max_tokens",
|
||||||
|
"shape": [1],
|
||||||
|
"datatype": "INT32",
|
||||||
|
"data": [20],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
type_of_model = "infer"
|
type_of_model = "infer"
|
||||||
else: ## Unknown model type passthrough
|
else: ## Unknown model type passthrough
|
||||||
data_for_triton = {
|
data_for_triton = {
|
||||||
"inputs": [{"name": "text_input", "shape": [1], "datatype": "BYTES", "data": [messages[0]["content"]]}]
|
"inputs": [
|
||||||
|
{
|
||||||
|
"name": "text_input",
|
||||||
|
"shape": [1],
|
||||||
|
"datatype": "BYTES",
|
||||||
|
"data": [messages[0]["content"]],
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
if logging_obj:
|
if logging_obj:
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
additional_args={
|
additional_args={
|
||||||
"complete_input_dict": optional_params,
|
"complete_input_dict": optional_params,
|
||||||
"api_base": api_base,
|
"api_base": api_base,
|
||||||
"http_client": client,
|
"http_client": client,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
handler = requests.Session()
|
|
||||||
handler.timeout = (600.0, 5.0)
|
|
||||||
|
|
||||||
response = handler.post(url=api_base, json=data_for_triton)
|
headers = {"Content-Type": "application/json"}
|
||||||
|
data_for_triton = json.dumps(data_for_triton)
|
||||||
|
|
||||||
|
if acompletion:
|
||||||
|
return self.acompletion(
|
||||||
|
model,
|
||||||
|
data_for_triton,
|
||||||
|
headers=headers,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_base=api_base,
|
||||||
|
stream=stream,
|
||||||
|
model_response=model_response,
|
||||||
|
type_of_model=type_of_model,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
handler = HTTPHandler()
|
||||||
|
if stream:
|
||||||
|
return self._handle_stream(
|
||||||
|
handler, api_base, data_for_triton, model, logging_obj
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = handler.post(url=api_base, data=data_for_triton, headers=headers)
|
||||||
|
return self._handle_response(
|
||||||
|
response, model_response, logging_obj, type_of_model=type_of_model
|
||||||
|
)
|
||||||
|
|
||||||
|
async def acompletion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
data_for_triton,
|
||||||
|
api_base,
|
||||||
|
stream,
|
||||||
|
logging_obj,
|
||||||
|
headers,
|
||||||
|
model_response,
|
||||||
|
type_of_model,
|
||||||
|
) -> ModelResponse:
|
||||||
|
handler = AsyncHTTPHandler()
|
||||||
|
if stream:
|
||||||
|
return self._ahandle_stream(
|
||||||
|
handler, api_base, data_for_triton, model, logging_obj
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await handler.post(
|
||||||
|
url=api_base, data=data_for_triton, headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._handle_response(
|
||||||
|
response, model_response, logging_obj, type_of_model=type_of_model
|
||||||
|
)
|
||||||
|
|
||||||
|
def _handle_stream(self, handler, api_base, data_for_triton, model, logging_obj):
|
||||||
|
response = handler.post(
|
||||||
|
url=api_base + "_stream", data=data_for_triton, stream=True
|
||||||
|
)
|
||||||
|
streamwrapper = litellm.CustomStreamWrapper(
|
||||||
|
response.iter_lines(),
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="triton",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
for chunk in streamwrapper:
|
||||||
|
yield (chunk)
|
||||||
|
|
||||||
|
async def _ahandle_stream(
|
||||||
|
self, handler, api_base, data_for_triton, model, logging_obj
|
||||||
|
):
|
||||||
|
response = await handler.post(
|
||||||
|
url=api_base + "_stream", data=data_for_triton, stream=True
|
||||||
|
)
|
||||||
|
streamwrapper = litellm.CustomStreamWrapper(
|
||||||
|
response.aiter_lines(),
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="triton",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
async for chunk in streamwrapper:
|
||||||
|
yield (chunk)
|
||||||
|
|
||||||
|
def _handle_response(self, response, model_response, logging_obj, type_of_model):
|
||||||
if logging_obj:
|
if logging_obj:
|
||||||
logging_obj.post_call(original_response=response)
|
logging_obj.post_call(original_response=response)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise TritonError(status_code=response.status_code, message=response.text)
|
raise TritonError(status_code=response.status_code, message=response.text)
|
||||||
_json_response = response.json()
|
|
||||||
|
|
||||||
|
_json_response = response.json()
|
||||||
model_response.model = _json_response.get("model_name", "None")
|
model_response.model = _json_response.get("model_name", "None")
|
||||||
if type_of_model == "trtllm":
|
if type_of_model == "trtllm":
|
||||||
model_response.choices = [Choices(index=0, message=Message(content=_json_response['text_output']))]
|
model_response.choices = [
|
||||||
|
Choices(index=0, message=Message(content=_json_response["text_output"]))
|
||||||
|
]
|
||||||
elif type_of_model == "infer":
|
elif type_of_model == "infer":
|
||||||
model_response.choices = [Choices(index=0, message=Message(content=_json_response['outputs'][0]['data']))]
|
model_response.choices = [
|
||||||
|
Choices(
|
||||||
|
index=0,
|
||||||
|
message=Message(content=_json_response["outputs"][0]["data"]),
|
||||||
|
)
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
model_response.choices = [Choices(index=0, message=Message(content=_json_response['outputs']))]
|
model_response.choices = [
|
||||||
|
Choices(index=0, message=Message(content=_json_response["outputs"]))
|
||||||
|
]
|
||||||
return model_response
|
return model_response
|
||||||
|
|
|
@ -333,6 +333,7 @@ async def acompletion(
|
||||||
or custom_llm_provider == "predibase"
|
or custom_llm_provider == "predibase"
|
||||||
or custom_llm_provider == "bedrock"
|
or custom_llm_provider == "bedrock"
|
||||||
or custom_llm_provider == "databricks"
|
or custom_llm_provider == "databricks"
|
||||||
|
or custom_llm_provider == "triton"
|
||||||
or custom_llm_provider in litellm.openai_compatible_providers
|
or custom_llm_provider in litellm.openai_compatible_providers
|
||||||
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
||||||
init_response = await loop.run_in_executor(None, func_with_context)
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
@ -2267,6 +2268,8 @@ def completion(
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
|
stream=stream,
|
||||||
|
acompletion=acompletion
|
||||||
)
|
)
|
||||||
|
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
|
|
|
@ -11013,6 +11013,42 @@ class CustomStreamWrapper:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
def handle_triton_stream(self, chunk):
|
||||||
|
try:
|
||||||
|
if isinstance(chunk, dict):
|
||||||
|
parsed_response = chunk
|
||||||
|
elif isinstance(chunk, (str, bytes)):
|
||||||
|
if isinstance(chunk, bytes):
|
||||||
|
chunk = chunk.decode("utf-8")
|
||||||
|
if "text_output" in chunk:
|
||||||
|
response = chunk.replace("data: ", "").strip()
|
||||||
|
parsed_response = json.loads(response)
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"text": "",
|
||||||
|
"is_finished": False,
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
print_verbose(f"chunk: {chunk} (Type: {type(chunk)})")
|
||||||
|
raise ValueError(
|
||||||
|
f"Unable to parse response. Original response: {chunk}"
|
||||||
|
)
|
||||||
|
text = parsed_response.get("text_output", "")
|
||||||
|
finish_reason = parsed_response.get("stop_reason")
|
||||||
|
is_finished = parsed_response.get("is_finished", False)
|
||||||
|
return {
|
||||||
|
"text": text,
|
||||||
|
"is_finished": is_finished,
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
"prompt_tokens": parsed_response.get("input_token_count", 0),
|
||||||
|
"completion_tokens": parsed_response.get("generated_token_count", 0),
|
||||||
|
}
|
||||||
|
return {"text": "", "is_finished": False}
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
def handle_clarifai_completion_chunk(self, chunk):
|
def handle_clarifai_completion_chunk(self, chunk):
|
||||||
try:
|
try:
|
||||||
if isinstance(chunk, dict):
|
if isinstance(chunk, dict):
|
||||||
|
@ -11337,6 +11373,12 @@ class CustomStreamWrapper:
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
self.received_finish_reason = response_obj["finish_reason"]
|
self.received_finish_reason = response_obj["finish_reason"]
|
||||||
|
elif self.custom_llm_provider == "triton":
|
||||||
|
response_obj = self.handle_triton_stream(chunk)
|
||||||
|
completion_obj["content"] = response_obj["text"]
|
||||||
|
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||||
|
if response_obj["is_finished"]:
|
||||||
|
self.received_finish_reason = response_obj["finish_reason"]
|
||||||
elif self.custom_llm_provider == "text-completion-openai":
|
elif self.custom_llm_provider == "text-completion-openai":
|
||||||
response_obj = self.handle_openai_text_completion_chunk(chunk)
|
response_obj = self.handle_openai_text_completion_chunk(chunk)
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
|
@ -11773,6 +11815,7 @@ class CustomStreamWrapper:
|
||||||
or self.custom_llm_provider == "predibase"
|
or self.custom_llm_provider == "predibase"
|
||||||
or self.custom_llm_provider == "databricks"
|
or self.custom_llm_provider == "databricks"
|
||||||
or self.custom_llm_provider == "bedrock"
|
or self.custom_llm_provider == "bedrock"
|
||||||
|
or self.custom_llm_provider == "triton"
|
||||||
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
||||||
):
|
):
|
||||||
async for chunk in self.completion_stream:
|
async for chunk in self.completion_stream:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue