mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
198 lines
6.9 KiB
Python
198 lines
6.9 KiB
Python
"""
|
|
Common helpers / utils across al OpenAI endpoints
|
|
"""
|
|
|
|
import hashlib
|
|
import json
|
|
from typing import Any, Dict, List, Literal, Optional, Union
|
|
|
|
import httpx
|
|
import openai
|
|
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
|
|
|
|
import litellm
|
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
|
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
|
|
|
|
|
|
class OpenAIError(BaseLLMException):
|
|
def __init__(
|
|
self,
|
|
status_code: int,
|
|
message: str,
|
|
request: Optional[httpx.Request] = None,
|
|
response: Optional[httpx.Response] = None,
|
|
headers: Optional[Union[dict, httpx.Headers]] = None,
|
|
body: Optional[dict] = None,
|
|
):
|
|
self.status_code = status_code
|
|
self.message = message
|
|
self.headers = headers
|
|
if request:
|
|
self.request = request
|
|
else:
|
|
self.request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
|
if response:
|
|
self.response = response
|
|
else:
|
|
self.response = httpx.Response(
|
|
status_code=status_code, request=self.request
|
|
)
|
|
super().__init__(
|
|
status_code=status_code,
|
|
message=self.message,
|
|
headers=self.headers,
|
|
request=self.request,
|
|
response=self.response,
|
|
body=body,
|
|
)
|
|
|
|
|
|
####### Error Handling Utils for OpenAI API #######################
|
|
###################################################################
|
|
def drop_params_from_unprocessable_entity_error(
|
|
e: Union[openai.UnprocessableEntityError, httpx.HTTPStatusError],
|
|
data: Dict[str, Any],
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Helper function to read OpenAI UnprocessableEntityError and drop the params that raised an error from the error message.
|
|
|
|
Args:
|
|
e (UnprocessableEntityError): The UnprocessableEntityError exception
|
|
data (Dict[str, Any]): The original data dictionary containing all parameters
|
|
|
|
Returns:
|
|
Dict[str, Any]: A new dictionary with invalid parameters removed
|
|
"""
|
|
invalid_params: List[str] = []
|
|
if isinstance(e, httpx.HTTPStatusError):
|
|
error_json = e.response.json()
|
|
error_message = error_json.get("error", {})
|
|
error_body = error_message
|
|
else:
|
|
error_body = e.body
|
|
if (
|
|
error_body is not None
|
|
and isinstance(error_body, dict)
|
|
and error_body.get("message")
|
|
):
|
|
message = error_body.get("message", {})
|
|
if isinstance(message, str):
|
|
try:
|
|
message = json.loads(message)
|
|
except json.JSONDecodeError:
|
|
message = {"detail": message}
|
|
detail = message.get("detail")
|
|
|
|
if isinstance(detail, List) and len(detail) > 0 and isinstance(detail[0], dict):
|
|
for error_dict in detail:
|
|
if (
|
|
error_dict.get("loc")
|
|
and isinstance(error_dict.get("loc"), list)
|
|
and len(error_dict.get("loc")) == 2
|
|
):
|
|
invalid_params.append(error_dict["loc"][1])
|
|
|
|
new_data = {k: v for k, v in data.items() if k not in invalid_params}
|
|
|
|
return new_data
|
|
|
|
|
|
class BaseOpenAILLM:
|
|
"""
|
|
Base class for OpenAI LLMs for getting their httpx clients and SSL verification settings
|
|
"""
|
|
|
|
@staticmethod
|
|
def get_cached_openai_client(
|
|
client_initialization_params: dict, client_type: Literal["openai", "azure"]
|
|
) -> Optional[Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]]:
|
|
"""Retrieves the OpenAI client from the in-memory cache based on the client initialization parameters"""
|
|
_cache_key = BaseOpenAILLM.get_openai_client_cache_key(
|
|
client_initialization_params=client_initialization_params,
|
|
client_type=client_type,
|
|
)
|
|
_cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key)
|
|
return _cached_client
|
|
|
|
@staticmethod
|
|
def set_cached_openai_client(
|
|
openai_client: Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI],
|
|
client_type: Literal["openai", "azure"],
|
|
client_initialization_params: dict,
|
|
):
|
|
"""Stores the OpenAI client in the in-memory cache for _DEFAULT_TTL_FOR_HTTPX_CLIENTS SECONDS"""
|
|
_cache_key = BaseOpenAILLM.get_openai_client_cache_key(
|
|
client_initialization_params=client_initialization_params,
|
|
client_type=client_type,
|
|
)
|
|
litellm.in_memory_llm_clients_cache.set_cache(
|
|
key=_cache_key,
|
|
value=openai_client,
|
|
ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS,
|
|
)
|
|
|
|
@staticmethod
|
|
def get_openai_client_cache_key(
|
|
client_initialization_params: dict, client_type: Literal["openai", "azure"]
|
|
) -> str:
|
|
"""Creates a cache key for the OpenAI client based on the client initialization parameters"""
|
|
hashed_api_key = None
|
|
if client_initialization_params.get("api_key") is not None:
|
|
hash_object = hashlib.sha256(
|
|
client_initialization_params.get("api_key", "").encode()
|
|
)
|
|
# Hexadecimal representation of the hash
|
|
hashed_api_key = hash_object.hexdigest()
|
|
|
|
# Create a more readable cache key using a list of key-value pairs
|
|
key_parts = [
|
|
f"hashed_api_key={hashed_api_key}",
|
|
f"is_async={client_initialization_params.get('is_async')}",
|
|
]
|
|
|
|
for param in BaseOpenAILLM.get_openai_client_initialization_param_fields(
|
|
client_type=client_type
|
|
):
|
|
key_parts.append(f"{param}={client_initialization_params.get(param)}")
|
|
|
|
_cache_key = ",".join(key_parts)
|
|
|
|
return _cache_key
|
|
|
|
@staticmethod
|
|
def get_openai_client_initialization_param_fields(
|
|
client_type: Literal["openai", "azure"]
|
|
) -> list[str]:
|
|
"""Returns a list of fields that are used to initialize the OpenAI client"""
|
|
import inspect
|
|
|
|
from openai import AzureOpenAI, OpenAI
|
|
|
|
if client_type == "openai":
|
|
signature = inspect.signature(OpenAI.__init__)
|
|
else:
|
|
signature = inspect.signature(AzureOpenAI.__init__)
|
|
|
|
# Extract parameter names, excluding 'self'
|
|
param_names = [param for param in signature.parameters if param != "self"]
|
|
return param_names
|
|
|
|
@staticmethod
|
|
def _get_async_http_client() -> Optional[httpx.AsyncClient]:
|
|
if litellm.aclient_session is not None:
|
|
return litellm.aclient_session
|
|
|
|
return httpx.AsyncClient(
|
|
limits=httpx.Limits(max_connections=1000, max_keepalive_connections=100),
|
|
verify=litellm.ssl_verify,
|
|
)
|
|
|
|
@staticmethod
|
|
def _get_sync_http_client() -> Optional[httpx.Client]:
|
|
if litellm.client_session is not None:
|
|
return litellm.client_session
|
|
return httpx.Client(
|
|
limits=httpx.Limits(max_connections=1000, max_keepalive_connections=100),
|
|
verify=litellm.ssl_verify,
|
|
)
|