mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
197 lines
6.9 KiB
Python
197 lines
6.9 KiB
Python
import os
|
|
import json
|
|
from enum import Enum
|
|
import requests # type: ignore
|
|
import time
|
|
from typing import Callable, Optional, List, Sequence, Any, Union, Dict
|
|
from litellm.utils import ModelResponse, Choices, Usage, map_finish_reason, CustomStreamWrapper, Message, EmbeddingResponse
|
|
import litellm
|
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
|
from .base import BaseLLM
|
|
import httpx # type: ignore
|
|
|
|
|
|
|
|
class TritonError(Exception):
|
|
def __init__(self, status_code: int, message: str) -> None:
|
|
self.status_code = status_code
|
|
self.message = message
|
|
self.request = httpx.Request(
|
|
method="POST",
|
|
url="https://api.anthropic.com/v1/messages", # using anthropic api base since httpx requires a url
|
|
)
|
|
self.response = httpx.Response(status_code=status_code, request=self.request)
|
|
super().__init__(
|
|
self.message
|
|
) # Call the base class constructor with the parameters it needs
|
|
|
|
class TritonChatCompletion(BaseLLM):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
async def aembedding(
|
|
self,
|
|
data: dict,
|
|
model_response: litellm.utils.EmbeddingResponse,
|
|
api_base: str,
|
|
logging_obj=None,
|
|
api_key: Optional[str] = None,
|
|
) -> EmbeddingResponse:
|
|
async_handler = AsyncHTTPHandler(
|
|
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
|
)
|
|
|
|
response = await async_handler.post(url=api_base, data=json.dumps(data))
|
|
|
|
if response.status_code != 200:
|
|
raise TritonError(status_code=response.status_code, message=response.text)
|
|
|
|
_text_response = response.text
|
|
|
|
logging_obj.post_call(original_response=_text_response)
|
|
|
|
_json_response = response.json()
|
|
|
|
_outputs = _json_response["outputs"]
|
|
_output_data = [output["data"] for output in _outputs]
|
|
_embedding_output = {
|
|
"object": "embedding",
|
|
"index": 0,
|
|
"embedding": _output_data,
|
|
}
|
|
|
|
model_response.model = _json_response.get("model_name", "None")
|
|
model_response.data = [_embedding_output]
|
|
|
|
return model_response
|
|
|
|
async def embedding(
|
|
self,
|
|
model: str,
|
|
input: List[str],
|
|
timeout: float,
|
|
api_base: str,
|
|
model_response: litellm.utils.EmbeddingResponse,
|
|
api_key: Optional[str] = None,
|
|
logging_obj=None,
|
|
optional_params=None,
|
|
client=None,
|
|
aembedding: bool = False,
|
|
) -> EmbeddingResponse:
|
|
data_for_triton = {
|
|
"inputs": [
|
|
{
|
|
"name": "input_text",
|
|
"shape": [len(input)], # size of the input data
|
|
"datatype": "BYTES",
|
|
"data": input,
|
|
}
|
|
]
|
|
}
|
|
|
|
curl_string = f"curl {api_base} -X POST -H 'Content-Type: application/json' -d '{data_for_triton}'"
|
|
|
|
logging_obj.pre_call(
|
|
input="",
|
|
api_key=None,
|
|
additional_args={
|
|
"complete_input_dict": optional_params,
|
|
"request_str": curl_string,
|
|
},
|
|
)
|
|
|
|
if aembedding:
|
|
response = await self.aembedding(
|
|
data=data_for_triton,
|
|
model_response=model_response,
|
|
logging_obj=logging_obj,
|
|
api_base=api_base,
|
|
api_key=api_key,
|
|
)
|
|
return response
|
|
else:
|
|
raise Exception(
|
|
"Only async embedding supported for triton, please use litellm.aembedding() for now"
|
|
)
|
|
|
|
def completion(
|
|
self,
|
|
model: str,
|
|
messages: List[dict],
|
|
timeout: float,
|
|
api_base: str,
|
|
model_response: ModelResponse,
|
|
api_key: Optional[str] = None,
|
|
logging_obj=None,
|
|
optional_params=None,
|
|
client=None,
|
|
stream: bool = False,
|
|
) -> ModelResponse:
|
|
|
|
type_of_model = ""
|
|
if api_base.endswith("generate"): ### This is a trtllm model
|
|
text_input = messages[0]["content"]
|
|
data_for_triton: Dict[str, Any] = {
|
|
"text_input": str(text_input),
|
|
"parameters": {
|
|
"max_tokens": int(optional_params.get("max_tokens", 20)),
|
|
"bad_words": [""],
|
|
"stop_words": [""]
|
|
}
|
|
}
|
|
data_for_triton["parameters"].update( optional_params)
|
|
type_of_model = "trtllm"
|
|
|
|
elif api_base.endswith("infer"): ### This is an infer model with a custom model on triton
|
|
text_input = messages[0]["content"]
|
|
data_for_triton = {
|
|
"inputs": [{"name": "text_input", "shape": [1], "datatype": "BYTES", "data": [text_input]}]
|
|
}
|
|
|
|
for k, v in optional_params.items():
|
|
if not (k == "stream" or k == "max_retries"):
|
|
datatype = "INT32" if isinstance(v, int) else "BYTES"
|
|
datatype = "FP32" if isinstance(v, float) else datatype
|
|
data_for_triton['inputs'].append({"name": k, "shape": [1], "datatype": datatype, "data": [v]})
|
|
|
|
if "max_tokens" not in optional_params:
|
|
data_for_triton['inputs'].append({"name": "max_tokens", "shape": [1], "datatype": "INT32", "data": [20]})
|
|
|
|
type_of_model = "infer"
|
|
else: ## Unknown model type passthrough
|
|
data_for_triton = {
|
|
"inputs": [{"name": "text_input", "shape": [1], "datatype": "BYTES", "data": [messages[0]["content"]]}]
|
|
}
|
|
|
|
if logging_obj:
|
|
logging_obj.pre_call(
|
|
input=messages,
|
|
api_key=api_key,
|
|
additional_args={
|
|
"complete_input_dict": optional_params,
|
|
"api_base": api_base,
|
|
"http_client": client,
|
|
},
|
|
)
|
|
handler = requests.Session()
|
|
handler.timeout = (600.0, 5.0)
|
|
|
|
response = handler.post(url=api_base, json=data_for_triton)
|
|
|
|
if logging_obj:
|
|
logging_obj.post_call(original_response=response)
|
|
|
|
if response.status_code != 200:
|
|
raise TritonError(status_code=response.status_code, message=response.text)
|
|
_json_response = response.json()
|
|
|
|
model_response.model = _json_response.get("model_name", "None")
|
|
if type_of_model == "trtllm":
|
|
model_response.choices = [Choices(index=0, message=Message(content=_json_response['text_output']))]
|
|
elif type_of_model == "infer":
|
|
model_response.choices = [Choices(index=0, message=Message(content=_json_response['outputs'][0]['data']))]
|
|
else:
|
|
model_response.choices = [Choices(index=0, message=Message(content=_json_response['outputs']))]
|
|
|
|
return model_response
|