mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
use common logic for re-using openai clients
This commit is contained in:
parent
bdf77f6f4b
commit
f73e9047dc
2 changed files with 95 additions and 1 deletions
|
@ -2,14 +2,17 @@
|
|||
Common helpers / utils across al OpenAI endpoints
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
|
||||
|
||||
|
||||
class OpenAIError(BaseLLMException):
|
||||
|
@ -100,6 +103,81 @@ class BaseOpenAILLM:
|
|||
Base class for OpenAI LLMs for getting their httpx clients and SSL verification settings
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_cached_openai_client(
|
||||
client_initialization_params: dict, client_type: Literal["openai", "azure"]
|
||||
) -> Optional[Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]]:
|
||||
"""Retrieves the OpenAI client from the in-memory cache based on the client initialization parameters"""
|
||||
_cache_key = BaseOpenAILLM.get_openai_client_cache_key(
|
||||
client_initialization_params=client_initialization_params,
|
||||
client_type=client_type,
|
||||
)
|
||||
_cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key)
|
||||
return _cached_client
|
||||
|
||||
@staticmethod
|
||||
def set_cached_openai_client(
|
||||
openai_client: Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI],
|
||||
client_type: Literal["openai", "azure"],
|
||||
client_initialization_params: dict,
|
||||
):
|
||||
"""Stores the OpenAI client in the in-memory cache for _DEFAULT_TTL_FOR_HTTPX_CLIENTS SECONDS"""
|
||||
_cache_key = BaseOpenAILLM.get_openai_client_cache_key(
|
||||
client_initialization_params=client_initialization_params,
|
||||
client_type=client_type,
|
||||
)
|
||||
litellm.in_memory_llm_clients_cache.set_cache(
|
||||
key=_cache_key,
|
||||
value=openai_client,
|
||||
ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_openai_client_cache_key(
|
||||
client_initialization_params: dict, client_type: Literal["openai", "azure"]
|
||||
) -> str:
|
||||
"""Creates a cache key for the OpenAI client based on the client initialization parameters"""
|
||||
hashed_api_key = None
|
||||
if client_initialization_params.get("api_key") is not None:
|
||||
hash_object = hashlib.sha256(
|
||||
client_initialization_params.get("api_key", "").encode()
|
||||
)
|
||||
# Hexadecimal representation of the hash
|
||||
hashed_api_key = hash_object.hexdigest()
|
||||
|
||||
# Create a more readable cache key using a list of key-value pairs
|
||||
key_parts = [
|
||||
f"hashed_api_key={hashed_api_key}",
|
||||
f"is_async={client_initialization_params.get('is_async')}",
|
||||
]
|
||||
|
||||
for param in BaseOpenAILLM.get_openai_client_initialization_param_fields(
|
||||
client_type=client_type
|
||||
):
|
||||
key_parts.append(f"{param}={client_initialization_params.get(param)}")
|
||||
|
||||
_cache_key = ",".join(key_parts)
|
||||
|
||||
return _cache_key
|
||||
|
||||
@staticmethod
|
||||
def get_openai_client_initialization_param_fields(
|
||||
client_type: Literal["openai", "azure"]
|
||||
) -> list[str]:
|
||||
"""Returns a list of fields that are used to initialize the OpenAI client"""
|
||||
import inspect
|
||||
|
||||
from openai import AzureOpenAI, OpenAI
|
||||
|
||||
if client_type == "openai":
|
||||
signature = inspect.signature(OpenAI.__init__)
|
||||
else:
|
||||
signature = inspect.signature(AzureOpenAI.__init__)
|
||||
|
||||
# Extract parameter names, excluding 'self'
|
||||
param_names = [param for param in signature.parameters if param != "self"]
|
||||
return param_names
|
||||
|
||||
@staticmethod
|
||||
def _get_async_http_client() -> Optional[httpx.AsyncClient]:
|
||||
if litellm.aclient_session is not None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue