forked from phoenix/litellm-mirror
#Fixed mypy errors. The requests package and stubs need to be imported - waiting to hear from Ishaan/Krrish before changing requirements.txt
This commit is contained in:
parent
a58dc68418
commit
51b9178630
2 changed files with 60 additions and 123 deletions
|
@ -1,19 +1,20 @@
|
|||
import os, types
|
||||
import os
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests, copy # type: ignore
|
||||
import requests
|
||||
import time
|
||||
from typing import Callable, Optional, List
|
||||
from litellm.utils import ModelResponse, Choices,Usage, map_finish_reason, CustomStreamWrapper, Message
|
||||
from typing import Callable, Optional, List, Sequence, Any, Union, Dict
|
||||
from litellm.utils import ModelResponse, Choices, Usage, map_finish_reason, CustomStreamWrapper, Message, EmbeddingResponse
|
||||
import litellm
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from .base import BaseLLM
|
||||
import httpx # type: ignore
|
||||
import requests
|
||||
import httpx
|
||||
from typing import Union,Collection
|
||||
|
||||
|
||||
class TritonError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
def __init__(self, status_code: int, message: str) -> None:
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
|
@ -25,54 +26,10 @@ class TritonError(Exception):
|
|||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class TritonChatCompletion(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
async def acompletion(
|
||||
self,
|
||||
data: dict,
|
||||
model_response: ModelResponse,
|
||||
api_base: str,
|
||||
logging_obj=None,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
|
||||
async_handler = httpx.AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
)
|
||||
|
||||
if api_base.endswith("generate") : ### This is a trtllm model
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url=api_base, json=data)
|
||||
|
||||
|
||||
|
||||
if response.status_code != 200:
|
||||
raise TritonError(status_code=response.status_code, message=response.text)
|
||||
|
||||
_text_response = response.text
|
||||
|
||||
|
||||
if logging_obj:
|
||||
logging_obj.post_call(original_response=_text_response)
|
||||
|
||||
_json_response = response.json()
|
||||
|
||||
_output_text = _json_response["outputs"][0]["data"][0]
|
||||
# decode the byte string
|
||||
_output_text = _output_text.encode("latin-1").decode("unicode_escape").encode(
|
||||
"latin-1"
|
||||
).decode("utf-8")
|
||||
|
||||
model_response.model = _json_response.get("model_name", "None")
|
||||
model_response.choices[0].message.content = _output_text
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
async def aembedding(
|
||||
self,
|
||||
data: dict,
|
||||
|
@ -80,8 +37,7 @@ class TritonChatCompletion(BaseLLM):
|
|||
api_base: str,
|
||||
logging_obj=None,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
|
||||
) -> EmbeddingResponse:
|
||||
async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
)
|
||||
|
@ -110,10 +66,10 @@ class TritonChatCompletion(BaseLLM):
|
|||
|
||||
return model_response
|
||||
|
||||
def embedding(
|
||||
async def embedding(
|
||||
self,
|
||||
model: str,
|
||||
input: list,
|
||||
input: List[str],
|
||||
timeout: float,
|
||||
api_base: str,
|
||||
model_response: litellm.utils.EmbeddingResponse,
|
||||
|
@ -121,8 +77,8 @@ class TritonChatCompletion(BaseLLM):
|
|||
logging_obj=None,
|
||||
optional_params=None,
|
||||
client=None,
|
||||
aembedding=None,
|
||||
):
|
||||
aembedding: bool = False,
|
||||
) -> EmbeddingResponse:
|
||||
data_for_triton = {
|
||||
"inputs": [
|
||||
{
|
||||
|
@ -134,8 +90,6 @@ class TritonChatCompletion(BaseLLM):
|
|||
]
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
|
||||
curl_string = f"curl {api_base} -X POST -H 'Content-Type: application/json' -d '{data_for_triton}'"
|
||||
|
||||
logging_obj.pre_call(
|
||||
|
@ -147,8 +101,8 @@ class TritonChatCompletion(BaseLLM):
|
|||
},
|
||||
)
|
||||
|
||||
if aembedding == True:
|
||||
response = self.aembedding(
|
||||
if aembedding:
|
||||
response = await self.aembedding(
|
||||
data=data_for_triton,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
|
@ -160,11 +114,11 @@ class TritonChatCompletion(BaseLLM):
|
|||
raise Exception(
|
||||
"Only async embedding supported for triton, please use litellm.aembedding() for now"
|
||||
)
|
||||
## Using Sync completion for now - Async completion not supported yet.
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
messages: List[dict],
|
||||
timeout: float,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
|
@ -172,46 +126,42 @@ class TritonChatCompletion(BaseLLM):
|
|||
logging_obj=None,
|
||||
optional_params=None,
|
||||
client=None,
|
||||
stream=False,
|
||||
):
|
||||
# check if model is llama
|
||||
data_for_triton = {}
|
||||
type_of_model = "" ""
|
||||
stream: bool = False,
|
||||
) -> ModelResponse:
|
||||
|
||||
type_of_model = ""
|
||||
if api_base.endswith("generate"): ### This is a trtllm model
|
||||
# this is a llama model
|
||||
text_input = messages[0]["content"]
|
||||
data_for_triton = {
|
||||
"text_input":f"{text_input}",
|
||||
data_for_triton: Dict[str, Any] = {
|
||||
"text_input": str(text_input),
|
||||
"parameters": {
|
||||
"max_tokens": optional_params.get("max_tokens", 20),
|
||||
"max_tokens": int(optional_params.get("max_tokens", 20)),
|
||||
"bad_words": [""],
|
||||
"stop_words": [""]
|
||||
}}
|
||||
for k,v in optional_params.items():
|
||||
data_for_triton["parameters"][k] = v
|
||||
}
|
||||
}
|
||||
data_for_triton["parameters"].update( optional_params)
|
||||
type_of_model = "trtllm"
|
||||
|
||||
elif api_base.endswith("infer"): ### This is a infer model with a custom model on triton
|
||||
# this is a custom model
|
||||
elif api_base.endswith("infer"): ### This is an infer model with a custom model on triton
|
||||
text_input = messages[0]["content"]
|
||||
data_for_triton = {
|
||||
"inputs": [{"name": "text_input", "shape": [1], "datatype": "BYTES", "data": [text_input]}]
|
||||
}
|
||||
|
||||
for k, v in optional_params.items():
|
||||
if not (k=="stream" or k=="max_retries"): ## skip these as they are added by litellm
|
||||
datatype = "INT32" if type(v) == int else "BYTES"
|
||||
datatype = "FP32" if type(v) == float else datatype
|
||||
if not (k == "stream" or k == "max_retries"):
|
||||
datatype = "INT32" if isinstance(v, int) else "BYTES"
|
||||
datatype = "FP32" if isinstance(v, float) else datatype
|
||||
data_for_triton['inputs'].append({"name": k, "shape": [1], "datatype": datatype, "data": [v]})
|
||||
|
||||
# check for max_tokens which is required
|
||||
if "max_tokens" not in optional_params:
|
||||
data_for_triton['inputs'].append({"name": "max_tokens", "shape": [1], "datatype": "INT32", "data": [20]})
|
||||
|
||||
type_of_model = "infer"
|
||||
else: ## Unknown model type passthrough
|
||||
data_for_triton = {
|
||||
messages[0]["content"]
|
||||
"inputs": [{"name": "text_input", "shape": [1], "datatype": "BYTES", "data": [messages[0]["content"]]}]
|
||||
}
|
||||
|
||||
if logging_obj:
|
||||
|
@ -229,7 +179,6 @@ class TritonChatCompletion(BaseLLM):
|
|||
|
||||
response = handler.post(url=api_base, json=data_for_triton)
|
||||
|
||||
|
||||
if logging_obj:
|
||||
logging_obj.post_call(original_response=response)
|
||||
|
||||
|
@ -239,22 +188,10 @@ class TritonChatCompletion(BaseLLM):
|
|||
|
||||
model_response.model = _json_response.get("model_name", "None")
|
||||
if type_of_model == "trtllm":
|
||||
# The actual response is part of the text_output key in the response
|
||||
model_response['choices'] = [ Choices(index=0, message= Message(content=_json_response['text_output']))]
|
||||
model_response.choices = [Choices(index=0, message=Message(content=_json_response['text_output']))]
|
||||
elif type_of_model == "infer":
|
||||
# The actual response is part of the outputs key in the response
|
||||
model_response['choices'] = [ Choices(index=0, message= Message(content=_json_response['outputs'][0]['data']))]
|
||||
model_response.choices = [Choices(index=0, message=Message(content=_json_response['outputs'][0]['data']))]
|
||||
else:
|
||||
## just passthrough the response
|
||||
model_response['choices'] = [ Choices(index=0, message= Message(content=_json_response['outputs']))]
|
||||
model_response.choices = [Choices(index=0, message=Message(content=_json_response['outputs']))]
|
||||
|
||||
"""
|
||||
response = self.acompletion(
|
||||
data=data_for_triton,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
)
|
||||
"""
|
||||
return model_response
|
|
@ -2261,7 +2261,7 @@ def completion(
|
|||
)
|
||||
model_response = triton_chat_completions.completion(
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
timeout=timeout, # type: ignore
|
||||
model=model,
|
||||
messages=messages,
|
||||
model_response=model_response,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue