use common caching logic for openai/azure clients

This commit is contained in:
Ishaan Jaff 2025-03-18 17:57:03 -07:00
parent f73e9047dc
commit a45830dac3

View file

@ -33,7 +33,6 @@ from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
from litellm.types.utils import ( from litellm.types.utils import (
EmbeddingResponse, EmbeddingResponse,
ImageResponse, ImageResponse,
@ -348,6 +347,7 @@ class OpenAIChatCompletion(BaseLLM, BaseOpenAILLM):
organization: Optional[str] = None, organization: Optional[str] = None,
client: Optional[Union[OpenAI, AsyncOpenAI]] = None, client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
): ):
client_initialization_params: Dict = locals()
if client is None: if client is None:
if not isinstance(max_retries, int): if not isinstance(max_retries, int):
raise OpenAIError( raise OpenAIError(
@ -356,20 +356,12 @@ class OpenAIChatCompletion(BaseLLM, BaseOpenAILLM):
max_retries max_retries
), ),
) )
# Creating a new OpenAI Client cached_client = self.get_cached_openai_client(
# check in memory cache before creating a new one client_initialization_params=client_initialization_params,
# Convert the API key to bytes client_type="openai",
hashed_api_key = None )
if api_key is not None: if cached_client:
hash_object = hashlib.sha256(api_key.encode()) return cached_client
# Hexadecimal representation of the hash
hashed_api_key = hash_object.hexdigest()
_cache_key = f"hashed_api_key={hashed_api_key},api_base={api_base},timeout={timeout},max_retries={max_retries},organization={organization},is_async={is_async}"
_cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key)
if _cached_client:
return _cached_client
if is_async: if is_async:
_new_client: Union[OpenAI, AsyncOpenAI] = AsyncOpenAI( _new_client: Union[OpenAI, AsyncOpenAI] = AsyncOpenAI(
api_key=api_key, api_key=api_key,
@ -390,10 +382,10 @@ class OpenAIChatCompletion(BaseLLM, BaseOpenAILLM):
) )
## SAVE CACHE KEY ## SAVE CACHE KEY
litellm.in_memory_llm_clients_cache.set_cache( self.set_cached_openai_client(
key=_cache_key, openai_client=_new_client,
value=_new_client, client_initialization_params=client_initialization_params,
ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS, client_type="openai",
) )
return _new_client return _new_client