#Fixed mypy errors. The requests package and stubs need to be imported - waiting to hear from Ishaan/Krrish before changing requirements.txt

This commit is contained in:
Giri Tatavarty 2024-05-29 15:08:56 -07:00
parent a58dc68418
commit 51b9178630
2 changed files with 60 additions and 123 deletions

View file

@ -1,19 +1,20 @@
import os, types import os
import json import json
from enum import Enum from enum import Enum
import requests, copy # type: ignore import requests
import time import time
from typing import Callable, Optional, List from typing import Callable, Optional, List, Sequence, Any, Union, Dict
from litellm.utils import ModelResponse, Choices,Usage, map_finish_reason, CustomStreamWrapper, Message from litellm.utils import ModelResponse, Choices, Usage, map_finish_reason, CustomStreamWrapper, Message, EmbeddingResponse
import litellm import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM from .base import BaseLLM
import httpx # type: ignore import httpx
import requests from typing import Union,Collection
class TritonError(Exception): class TritonError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code: int, message: str) -> None:
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request( self.request = httpx.Request(
@ -25,54 +26,10 @@ class TritonError(Exception):
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class TritonChatCompletion(BaseLLM): class TritonChatCompletion(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
async def acompletion(
self,
data: dict,
model_response: ModelResponse,
api_base: str,
logging_obj=None,
api_key: Optional[str] = None,
):
async_handler = httpx.AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
if api_base.endswith("generate") : ### This is a trtllm model
async with httpx.AsyncClient() as client:
response = await client.post(url=api_base, json=data)
if response.status_code != 200:
raise TritonError(status_code=response.status_code, message=response.text)
_text_response = response.text
if logging_obj:
logging_obj.post_call(original_response=_text_response)
_json_response = response.json()
_output_text = _json_response["outputs"][0]["data"][0]
# decode the byte string
_output_text = _output_text.encode("latin-1").decode("unicode_escape").encode(
"latin-1"
).decode("utf-8")
model_response.model = _json_response.get("model_name", "None")
model_response.choices[0].message.content = _output_text
return model_response
async def aembedding( async def aembedding(
self, self,
data: dict, data: dict,
@ -80,8 +37,7 @@ class TritonChatCompletion(BaseLLM):
api_base: str, api_base: str,
logging_obj=None, logging_obj=None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
): ) -> EmbeddingResponse:
async_handler = AsyncHTTPHandler( async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0) timeout=httpx.Timeout(timeout=600.0, connect=5.0)
) )
@ -98,7 +54,7 @@ class TritonChatCompletion(BaseLLM):
_json_response = response.json() _json_response = response.json()
_outputs = _json_response["outputs"] _outputs = _json_response["outputs"]
_output_data = [ output["data"] for output in _outputs ] _output_data = [output["data"] for output in _outputs]
_embedding_output = { _embedding_output = {
"object": "embedding", "object": "embedding",
"index": 0, "index": 0,
@ -110,10 +66,10 @@ class TritonChatCompletion(BaseLLM):
return model_response return model_response
def embedding( async def embedding(
self, self,
model: str, model: str,
input: list, input: List[str],
timeout: float, timeout: float,
api_base: str, api_base: str,
model_response: litellm.utils.EmbeddingResponse, model_response: litellm.utils.EmbeddingResponse,
@ -121,21 +77,19 @@ class TritonChatCompletion(BaseLLM):
logging_obj=None, logging_obj=None,
optional_params=None, optional_params=None,
client=None, client=None,
aembedding=None, aembedding: bool = False,
): ) -> EmbeddingResponse:
data_for_triton = { data_for_triton = {
"inputs": [ "inputs": [
{ {
"name": "input_text", "name": "input_text",
"shape": [len(input)], #size of the input data "shape": [len(input)], # size of the input data
"datatype": "BYTES", "datatype": "BYTES",
"data": input, "data": input,
} }
] ]
} }
## LOGGING
curl_string = f"curl {api_base} -X POST -H 'Content-Type: application/json' -d '{data_for_triton}'" curl_string = f"curl {api_base} -X POST -H 'Content-Type: application/json' -d '{data_for_triton}'"
logging_obj.pre_call( logging_obj.pre_call(
@ -147,8 +101,8 @@ class TritonChatCompletion(BaseLLM):
}, },
) )
if aembedding == True: if aembedding:
response = self.aembedding( response = await self.aembedding(
data=data_for_triton, data=data_for_triton,
model_response=model_response, model_response=model_response,
logging_obj=logging_obj, logging_obj=logging_obj,
@ -160,11 +114,11 @@ class TritonChatCompletion(BaseLLM):
raise Exception( raise Exception(
"Only async embedding supported for triton, please use litellm.aembedding() for now" "Only async embedding supported for triton, please use litellm.aembedding() for now"
) )
## Using Sync completion for now - Async completion not supported yet.
def completion( def completion(
self, self,
model: str, model: str,
messages: list, messages: List[dict],
timeout: float, timeout: float,
api_base: str, api_base: str,
model_response: ModelResponse, model_response: ModelResponse,
@ -172,48 +126,44 @@ class TritonChatCompletion(BaseLLM):
logging_obj=None, logging_obj=None,
optional_params=None, optional_params=None,
client=None, client=None,
stream=False, stream: bool = False,
): ) -> ModelResponse:
# check if model is llama
data_for_triton = {} type_of_model = ""
type_of_model = "" "" if api_base.endswith("generate"): ### This is a trtllm model
if api_base.endswith("generate") : ### This is a trtllm model text_input = messages[0]["content"]
# this is a llama model data_for_triton: Dict[str, Any] = {
text_input = messages[0]["content"] "text_input": str(text_input),
data_for_triton = { "parameters": {
"text_input":f"{text_input}", "max_tokens": int(optional_params.get("max_tokens", 20)),
"parameters": { "bad_words": [""],
"max_tokens": optional_params.get("max_tokens", 20), "stop_words": [""]
"bad_words":[""], }
"stop_words":[""] }
}} data_for_triton["parameters"].update( optional_params)
for k,v in optional_params.items():
data_for_triton["parameters"][k] = v
type_of_model = "trtllm" type_of_model = "trtllm"
elif api_base.endswith("infer"): ### This is a infer model with a custom model on triton elif api_base.endswith("infer"): ### This is an infer model with a custom model on triton
# this is a custom model text_input = messages[0]["content"]
text_input = messages[0]["content"]
data_for_triton = {
"inputs": [{"name": "text_input","shape": [1],"datatype": "BYTES","data": [text_input] }]
}
for k,v in optional_params.items():
if not (k=="stream" or k=="max_retries"): ## skip these as they are added by litellm
datatype = "INT32" if type(v) == int else "BYTES"
datatype = "FP32" if type(v) == float else datatype
data_for_triton['inputs'].append({"name": k,"shape": [1],"datatype": datatype,"data": [v]})
# check for max_tokens which is required
if "max_tokens" not in optional_params:
data_for_triton['inputs'].append({"name": "max_tokens","shape": [1],"datatype": "INT32","data": [20]})
type_of_model = "infer"
else: ## Unknown model type passthrough
data_for_triton = { data_for_triton = {
messages[0]["content"] "inputs": [{"name": "text_input", "shape": [1], "datatype": "BYTES", "data": [text_input]}]
} }
for k, v in optional_params.items():
if not (k == "stream" or k == "max_retries"):
datatype = "INT32" if isinstance(v, int) else "BYTES"
datatype = "FP32" if isinstance(v, float) else datatype
data_for_triton['inputs'].append({"name": k, "shape": [1], "datatype": datatype, "data": [v]})
if "max_tokens" not in optional_params:
data_for_triton['inputs'].append({"name": "max_tokens", "shape": [1], "datatype": "INT32", "data": [20]})
type_of_model = "infer"
else: ## Unknown model type passthrough
data_for_triton = {
"inputs": [{"name": "text_input", "shape": [1], "datatype": "BYTES", "data": [messages[0]["content"]]}]
}
if logging_obj: if logging_obj:
logging_obj.pre_call( logging_obj.pre_call(
input=messages, input=messages,
@ -226,35 +176,22 @@ class TritonChatCompletion(BaseLLM):
) )
handler = requests.Session() handler = requests.Session()
handler.timeout = (600.0, 5.0) handler.timeout = (600.0, 5.0)
response = handler.post(url=api_base, json=data_for_triton) response = handler.post(url=api_base, json=data_for_triton)
if logging_obj: if logging_obj:
logging_obj.post_call(original_response=response) logging_obj.post_call(original_response=response)
if response.status_code != 200: if response.status_code != 200:
raise TritonError(status_code=response.status_code, message=response.text) raise TritonError(status_code=response.status_code, message=response.text)
_json_response=response.json() _json_response = response.json()
model_response.model = _json_response.get("model_name", "None") model_response.model = _json_response.get("model_name", "None")
if type_of_model == "trtllm": if type_of_model == "trtllm":
# The actual response is part of the text_output key in the response model_response.choices = [Choices(index=0, message=Message(content=_json_response['text_output']))]
model_response['choices'] = [ Choices(index=0, message= Message(content=_json_response['text_output']))]
elif type_of_model == "infer": elif type_of_model == "infer":
# The actual response is part of the outputs key in the response model_response.choices = [Choices(index=0, message=Message(content=_json_response['outputs'][0]['data']))]
model_response['choices'] = [ Choices(index=0, message= Message(content=_json_response['outputs'][0]['data']))]
else: else:
## just passthrough the response model_response.choices = [Choices(index=0, message=Message(content=_json_response['outputs']))]
model_response['choices'] = [ Choices(index=0, message= Message(content=_json_response['outputs']))]
return model_response
"""
response = self.acompletion(
data=data_for_triton,
model_response=model_response,
logging_obj=logging_obj,
api_base=api_base,
api_key=api_key,
)
"""
return model_response

View file

@ -2261,7 +2261,7 @@ def completion(
) )
model_response = triton_chat_completions.completion( model_response = triton_chat_completions.completion(
api_base=api_base, api_base=api_base,
timeout=timeout, timeout=timeout, # type: ignore
model=model, model=model,
messages=messages, messages=messages,
model_response=model_response, model_response=model_response,