forked from phoenix/litellm-mirror
* fix triton * fix TEXT_COMPLETION_CODESTRAL * fix REPLICATE * fix CLARIFAI * fix HUGGINGFACE * add test_no_async_http_handler_usage * fix PREDIBASE * fix anthropic use get_async_httpx_client * fix vertex fine tuning * fix dbricks get_async_httpx_client * fix get_async_httpx_client vertex * fix get_async_httpx_client * fix get_async_httpx_client * fix make_async_azure_httpx_request * fix check_for_async_http_handler * test: cleanup mistral model * add check for AsyncClient * fix check_for_async_http_handler * fix get_async_httpx_client * fix tests using in_memory_llm_clients_cache * fix langfuse import * fix import --------- Co-authored-by: Krrish Dholakia <krrishdholakia@gmail.com>
347 lines
11 KiB
Python
347 lines
11 KiB
Python
import json
|
|
import os
|
|
import time
|
|
from enum import Enum
|
|
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
|
|
|
import httpx # type: ignore
|
|
import requests # type: ignore
|
|
|
|
import litellm
|
|
from litellm.llms.custom_httpx.http_handler import (
|
|
AsyncHTTPHandler,
|
|
HTTPHandler,
|
|
get_async_httpx_client,
|
|
)
|
|
from litellm.utils import (
|
|
Choices,
|
|
CustomStreamWrapper,
|
|
Delta,
|
|
EmbeddingResponse,
|
|
Message,
|
|
ModelResponse,
|
|
Usage,
|
|
map_finish_reason,
|
|
)
|
|
|
|
from .base import BaseLLM
|
|
from .prompt_templates.factory import custom_prompt, prompt_factory
|
|
|
|
|
|
class TritonError(Exception):
|
|
def __init__(self, status_code: int, message: str) -> None:
|
|
self.status_code = status_code
|
|
self.message = message
|
|
self.request = httpx.Request(
|
|
method="POST",
|
|
url="https://api.anthropic.com/v1/messages", # using anthropic api base since httpx requires a url
|
|
)
|
|
self.response = httpx.Response(status_code=status_code, request=self.request)
|
|
super().__init__(
|
|
self.message
|
|
) # Call the base class constructor with the parameters it needs
|
|
|
|
|
|
class TritonChatCompletion(BaseLLM):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
async def aembedding(
|
|
self,
|
|
data: dict,
|
|
model_response: litellm.utils.EmbeddingResponse,
|
|
api_base: str,
|
|
logging_obj: Any,
|
|
api_key: Optional[str] = None,
|
|
) -> EmbeddingResponse:
|
|
async_handler = get_async_httpx_client(
|
|
llm_provider=litellm.LlmProviders.TRITON, params={"timeout": 600.0}
|
|
)
|
|
|
|
response = await async_handler.post(url=api_base, data=json.dumps(data))
|
|
|
|
if response.status_code != 200:
|
|
raise TritonError(status_code=response.status_code, message=response.text)
|
|
|
|
_text_response = response.text
|
|
|
|
logging_obj.post_call(original_response=_text_response)
|
|
|
|
_json_response = response.json()
|
|
_embedding_output = []
|
|
|
|
_outputs = _json_response["outputs"]
|
|
for output in _outputs:
|
|
_shape = output["shape"]
|
|
_data = output["data"]
|
|
_split_output_data = self.split_embedding_by_shape(_data, _shape)
|
|
|
|
for idx, embedding in enumerate(_split_output_data):
|
|
_embedding_output.append(
|
|
{
|
|
"object": "embedding",
|
|
"index": idx,
|
|
"embedding": embedding,
|
|
}
|
|
)
|
|
|
|
model_response.model = _json_response.get("model_name", "None")
|
|
model_response.data = _embedding_output
|
|
|
|
return model_response
|
|
|
|
async def embedding(
|
|
self,
|
|
model: str,
|
|
input: List[str],
|
|
timeout: float,
|
|
api_base: str,
|
|
model_response: litellm.utils.EmbeddingResponse,
|
|
logging_obj: Any,
|
|
optional_params: dict,
|
|
api_key: Optional[str] = None,
|
|
client=None,
|
|
aembedding: bool = False,
|
|
) -> EmbeddingResponse:
|
|
data_for_triton = {
|
|
"inputs": [
|
|
{
|
|
"name": "input_text",
|
|
"shape": [len(input)],
|
|
"datatype": "BYTES",
|
|
"data": input,
|
|
}
|
|
]
|
|
}
|
|
|
|
curl_string = f"curl {api_base} -X POST -H 'Content-Type: application/json' -d '{data_for_triton}'"
|
|
|
|
logging_obj.pre_call(
|
|
input="",
|
|
api_key=None,
|
|
additional_args={
|
|
"complete_input_dict": optional_params,
|
|
"request_str": curl_string,
|
|
},
|
|
)
|
|
|
|
if aembedding:
|
|
response = await self.aembedding( # type: ignore
|
|
data=data_for_triton,
|
|
model_response=model_response,
|
|
logging_obj=logging_obj,
|
|
api_base=api_base,
|
|
api_key=api_key,
|
|
)
|
|
return response
|
|
else:
|
|
raise Exception(
|
|
"Only async embedding supported for triton, please use litellm.aembedding() for now"
|
|
)
|
|
|
|
def completion(
|
|
self,
|
|
model: str,
|
|
messages: List[dict],
|
|
timeout: float,
|
|
api_base: str,
|
|
logging_obj: Any,
|
|
optional_params: dict,
|
|
model_response: ModelResponse,
|
|
api_key: Optional[str] = None,
|
|
client=None,
|
|
stream: Optional[bool] = False,
|
|
acompletion: bool = False,
|
|
) -> ModelResponse:
|
|
type_of_model = ""
|
|
optional_params.pop("stream", False)
|
|
if api_base.endswith("generate"): ### This is a trtllm model
|
|
text_input = messages[0]["content"]
|
|
data_for_triton: Dict[str, Any] = {
|
|
"text_input": prompt_factory(model=model, messages=messages),
|
|
"parameters": {
|
|
"max_tokens": int(optional_params.get("max_tokens", 2000)),
|
|
"bad_words": [""],
|
|
"stop_words": [""],
|
|
},
|
|
"stream": bool(stream),
|
|
}
|
|
data_for_triton["parameters"].update(optional_params)
|
|
type_of_model = "trtllm"
|
|
|
|
elif api_base.endswith(
|
|
"infer"
|
|
): ### This is an infer model with a custom model on triton
|
|
text_input = messages[0]["content"]
|
|
data_for_triton = {
|
|
"inputs": [
|
|
{
|
|
"name": "text_input",
|
|
"shape": [1],
|
|
"datatype": "BYTES",
|
|
"data": [text_input],
|
|
}
|
|
]
|
|
}
|
|
|
|
for k, v in optional_params.items():
|
|
if not (k == "stream" or k == "max_retries"):
|
|
datatype = "INT32" if isinstance(v, int) else "BYTES"
|
|
datatype = "FP32" if isinstance(v, float) else datatype
|
|
data_for_triton["inputs"].append(
|
|
{"name": k, "shape": [1], "datatype": datatype, "data": [v]}
|
|
)
|
|
|
|
if "max_tokens" not in optional_params:
|
|
data_for_triton["inputs"].append(
|
|
{
|
|
"name": "max_tokens",
|
|
"shape": [1],
|
|
"datatype": "INT32",
|
|
"data": [20],
|
|
}
|
|
)
|
|
|
|
type_of_model = "infer"
|
|
else: ## Unknown model type passthrough
|
|
data_for_triton = {
|
|
"inputs": [
|
|
{
|
|
"name": "text_input",
|
|
"shape": [1],
|
|
"datatype": "BYTES",
|
|
"data": [messages[0]["content"]],
|
|
}
|
|
]
|
|
}
|
|
|
|
if logging_obj:
|
|
logging_obj.pre_call(
|
|
input=messages,
|
|
api_key=api_key,
|
|
additional_args={
|
|
"complete_input_dict": optional_params,
|
|
"api_base": api_base,
|
|
"http_client": client,
|
|
},
|
|
)
|
|
|
|
headers = {"Content-Type": "application/json"}
|
|
json_data_for_triton: str = json.dumps(data_for_triton)
|
|
|
|
if acompletion:
|
|
return self.acompletion( # type: ignore
|
|
model,
|
|
json_data_for_triton,
|
|
headers=headers,
|
|
logging_obj=logging_obj,
|
|
api_base=api_base,
|
|
stream=stream,
|
|
model_response=model_response,
|
|
type_of_model=type_of_model,
|
|
)
|
|
else:
|
|
handler = HTTPHandler()
|
|
if stream:
|
|
return self._handle_stream( # type: ignore
|
|
handler, api_base, json_data_for_triton, model, logging_obj
|
|
)
|
|
else:
|
|
response = handler.post(
|
|
url=api_base, data=json_data_for_triton, headers=headers
|
|
)
|
|
return self._handle_response(
|
|
response, model_response, logging_obj, type_of_model=type_of_model
|
|
)
|
|
|
|
async def acompletion(
|
|
self,
|
|
model: str,
|
|
data_for_triton,
|
|
api_base,
|
|
stream,
|
|
logging_obj,
|
|
headers,
|
|
model_response,
|
|
type_of_model,
|
|
) -> ModelResponse:
|
|
handler = get_async_httpx_client(
|
|
llm_provider=litellm.LlmProviders.TRITON, params={"timeout": 600.0}
|
|
)
|
|
if stream:
|
|
return self._ahandle_stream( # type: ignore
|
|
handler, api_base, data_for_triton, model, logging_obj
|
|
)
|
|
else:
|
|
response = await handler.post(
|
|
url=api_base, data=data_for_triton, headers=headers
|
|
)
|
|
|
|
return self._handle_response(
|
|
response, model_response, logging_obj, type_of_model=type_of_model
|
|
)
|
|
|
|
def _handle_stream(self, handler, api_base, data_for_triton, model, logging_obj):
|
|
response = handler.post(
|
|
url=api_base + "_stream", data=data_for_triton, stream=True
|
|
)
|
|
streamwrapper = litellm.CustomStreamWrapper(
|
|
response.iter_lines(),
|
|
model=model,
|
|
custom_llm_provider="triton",
|
|
logging_obj=logging_obj,
|
|
)
|
|
for chunk in streamwrapper:
|
|
yield (chunk)
|
|
|
|
async def _ahandle_stream(
|
|
self, handler, api_base, data_for_triton, model, logging_obj
|
|
):
|
|
response = await handler.post(
|
|
url=api_base + "_stream", data=data_for_triton, stream=True
|
|
)
|
|
streamwrapper = litellm.CustomStreamWrapper(
|
|
response.aiter_lines(),
|
|
model=model,
|
|
custom_llm_provider="triton",
|
|
logging_obj=logging_obj,
|
|
)
|
|
async for chunk in streamwrapper:
|
|
yield (chunk)
|
|
|
|
def _handle_response(self, response, model_response, logging_obj, type_of_model):
|
|
if logging_obj:
|
|
logging_obj.post_call(original_response=response)
|
|
|
|
if response.status_code != 200:
|
|
raise TritonError(status_code=response.status_code, message=response.text)
|
|
|
|
_json_response = response.json()
|
|
model_response.model = _json_response.get("model_name", "None")
|
|
if type_of_model == "trtllm":
|
|
model_response.choices = [
|
|
Choices(index=0, message=Message(content=_json_response["text_output"]))
|
|
]
|
|
elif type_of_model == "infer":
|
|
model_response.choices = [
|
|
Choices(
|
|
index=0,
|
|
message=Message(content=_json_response["outputs"][0]["data"]),
|
|
)
|
|
]
|
|
else:
|
|
model_response.choices = [
|
|
Choices(index=0, message=Message(content=_json_response["outputs"]))
|
|
]
|
|
return model_response
|
|
|
|
@staticmethod
|
|
def split_embedding_by_shape(
|
|
data: List[float], shape: List[int]
|
|
) -> List[List[float]]:
|
|
if len(shape) != 2:
|
|
raise ValueError("Shape must be of length 2.")
|
|
embedding_size = shape[1]
|
|
return [
|
|
data[i * embedding_size : (i + 1) * embedding_size] for i in range(shape[0])
|
|
]
|