mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
use common logic for re-using openai clients
This commit is contained in:
parent
bdf77f6f4b
commit
f73e9047dc
2 changed files with 95 additions and 1 deletions
|
@ -257,7 +257,17 @@ class BaseAzureLLM(BaseOpenAILLM):
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
|
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
|
||||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
|
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
|
||||||
|
client_initialization_params: dict = locals()
|
||||||
if client is None:
|
if client is None:
|
||||||
|
cached_client = self.get_cached_openai_client(
|
||||||
|
client_initialization_params=client_initialization_params,
|
||||||
|
client_type="azure",
|
||||||
|
)
|
||||||
|
if cached_client and isinstance(
|
||||||
|
cached_client, (AzureOpenAI, AsyncAzureOpenAI)
|
||||||
|
):
|
||||||
|
return cached_client
|
||||||
|
|
||||||
azure_client_params = self.initialize_azure_sdk_client(
|
azure_client_params = self.initialize_azure_sdk_client(
|
||||||
litellm_params=litellm_params or {},
|
litellm_params=litellm_params or {},
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -278,6 +288,12 @@ class BaseAzureLLM(BaseOpenAILLM):
|
||||||
# set api_version to version passed by user
|
# set api_version to version passed by user
|
||||||
openai_client._custom_query.setdefault("api-version", api_version)
|
openai_client._custom_query.setdefault("api-version", api_version)
|
||||||
|
|
||||||
|
# save client in-memory cache
|
||||||
|
self.set_cached_openai_client(
|
||||||
|
openai_client=openai_client,
|
||||||
|
client_initialization_params=client_initialization_params,
|
||||||
|
client_type="azure",
|
||||||
|
)
|
||||||
return openai_client
|
return openai_client
|
||||||
|
|
||||||
def initialize_azure_sdk_client(
|
def initialize_azure_sdk_client(
|
||||||
|
|
|
@ -2,14 +2,17 @@
|
||||||
Common helpers / utils across al OpenAI endpoints
|
Common helpers / utils across al OpenAI endpoints
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
import json
|
import json
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import openai
|
import openai
|
||||||
|
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
|
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
|
||||||
|
|
||||||
|
|
||||||
class OpenAIError(BaseLLMException):
|
class OpenAIError(BaseLLMException):
|
||||||
|
@ -100,6 +103,81 @@ class BaseOpenAILLM:
|
||||||
Base class for OpenAI LLMs for getting their httpx clients and SSL verification settings
|
Base class for OpenAI LLMs for getting their httpx clients and SSL verification settings
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_cached_openai_client(
|
||||||
|
client_initialization_params: dict, client_type: Literal["openai", "azure"]
|
||||||
|
) -> Optional[Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]]:
|
||||||
|
"""Retrieves the OpenAI client from the in-memory cache based on the client initialization parameters"""
|
||||||
|
_cache_key = BaseOpenAILLM.get_openai_client_cache_key(
|
||||||
|
client_initialization_params=client_initialization_params,
|
||||||
|
client_type=client_type,
|
||||||
|
)
|
||||||
|
_cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key)
|
||||||
|
return _cached_client
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_cached_openai_client(
|
||||||
|
openai_client: Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI],
|
||||||
|
client_type: Literal["openai", "azure"],
|
||||||
|
client_initialization_params: dict,
|
||||||
|
):
|
||||||
|
"""Stores the OpenAI client in the in-memory cache for _DEFAULT_TTL_FOR_HTTPX_CLIENTS SECONDS"""
|
||||||
|
_cache_key = BaseOpenAILLM.get_openai_client_cache_key(
|
||||||
|
client_initialization_params=client_initialization_params,
|
||||||
|
client_type=client_type,
|
||||||
|
)
|
||||||
|
litellm.in_memory_llm_clients_cache.set_cache(
|
||||||
|
key=_cache_key,
|
||||||
|
value=openai_client,
|
||||||
|
ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_openai_client_cache_key(
|
||||||
|
client_initialization_params: dict, client_type: Literal["openai", "azure"]
|
||||||
|
) -> str:
|
||||||
|
"""Creates a cache key for the OpenAI client based on the client initialization parameters"""
|
||||||
|
hashed_api_key = None
|
||||||
|
if client_initialization_params.get("api_key") is not None:
|
||||||
|
hash_object = hashlib.sha256(
|
||||||
|
client_initialization_params.get("api_key", "").encode()
|
||||||
|
)
|
||||||
|
# Hexadecimal representation of the hash
|
||||||
|
hashed_api_key = hash_object.hexdigest()
|
||||||
|
|
||||||
|
# Create a more readable cache key using a list of key-value pairs
|
||||||
|
key_parts = [
|
||||||
|
f"hashed_api_key={hashed_api_key}",
|
||||||
|
f"is_async={client_initialization_params.get('is_async')}",
|
||||||
|
]
|
||||||
|
|
||||||
|
for param in BaseOpenAILLM.get_openai_client_initialization_param_fields(
|
||||||
|
client_type=client_type
|
||||||
|
):
|
||||||
|
key_parts.append(f"{param}={client_initialization_params.get(param)}")
|
||||||
|
|
||||||
|
_cache_key = ",".join(key_parts)
|
||||||
|
|
||||||
|
return _cache_key
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_openai_client_initialization_param_fields(
|
||||||
|
client_type: Literal["openai", "azure"]
|
||||||
|
) -> list[str]:
|
||||||
|
"""Returns a list of fields that are used to initialize the OpenAI client"""
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from openai import AzureOpenAI, OpenAI
|
||||||
|
|
||||||
|
if client_type == "openai":
|
||||||
|
signature = inspect.signature(OpenAI.__init__)
|
||||||
|
else:
|
||||||
|
signature = inspect.signature(AzureOpenAI.__init__)
|
||||||
|
|
||||||
|
# Extract parameter names, excluding 'self'
|
||||||
|
param_names = [param for param in signature.parameters if param != "self"]
|
||||||
|
return param_names
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_async_http_client() -> Optional[httpx.AsyncClient]:
|
def _get_async_http_client() -> Optional[httpx.AsyncClient]:
|
||||||
if litellm.aclient_session is not None:
|
if litellm.aclient_session is not None:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue