diff --git a/litellm/llms/azure/common_utils.py b/litellm/llms/azure/common_utils.py index 34cca8fc8a..4d9c35a5fb 100644 --- a/litellm/llms/azure/common_utils.py +++ b/litellm/llms/azure/common_utils.py @@ -257,7 +257,17 @@ class BaseAzureLLM(BaseOpenAILLM): model: Optional[str] = None, ) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]: openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None + client_initialization_params: dict = locals() if client is None: + cached_client = self.get_cached_openai_client( + client_initialization_params=client_initialization_params, + client_type="azure", + ) + if cached_client and isinstance( + cached_client, (AzureOpenAI, AsyncAzureOpenAI) + ): + return cached_client + azure_client_params = self.initialize_azure_sdk_client( litellm_params=litellm_params or {}, api_key=api_key, @@ -278,6 +288,12 @@ class BaseAzureLLM(BaseOpenAILLM): # set api_version to version passed by user openai_client._custom_query.setdefault("api-version", api_version) + # save client in-memory cache + self.set_cached_openai_client( + openai_client=openai_client, + client_initialization_params=client_initialization_params, + client_type="azure", + ) return openai_client def initialize_azure_sdk_client( diff --git a/litellm/llms/openai/common_utils.py b/litellm/llms/openai/common_utils.py index 649ce2e0f1..ac84fbacf1 100644 --- a/litellm/llms/openai/common_utils.py +++ b/litellm/llms/openai/common_utils.py @@ -2,14 +2,17 @@ Common helpers / utils across al OpenAI endpoints """ +import hashlib import json -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union import httpx import openai +from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI import litellm from litellm.llms.base_llm.chat.transformation import BaseLLMException +from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS class OpenAIError(BaseLLMException): @@ -100,6 +103,81 @@ class BaseOpenAILLM: Base class for OpenAI LLMs for getting their httpx clients and SSL verification settings """ + @staticmethod + def get_cached_openai_client( + client_initialization_params: dict, client_type: Literal["openai", "azure"] + ) -> Optional[Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]]: + """Retrieves the OpenAI client from the in-memory cache based on the client initialization parameters""" + _cache_key = BaseOpenAILLM.get_openai_client_cache_key( + client_initialization_params=client_initialization_params, + client_type=client_type, + ) + _cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key) + return _cached_client + + @staticmethod + def set_cached_openai_client( + openai_client: Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI], + client_type: Literal["openai", "azure"], + client_initialization_params: dict, + ): + """Stores the OpenAI client in the in-memory cache for _DEFAULT_TTL_FOR_HTTPX_CLIENTS SECONDS""" + _cache_key = BaseOpenAILLM.get_openai_client_cache_key( + client_initialization_params=client_initialization_params, + client_type=client_type, + ) + litellm.in_memory_llm_clients_cache.set_cache( + key=_cache_key, + value=openai_client, + ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS, + ) + + @staticmethod + def get_openai_client_cache_key( + client_initialization_params: dict, client_type: Literal["openai", "azure"] + ) -> str: + """Creates a cache key for the OpenAI client based on the client initialization parameters""" + hashed_api_key = None + if client_initialization_params.get("api_key") is not None: + hash_object = hashlib.sha256( + client_initialization_params.get("api_key", "").encode() + ) + # Hexadecimal representation of the hash + hashed_api_key = hash_object.hexdigest() + + # Create a more readable cache key using a list of key-value pairs + key_parts = [ + f"hashed_api_key={hashed_api_key}", + f"is_async={client_initialization_params.get('is_async')}", + ] + + for param in BaseOpenAILLM.get_openai_client_initialization_param_fields( + client_type=client_type + ): + key_parts.append(f"{param}={client_initialization_params.get(param)}") + + _cache_key = ",".join(key_parts) + + return _cache_key + + @staticmethod + def get_openai_client_initialization_param_fields( + client_type: Literal["openai", "azure"] + ) -> list[str]: + """Returns a list of fields that are used to initialize the OpenAI client""" + import inspect + + from openai import AzureOpenAI, OpenAI + + if client_type == "openai": + signature = inspect.signature(OpenAI.__init__) + else: + signature = inspect.signature(AzureOpenAI.__init__) + + # Extract parameter names, excluding 'self' + param_names = [param for param in signature.parameters if param != "self"] + return param_names + @staticmethod def _get_async_http_client() -> Optional[httpx.AsyncClient]: if litellm.aclient_session is not None: