diff --git a/litellm/router.py b/litellm/router.py index 0cad565b00..142a781bbe 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -63,7 +63,10 @@ from litellm.router_utils.batch_utils import ( _get_router_metadata_variable_name, replace_model_in_jsonl, ) -from litellm.router_utils.client_initalization_utils import InitalizeOpenAISDKClient +from litellm.router_utils.client_initalization_utils import ( + set_client, + should_initialize_sync_client, +) from litellm.router_utils.cooldown_cache import CooldownCache from litellm.router_utils.cooldown_callbacks import router_cooldown_event_callback from litellm.router_utils.cooldown_handlers import ( @@ -3948,7 +3951,7 @@ class Router: raise Exception(f"Unsupported provider - {custom_llm_provider}") # init OpenAI, Azure clients - InitalizeOpenAISDKClient.set_client( + set_client( litellm_router_instance=self, model=deployment.to_json(exclude_none=True) ) @@ -4658,9 +4661,7 @@ class Router: """ Re-initialize the client """ - InitalizeOpenAISDKClient.set_client( - litellm_router_instance=self, model=deployment - ) + set_client(litellm_router_instance=self, model=deployment) client = self.cache.get_cache(key=cache_key, local_only=True) return client else: @@ -4670,9 +4671,7 @@ class Router: """ Re-initialize the client """ - InitalizeOpenAISDKClient.set_client( - litellm_router_instance=self, model=deployment - ) + set_client(litellm_router_instance=self, model=deployment) client = self.cache.get_cache(key=cache_key, local_only=True) return client else: @@ -4683,9 +4682,7 @@ class Router: """ Re-initialize the client """ - InitalizeOpenAISDKClient.set_client( - litellm_router_instance=self, model=deployment - ) + set_client(litellm_router_instance=self, model=deployment) client = self.cache.get_cache(key=cache_key) return client else: @@ -4695,9 +4692,7 @@ class Router: """ Re-initialize the client """ - InitalizeOpenAISDKClient.set_client( - litellm_router_instance=self, model=deployment - ) + set_client(litellm_router_instance=self, model=deployment) client = self.cache.get_cache(key=cache_key) return client diff --git a/litellm/router_utils/client_initalization_utils.py b/litellm/router_utils/client_initalization_utils.py index 679cefadfa..6c845296a8 100644 --- a/litellm/router_utils/client_initalization_utils.py +++ b/litellm/router_utils/client_initalization_utils.py @@ -1,11 +1,10 @@ import asyncio import os import traceback -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional import httpx import openai -from pydantic import BaseModel import litellm from litellm import get_secret, get_secret_str @@ -17,511 +16,89 @@ from litellm.secret_managers.get_azure_ad_token_provider import ( from litellm.utils import calculate_max_parallel_requests if TYPE_CHECKING: - from httpx import Timeout as httpxTimeout - from litellm.router import Router as _Router LitellmRouter = _Router else: LitellmRouter = Any - httpxTimeout = Any -class OpenAISDKClientInitializationParams(BaseModel): - api_key: Optional[str] - api_base: Optional[str] - api_version: Optional[str] - azure_ad_token_provider: Optional[Callable[[], str]] - timeout: Optional[Union[float, httpxTimeout]] - stream_timeout: Optional[Union[float, httpxTimeout]] - max_retries: int - organization: Optional[str] - - # Internal LiteLLM specific params - custom_llm_provider: Optional[str] - model_name: str - - -class InitalizeOpenAISDKClient: +def should_initialize_sync_client( + litellm_router_instance: LitellmRouter, +) -> bool: """ - OpenAI Python SDK requires creating a OpenAI/AzureOpenAI client - this class is responsible for creating that client + Returns if Sync OpenAI, Azure Clients should be initialized. + + Do not init sync clients when router.router_general_settings.async_only_mode is True + """ + if litellm_router_instance is None: + return False - @staticmethod - def should_initialize_sync_client( - litellm_router_instance: LitellmRouter, - ) -> bool: - """ - Returns if Sync OpenAI, Azure Clients should be initialized. - - Do not init sync clients when router.router_general_settings.async_only_mode is True - - """ - if litellm_router_instance is None: + if litellm_router_instance.router_general_settings is not None: + if ( + hasattr(litellm_router_instance, "router_general_settings") + and hasattr( + litellm_router_instance.router_general_settings, "async_only_mode" + ) + and litellm_router_instance.router_general_settings.async_only_mode is True + ): return False - if litellm_router_instance.router_general_settings is not None: - if ( - hasattr(litellm_router_instance, "router_general_settings") - and hasattr( - litellm_router_instance.router_general_settings, "async_only_mode" - ) - and litellm_router_instance.router_general_settings.async_only_mode - is True - ): - return False + return True - return True - @staticmethod - def set_client( # noqa: PLR0915 - litellm_router_instance: LitellmRouter, model: dict - ): - """ - - Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278 - - Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994 - """ - client_ttl = litellm_router_instance.client_ttl - litellm_params = model.get("litellm_params", {}) - model_name = litellm_params.get("model") - model_id = model["model_info"]["id"] - # ### IF RPM SET - initialize a semaphore ### - rpm = litellm_params.get("rpm", None) - tpm = litellm_params.get("tpm", None) - max_parallel_requests = litellm_params.get("max_parallel_requests", None) - calculated_max_parallel_requests = calculate_max_parallel_requests( - rpm=rpm, - max_parallel_requests=max_parallel_requests, - tpm=tpm, - default_max_parallel_requests=litellm_router_instance.default_max_parallel_requests, +def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PLR0915 + """ + - Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278 + - Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994 + """ + client_ttl = litellm_router_instance.client_ttl + litellm_params = model.get("litellm_params", {}) + model_name = litellm_params.get("model") + model_id = model["model_info"]["id"] + # ### IF RPM SET - initialize a semaphore ### + rpm = litellm_params.get("rpm", None) + tpm = litellm_params.get("tpm", None) + max_parallel_requests = litellm_params.get("max_parallel_requests", None) + calculated_max_parallel_requests = calculate_max_parallel_requests( + rpm=rpm, + max_parallel_requests=max_parallel_requests, + tpm=tpm, + default_max_parallel_requests=litellm_router_instance.default_max_parallel_requests, + ) + if calculated_max_parallel_requests: + semaphore = asyncio.Semaphore(calculated_max_parallel_requests) + cache_key = f"{model_id}_max_parallel_requests_client" + litellm_router_instance.cache.set_cache( + key=cache_key, + value=semaphore, + local_only=True, ) - if calculated_max_parallel_requests: - semaphore = asyncio.Semaphore(calculated_max_parallel_requests) - cache_key = f"{model_id}_max_parallel_requests_client" - litellm_router_instance.cache.set_cache( - key=cache_key, - value=semaphore, - local_only=True, - ) - #### for OpenAI / Azure we need to initalize the Client for High Traffic ######## - custom_llm_provider = litellm_params.get("custom_llm_provider") - custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or "" - default_api_base = None - default_api_key = None - if custom_llm_provider in litellm.openai_compatible_providers: - _, custom_llm_provider, api_key, api_base = litellm.get_llm_provider( - model=model_name - ) - default_api_base = api_base - default_api_key = api_key - - if InitalizeOpenAISDKClient._should_create_openai_sdk_client_for_model( - model_name=model_name, - custom_llm_provider=custom_llm_provider, - ): - client_initialization_params = ( - InitalizeOpenAISDKClient._get_client_initialization_params( - model=model, - model_name=model_name, - custom_llm_provider=custom_llm_provider, - litellm_params=litellm_params, - default_api_key=default_api_key, - default_api_base=default_api_base, - ) - ) - - ############### Unpack client initialization params ####################### - api_key = client_initialization_params.api_key - api_base = client_initialization_params.api_base - api_version: Optional[str] = client_initialization_params.api_version - timeout: Optional[Union[float, httpxTimeout]] = ( - client_initialization_params.timeout - ) - stream_timeout: Optional[Union[float, httpxTimeout]] = ( - client_initialization_params.stream_timeout - ) - max_retries: int = client_initialization_params.max_retries - organization: Optional[str] = client_initialization_params.organization - azure_ad_token_provider: Optional[Callable[[], str]] = ( - client_initialization_params.azure_ad_token_provider - ) - custom_llm_provider = client_initialization_params.custom_llm_provider - model_name = client_initialization_params.model_name - ########################################################################## - - if custom_llm_provider == "azure" or custom_llm_provider == "azure_text": - if api_base is None or not isinstance(api_base, str): - filtered_litellm_params = { - k: v - for k, v in model["litellm_params"].items() - if k != "api_key" - } - _filtered_model = { - "model_name": model["model_name"], - "litellm_params": filtered_litellm_params, - } - raise ValueError( - f"api_base is required for Azure OpenAI. Set it on your config. Model - {_filtered_model}" - ) - azure_ad_token = litellm_params.get("azure_ad_token") - if azure_ad_token is not None: - if azure_ad_token.startswith("oidc/"): - azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) - elif ( - azure_ad_token_provider is None - and litellm.enable_azure_ad_token_refresh is True - ): - try: - azure_ad_token_provider = get_azure_ad_token_provider() - except ValueError: - verbose_router_logger.debug( - "Azure AD Token Provider could not be used." - ) - if api_version is None: - api_version = os.getenv( - "AZURE_API_VERSION", litellm.AZURE_DEFAULT_API_VERSION - ) - - if "gateway.ai.cloudflare.com" in api_base: - if not api_base.endswith("/"): - api_base += "/" - azure_model = model_name.replace("azure/", "") - api_base += f"{azure_model}" - cache_key = f"{model_id}_async_client" - _client = openai.AsyncAzureOpenAI( - api_key=api_key, - azure_ad_token=azure_ad_token, - azure_ad_token_provider=azure_ad_token_provider, - base_url=api_base, - api_version=api_version, - timeout=timeout, - max_retries=max_retries, - http_client=httpx.AsyncClient( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), # type: ignore - ) - litellm_router_instance.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - - if InitalizeOpenAISDKClient.should_initialize_sync_client( - litellm_router_instance=litellm_router_instance - ): - cache_key = f"{model_id}_client" - _client = openai.AzureOpenAI( # type: ignore - api_key=api_key, - azure_ad_token=azure_ad_token, - azure_ad_token_provider=azure_ad_token_provider, - base_url=api_base, - api_version=api_version, - timeout=timeout, - max_retries=max_retries, - http_client=httpx.Client( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), # type: ignore - ) - litellm_router_instance.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - # streaming clients can have diff timeouts - cache_key = f"{model_id}_stream_async_client" - _client = openai.AsyncAzureOpenAI( # type: ignore - api_key=api_key, - azure_ad_token=azure_ad_token, - azure_ad_token_provider=azure_ad_token_provider, - base_url=api_base, - api_version=api_version, - timeout=stream_timeout, - max_retries=max_retries, - http_client=httpx.AsyncClient( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), # type: ignore - ) - litellm_router_instance.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - - if InitalizeOpenAISDKClient.should_initialize_sync_client( - litellm_router_instance=litellm_router_instance - ): - cache_key = f"{model_id}_stream_client" - _client = openai.AzureOpenAI( # type: ignore - api_key=api_key, - azure_ad_token=azure_ad_token, - azure_ad_token_provider=azure_ad_token_provider, - base_url=api_base, - api_version=api_version, - timeout=stream_timeout, - max_retries=max_retries, - http_client=httpx.Client( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), # type: ignore - ) - litellm_router_instance.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - else: - _api_key = api_key - if _api_key is not None and isinstance(_api_key, str): - # only show first 5 chars of api_key - _api_key = _api_key[:8] + "*" * 15 - verbose_router_logger.debug( - f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{_api_key}" - ) - azure_client_params = { - "api_key": api_key, - "azure_endpoint": api_base, - "api_version": api_version, - "azure_ad_token": azure_ad_token, - "azure_ad_token_provider": azure_ad_token_provider, - } - - if azure_ad_token_provider is not None: - azure_client_params["azure_ad_token_provider"] = ( - azure_ad_token_provider - ) - from litellm.llms.AzureOpenAI.azure import ( - select_azure_base_url_or_endpoint, - ) - - # this decides if we should set azure_endpoint or base_url on Azure OpenAI Client - # required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client - azure_client_params = select_azure_base_url_or_endpoint( - azure_client_params - ) - - cache_key = f"{model_id}_async_client" - _client = openai.AsyncAzureOpenAI( # type: ignore - **azure_client_params, - timeout=timeout, - max_retries=max_retries, - http_client=httpx.AsyncClient( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), # type: ignore - ) - litellm_router_instance.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - if InitalizeOpenAISDKClient.should_initialize_sync_client( - litellm_router_instance=litellm_router_instance - ): - cache_key = f"{model_id}_client" - _client = openai.AzureOpenAI( # type: ignore - **azure_client_params, - timeout=timeout, - max_retries=max_retries, - http_client=httpx.Client( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), # type: ignore - ) - litellm_router_instance.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - - # streaming clients should have diff timeouts - cache_key = f"{model_id}_stream_async_client" - _client = openai.AsyncAzureOpenAI( # type: ignore - **azure_client_params, - timeout=stream_timeout, - max_retries=max_retries, - http_client=httpx.AsyncClient( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), - ) - litellm_router_instance.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - - if InitalizeOpenAISDKClient.should_initialize_sync_client( - litellm_router_instance=litellm_router_instance - ): - cache_key = f"{model_id}_stream_client" - _client = openai.AzureOpenAI( # type: ignore - **azure_client_params, - timeout=stream_timeout, - max_retries=max_retries, - http_client=httpx.Client( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), - ) - litellm_router_instance.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - - else: - _api_key = api_key # type: ignore - if _api_key is not None and isinstance(_api_key, str): - # only show first 5 chars of api_key - _api_key = _api_key[:8] + "*" * 15 - verbose_router_logger.debug( - f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{_api_key}" - ) - cache_key = f"{model_id}_async_client" - _client = openai.AsyncOpenAI( # type: ignore - api_key=api_key, - base_url=api_base, - timeout=timeout, - max_retries=max_retries, - organization=organization, - http_client=httpx.AsyncClient( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), # type: ignore - ) - litellm_router_instance.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - - if InitalizeOpenAISDKClient.should_initialize_sync_client( - litellm_router_instance=litellm_router_instance - ): - cache_key = f"{model_id}_client" - _client = openai.OpenAI( # type: ignore - api_key=api_key, - base_url=api_base, - timeout=timeout, - max_retries=max_retries, - organization=organization, - http_client=httpx.Client( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), # type: ignore - ) - litellm_router_instance.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - - # streaming clients should have diff timeouts - cache_key = f"{model_id}_stream_async_client" - _client = openai.AsyncOpenAI( # type: ignore - api_key=api_key, - base_url=api_base, - timeout=stream_timeout, - max_retries=max_retries, - organization=organization, - http_client=httpx.AsyncClient( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), # type: ignore - ) - litellm_router_instance.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - - if InitalizeOpenAISDKClient.should_initialize_sync_client( - litellm_router_instance=litellm_router_instance - ): - # streaming clients should have diff timeouts - cache_key = f"{model_id}_stream_client" - _client = openai.OpenAI( # type: ignore - api_key=api_key, - base_url=api_base, - timeout=stream_timeout, - max_retries=max_retries, - organization=organization, - http_client=httpx.Client( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), # type: ignore - ) - litellm_router_instance.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - - @staticmethod - def _get_client_initialization_params( - model: dict, - model_name: str, - custom_llm_provider: Optional[str], - litellm_params: dict, - default_api_key: Optional[str], - default_api_base: Optional[str], - ) -> OpenAISDKClientInitializationParams: - """ - Returns params to use for initializing OpenAI SDK clients (for OpenAI, Azure OpenAI, OpenAI Compatible Providers) - - Args: - model: model dict - model_name: model name - custom_llm_provider: custom llm provider - litellm_params: litellm params - default_api_key: default api key - default_api_base: default api base - - Returns: - OpenAISDKClientInitializationParams - """ + #### for OpenAI / Azure we need to initalize the Client for High Traffic ######## + custom_llm_provider = litellm_params.get("custom_llm_provider") + custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or "" + default_api_base = None + default_api_key = None + if custom_llm_provider in litellm.openai_compatible_providers: + _, custom_llm_provider, api_key, api_base = litellm.get_llm_provider( + model=model_name + ) + default_api_base = api_base + default_api_key = api_key + if ( + model_name in litellm.open_ai_chat_completion_models + or custom_llm_provider in litellm.openai_compatible_providers + or custom_llm_provider == "azure" + or custom_llm_provider == "azure_text" + or custom_llm_provider == "custom_openai" + or custom_llm_provider == "openai" + or custom_llm_provider == "text-completion-openai" + or "ft:gpt-3.5-turbo" in model_name + or model_name in litellm.open_ai_embedding_models + ): is_azure_ai_studio_model: bool = False if custom_llm_provider == "azure": if litellm.utils._is_non_openai_azure_model(model_name): @@ -534,7 +111,8 @@ class InitalizeOpenAISDKClient: # we do this here because we init clients for Azure, OpenAI and we need to set the right key api_key = litellm_params.get("api_key") or default_api_key if api_key and isinstance(api_key, str) and api_key.startswith("os.environ/"): - api_key = get_secret_str(api_key) + api_key_env_name = api_key.replace("os.environ/", "") + api_key = get_secret_str(api_key_env_name) litellm_params["api_key"] = api_key api_base = litellm_params.get("api_base") @@ -543,7 +121,8 @@ class InitalizeOpenAISDKClient: api_base or base_url or default_api_base ) # allow users to pass in `api_base` or `base_url` for azure if api_base and api_base.startswith("os.environ/"): - api_base = get_secret_str(api_base) + api_base_env_name = api_base.replace("os.environ/", "") + api_base = get_secret_str(api_base_env_name) litellm_params["api_base"] = api_base ## AZURE AI STUDIO MISTRAL CHECK ## @@ -568,132 +147,436 @@ class InitalizeOpenAISDKClient: api_version = litellm_params.get("api_version") if api_version and api_version.startswith("os.environ/"): - api_version = get_secret_str(api_version) + api_version_env_name = api_version.replace("os.environ/", "") + api_version = get_secret_str(api_version_env_name) litellm_params["api_version"] = api_version - timeout: Optional[Union[float, str, httpxTimeout]] = ( + timeout: Optional[float] = ( litellm_params.pop("timeout", None) or litellm.request_timeout ) if isinstance(timeout, str) and timeout.startswith("os.environ/"): - timeout = float(get_secret(timeout)) # type: ignore + timeout_env_name = timeout.replace("os.environ/", "") + timeout = get_secret(timeout_env_name) # type: ignore litellm_params["timeout"] = timeout - stream_timeout: Optional[Union[float, str, httpxTimeout]] = litellm_params.pop( + stream_timeout: Optional[float] = litellm_params.pop( "stream_timeout", timeout ) # if no stream_timeout is set, default to timeout if isinstance(stream_timeout, str) and stream_timeout.startswith("os.environ/"): - stream_timeout = float(get_secret(stream_timeout)) # type: ignore + stream_timeout_env_name = stream_timeout.replace("os.environ/", "") + stream_timeout = get_secret(stream_timeout_env_name) # type: ignore litellm_params["stream_timeout"] = stream_timeout max_retries: Optional[int] = litellm_params.pop( "max_retries", 0 ) # router handles retry logic if isinstance(max_retries, str) and max_retries.startswith("os.environ/"): - max_retries = get_secret(max_retries) # type: ignore + max_retries_env_name = max_retries.replace("os.environ/", "") + max_retries = get_secret(max_retries_env_name) # type: ignore litellm_params["max_retries"] = max_retries organization = litellm_params.get("organization", None) if isinstance(organization, str) and organization.startswith("os.environ/"): - organization = get_secret_str(organization) + organization_env_name = organization.replace("os.environ/", "") + organization = get_secret_str(organization_env_name) litellm_params["organization"] = organization azure_ad_token_provider: Optional[Callable[[], str]] = None - tenant_id = litellm_params.get("tenant_id") - if tenant_id is not None: + if litellm_params.get("tenant_id"): verbose_router_logger.debug("Using Azure AD Token Provider for Azure Auth") - azure_ad_token_provider = ( - InitalizeOpenAISDKClient.get_azure_ad_token_from_entrata_id( - tenant_id=tenant_id, - client_id=litellm_params.get("client_id"), - client_secret=litellm_params.get("client_secret"), + azure_ad_token_provider = get_azure_ad_token_from_entrata_id( + tenant_id=litellm_params.get("tenant_id"), + client_id=litellm_params.get("client_id"), + client_secret=litellm_params.get("client_secret"), + ) + + if custom_llm_provider == "azure" or custom_llm_provider == "azure_text": + if api_base is None or not isinstance(api_base, str): + filtered_litellm_params = { + k: v for k, v in model["litellm_params"].items() if k != "api_key" + } + _filtered_model = { + "model_name": model["model_name"], + "litellm_params": filtered_litellm_params, + } + raise ValueError( + f"api_base is required for Azure OpenAI. Set it on your config. Model - {_filtered_model}" ) + azure_ad_token = litellm_params.get("azure_ad_token") + if azure_ad_token is not None: + if azure_ad_token.startswith("oidc/"): + azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) + elif ( + azure_ad_token_provider is None + and litellm.enable_azure_ad_token_refresh is True + ): + try: + azure_ad_token_provider = get_azure_ad_token_provider() + except ValueError: + verbose_router_logger.debug( + "Azure AD Token Provider could not be used." + ) + if api_version is None: + api_version = os.getenv( + "AZURE_API_VERSION", litellm.AZURE_DEFAULT_API_VERSION + ) + + if "gateway.ai.cloudflare.com" in api_base: + if not api_base.endswith("/"): + api_base += "/" + azure_model = model_name.replace("azure/", "") + api_base += f"{azure_model}" + cache_key = f"{model_id}_async_client" + _client = openai.AsyncAzureOpenAI( + api_key=api_key, + azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, + base_url=api_base, + api_version=api_version, + timeout=timeout, # type: ignore + max_retries=max_retries, # type: ignore + http_client=httpx.AsyncClient( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + if should_initialize_sync_client( + litellm_router_instance=litellm_router_instance + ): + cache_key = f"{model_id}_client" + _client = openai.AzureOpenAI( # type: ignore + api_key=api_key, + azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, + base_url=api_base, + api_version=api_version, + timeout=timeout, # type: ignore + max_retries=max_retries, # type: ignore + http_client=httpx.Client( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + # streaming clients can have diff timeouts + cache_key = f"{model_id}_stream_async_client" + _client = openai.AsyncAzureOpenAI( # type: ignore + api_key=api_key, + azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, + base_url=api_base, + api_version=api_version, + timeout=stream_timeout, # type: ignore + max_retries=max_retries, # type: ignore + http_client=httpx.AsyncClient( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + if should_initialize_sync_client( + litellm_router_instance=litellm_router_instance + ): + cache_key = f"{model_id}_stream_client" + _client = openai.AzureOpenAI( # type: ignore + api_key=api_key, + azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, + base_url=api_base, + api_version=api_version, + timeout=stream_timeout, # type: ignore + max_retries=max_retries, # type: ignore + http_client=httpx.Client( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + else: + _api_key = api_key + if _api_key is not None and isinstance(_api_key, str): + # only show first 5 chars of api_key + _api_key = _api_key[:8] + "*" * 15 + verbose_router_logger.debug( + f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{_api_key}" + ) + azure_client_params = { + "api_key": api_key, + "azure_endpoint": api_base, + "api_version": api_version, + "azure_ad_token": azure_ad_token, + "azure_ad_token_provider": azure_ad_token_provider, + } + + if azure_ad_token_provider is not None: + azure_client_params["azure_ad_token_provider"] = ( + azure_ad_token_provider + ) + from litellm.llms.AzureOpenAI.azure import ( + select_azure_base_url_or_endpoint, + ) + + # this decides if we should set azure_endpoint or base_url on Azure OpenAI Client + # required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client + azure_client_params = select_azure_base_url_or_endpoint( + azure_client_params + ) + + cache_key = f"{model_id}_async_client" + _client = openai.AsyncAzureOpenAI( # type: ignore + **azure_client_params, + timeout=timeout, # type: ignore + max_retries=max_retries, # type: ignore + http_client=httpx.AsyncClient( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + if should_initialize_sync_client( + litellm_router_instance=litellm_router_instance + ): + cache_key = f"{model_id}_client" + _client = openai.AzureOpenAI( # type: ignore + **azure_client_params, + timeout=timeout, # type: ignore + max_retries=max_retries, # type: ignore + http_client=httpx.Client( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + # streaming clients should have diff timeouts + cache_key = f"{model_id}_stream_async_client" + _client = openai.AsyncAzureOpenAI( # type: ignore + **azure_client_params, + timeout=stream_timeout, # type: ignore + max_retries=max_retries, # type: ignore + http_client=httpx.AsyncClient( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + if should_initialize_sync_client( + litellm_router_instance=litellm_router_instance + ): + cache_key = f"{model_id}_stream_client" + _client = openai.AzureOpenAI( # type: ignore + **azure_client_params, + timeout=stream_timeout, # type: ignore + max_retries=max_retries, # type: ignore + http_client=httpx.Client( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + else: + _api_key = api_key # type: ignore + if _api_key is not None and isinstance(_api_key, str): + # only show first 5 chars of api_key + _api_key = _api_key[:8] + "*" * 15 + verbose_router_logger.debug( + f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{_api_key}" ) - - return OpenAISDKClientInitializationParams( - api_key=api_key, - api_base=api_base, - api_version=api_version, - azure_ad_token_provider=azure_ad_token_provider, - timeout=timeout, # type: ignore - stream_timeout=stream_timeout, # type: ignore - max_retries=max_retries, # type: ignore - organization=organization, - custom_llm_provider=custom_llm_provider, - model_name=model_name, - ) - - @staticmethod - def _should_create_openai_sdk_client_for_model( - model_name: str, - custom_llm_provider: str, - ) -> bool: - """ - Returns True if a OpenAI SDK client should be created for a given model - - We need a OpenAI SDK client for models that are callsed using OpenAI Python SDK - Azure OpenAI, OpenAI, OpenAI Compatible Providers, OpenAI Embedding Models - """ - if ( - model_name in litellm.open_ai_chat_completion_models - or custom_llm_provider in litellm.openai_compatible_providers - or custom_llm_provider == "azure" - or custom_llm_provider == "azure_text" - or custom_llm_provider == "custom_openai" - or custom_llm_provider == "openai" - or custom_llm_provider == "text-completion-openai" - or "ft:gpt-3.5-turbo" in model_name - or model_name in litellm.open_ai_embedding_models - ): - return True - return False - - @staticmethod - def get_azure_ad_token_from_entrata_id( - tenant_id: str, client_id: Optional[str], client_secret: Optional[str] - ) -> Callable[[], str]: - from azure.identity import ( - ClientSecretCredential, - DefaultAzureCredential, - get_bearer_token_provider, - ) - - if client_id is None or client_secret is None: - raise ValueError( - "client_id and client_secret must be provided when using `tenant_id`" + cache_key = f"{model_id}_async_client" + _client = openai.AsyncOpenAI( # type: ignore + api_key=api_key, + base_url=api_base, + timeout=timeout, # type: ignore + max_retries=max_retries, # type: ignore + organization=organization, + http_client=httpx.AsyncClient( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), # type: ignore ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr - verbose_router_logger.debug("Getting Azure AD Token from Entrata ID") + if should_initialize_sync_client( + litellm_router_instance=litellm_router_instance + ): + cache_key = f"{model_id}_client" + _client = openai.OpenAI( # type: ignore + api_key=api_key, + base_url=api_base, + timeout=timeout, # type: ignore + max_retries=max_retries, # type: ignore + organization=organization, + http_client=httpx.Client( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr - if tenant_id.startswith("os.environ/"): - _tenant_id = get_secret_str(tenant_id) - else: - _tenant_id = tenant_id + # streaming clients should have diff timeouts + cache_key = f"{model_id}_stream_async_client" + _client = openai.AsyncOpenAI( # type: ignore + api_key=api_key, + base_url=api_base, + timeout=stream_timeout, # type: ignore + max_retries=max_retries, # type: ignore + organization=organization, + http_client=httpx.AsyncClient( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr - if client_id.startswith("os.environ/"): - _client_id = get_secret_str(client_id) - else: - _client_id = client_id + if should_initialize_sync_client( + litellm_router_instance=litellm_router_instance + ): + # streaming clients should have diff timeouts + cache_key = f"{model_id}_stream_client" + _client = openai.OpenAI( # type: ignore + api_key=api_key, + base_url=api_base, + timeout=stream_timeout, # type: ignore + max_retries=max_retries, # type: ignore + organization=organization, + http_client=httpx.Client( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr - if client_secret.startswith("os.environ/"): - _client_secret = get_secret_str(client_secret) - else: - _client_secret = client_secret - verbose_router_logger.debug( - "tenant_id %s, client_id %s, client_secret %s", - _tenant_id, - _client_id, - _client_secret, - ) - if _tenant_id is None or _client_id is None or _client_secret is None: - raise ValueError("tenant_id, client_id, and client_secret must be provided") - credential = ClientSecretCredential(_tenant_id, _client_id, _client_secret) +def get_azure_ad_token_from_entrata_id( + tenant_id: str, client_id: str, client_secret: str +) -> Callable[[], str]: + from azure.identity import ( + ClientSecretCredential, + DefaultAzureCredential, + get_bearer_token_provider, + ) - verbose_router_logger.debug("credential %s", credential) + verbose_router_logger.debug("Getting Azure AD Token from Entrata ID") - token_provider = get_bearer_token_provider( - credential, "https://cognitiveservices.azure.com/.default" - ) + if tenant_id.startswith("os.environ/"): + _tenant_id = get_secret_str(tenant_id) + else: + _tenant_id = tenant_id - verbose_router_logger.debug("token_provider %s", token_provider) + if client_id.startswith("os.environ/"): + _client_id = get_secret_str(client_id) + else: + _client_id = client_id - return token_provider + if client_secret.startswith("os.environ/"): + _client_secret = get_secret_str(client_secret) + else: + _client_secret = client_secret + + verbose_router_logger.debug( + "tenant_id %s, client_id %s, client_secret %s", + _tenant_id, + _client_id, + _client_secret, + ) + if _tenant_id is None or _client_id is None or _client_secret is None: + raise ValueError("tenant_id, client_id, and client_secret must be provided") + credential = ClientSecretCredential(_tenant_id, _client_id, _client_secret) + + verbose_router_logger.debug("credential %s", credential) + + token_provider = get_bearer_token_provider( + credential, "https://cognitiveservices.azure.com/.default" + ) + + verbose_router_logger.debug("token_provider %s", token_provider) + + return token_provider diff --git a/tests/local_testing/test_router_init.py b/tests/local_testing/test_router_init.py index 45695f3c3d..3733af252b 100644 --- a/tests/local_testing/test_router_init.py +++ b/tests/local_testing/test_router_init.py @@ -17,10 +17,6 @@ from dotenv import load_dotenv import litellm from litellm import Router -from litellm.router_utils.client_initalization_utils import ( - InitalizeOpenAISDKClient, - OpenAISDKClientInitializationParams, -) load_dotenv() @@ -700,283 +696,3 @@ def test_init_router_with_supported_environments(environment, expected_models): assert set(_model_list) == set(expected_models) os.environ.pop("LITELLM_ENVIRONMENT") - - -def test_should_initialize_sync_client(): - from litellm.types.router import RouterGeneralSettings - - # Test case 1: Router instance is None - assert InitalizeOpenAISDKClient.should_initialize_sync_client(None) is False - - # Test case 2: Router instance without router_general_settings - router = Router(model_list=[]) - assert InitalizeOpenAISDKClient.should_initialize_sync_client(router) is True - - # Test case 3: Router instance with async_only_mode = False - router = Router( - model_list=[], - router_general_settings=RouterGeneralSettings(async_only_mode=False), - ) - assert InitalizeOpenAISDKClient.should_initialize_sync_client(router) is True - - # Test case 4: Router instance with async_only_mode = True - router = Router( - model_list=[], - router_general_settings=RouterGeneralSettings(async_only_mode=True), - ) - assert InitalizeOpenAISDKClient.should_initialize_sync_client(router) is False - - # Test case 5: Router instance with router_general_settings but without async_only_mode - router = Router(model_list=[], router_general_settings=RouterGeneralSettings()) - assert InitalizeOpenAISDKClient.should_initialize_sync_client(router) is True - - print("All test cases passed!") - - -@pytest.mark.parametrize( - "model_name, custom_llm_provider, expected_result", - [ - ("gpt-3.5-turbo", None, True), # OpenAI chat completion model - ("text-embedding-ada-002", None, True), # OpenAI embedding model - ("claude-2", None, False), # Non-OpenAI model - ("gpt-3.5-turbo", "azure", True), # Azure OpenAI - ("text-davinci-003", "azure_text", True), # Azure OpenAI - ("gpt-3.5-turbo", "openai", True), # OpenAI - ("custom-model", "custom_openai", True), # Custom OpenAI compatible - ("text-davinci-003", "text-completion-openai", True), # OpenAI text completion - ( - "ft:gpt-3.5-turbo-0613:my-org:custom-model:7p4lURel", - None, - True, - ), # Fine-tuned GPT model - ("mistral-7b", "huggingface", False), # Non-OpenAI provider - ("custom-model", "anthropic", False), # Non-OpenAI compatible provider - ], -) -def test_should_create_openai_sdk_client_for_model( - model_name, custom_llm_provider, expected_result -): - result = InitalizeOpenAISDKClient._should_create_openai_sdk_client_for_model( - model_name, custom_llm_provider - ) - assert ( - result == expected_result - ), f"Failed for model: {model_name}, provider: {custom_llm_provider}" - - -def test_should_create_openai_sdk_client_for_model_openai_compatible_providers(): - # Test with a known OpenAI compatible provider - assert InitalizeOpenAISDKClient._should_create_openai_sdk_client_for_model( - "custom-model", "groq" - ), "Should return True for OpenAI compatible provider" - - # Add a new compatible provider and test - litellm.openai_compatible_providers.append("new_provider") - assert InitalizeOpenAISDKClient._should_create_openai_sdk_client_for_model( - "custom-model", "new_provider" - ), "Should return True for newly added OpenAI compatible provider" - - # Clean up - litellm.openai_compatible_providers.remove("new_provider") - - -def test_get_client_initialization_params_openai(): - """Test basic OpenAI configuration with direct parameter passing.""" - model = {} - model_name = "gpt-3.5-turbo" - custom_llm_provider = None - litellm_params = {"api_key": "sk-openai-key", "timeout": 30, "max_retries": 3} - default_api_key = None - default_api_base = None - - result = InitalizeOpenAISDKClient._get_client_initialization_params( - model=model, - model_name=model_name, - custom_llm_provider=custom_llm_provider, - litellm_params=litellm_params, - default_api_key=default_api_key, - default_api_base=default_api_base, - ) - - assert isinstance(result, OpenAISDKClientInitializationParams) - assert result.api_key == "sk-openai-key" - assert result.timeout == 30 - assert result.max_retries == 3 - assert result.model_name == "gpt-3.5-turbo" - - -def test_get_client_initialization_params_azure(): - """Test Azure OpenAI configuration with specific Azure parameters.""" - model = {} - model_name = "azure/gpt-4" - custom_llm_provider = "azure" - litellm_params = { - "api_key": "azure-key", - "api_base": "https://example.azure.openai.com", - "api_version": "2023-05-15", - } - default_api_key = None - default_api_base = None - - result = InitalizeOpenAISDKClient._get_client_initialization_params( - model=model, - model_name=model_name, - custom_llm_provider=custom_llm_provider, - litellm_params=litellm_params, - default_api_key=default_api_key, - default_api_base=default_api_base, - ) - - assert result.api_key == "azure-key" - assert result.api_base == "https://example.azure.openai.com" - assert result.api_version == "2023-05-15" - assert result.custom_llm_provider == "azure" - - -def test_get_client_initialization_params_environment_variable_parsing(): - """Test parsing of environment variables for configuration.""" - os.environ["UNIQUE_OPENAI_API_KEY"] = "env-openai-key" - os.environ["UNIQUE_TIMEOUT"] = "45" - - model = {} - model_name = "gpt-4" - custom_llm_provider = None - litellm_params = { - "api_key": "os.environ/UNIQUE_OPENAI_API_KEY", - "timeout": "os.environ/UNIQUE_TIMEOUT", - "organization": "os.environ/UNIQUE_ORG_ID", - } - default_api_key = None - default_api_base = None - - result = InitalizeOpenAISDKClient._get_client_initialization_params( - model=model, - model_name=model_name, - custom_llm_provider=custom_llm_provider, - litellm_params=litellm_params, - default_api_key=default_api_key, - default_api_base=default_api_base, - ) - - assert result.api_key == "env-openai-key" - assert result.timeout == 45.0 - assert result.organization is None # Since ORG_ID is not set in the environment - - -def test_get_client_initialization_params_azure_ai_studio_mistral(): - """ - Test configuration for Azure AI Studio Mistral model. - - - /v1/ is added to the api_base if it is not present - - custom_llm_provider is set to openai (Azure AI Studio Mistral models need to use OpenAI route) - """ - - model = {} - model_name = "azure/mistral-large-latest" - custom_llm_provider = "azure" - litellm_params = { - "api_key": "azure-key", - "api_base": "https://example.azure.openai.com", - } - default_api_key = None - default_api_base = None - - result = InitalizeOpenAISDKClient._get_client_initialization_params( - model, - model_name, - custom_llm_provider, - litellm_params, - default_api_key, - default_api_base, - ) - - assert result.custom_llm_provider == "openai" - assert result.model_name == "mistral-large-latest" - assert result.api_base == "https://example.azure.openai.com/v1/" - - -def test_get_client_initialization_params_default_values(): - """ - Test use of default values when specific parameters are not provided. - - This is used typically for OpenAI compatible providers - example Together AI - - """ - model = {} - model_name = "together/meta-llama-3.1-8b-instruct" - custom_llm_provider = None - litellm_params = {} - default_api_key = "together-api-key" - default_api_base = "https://together.xyz/api.openai.com" - - result = InitalizeOpenAISDKClient._get_client_initialization_params( - model=model, - model_name=model_name, - custom_llm_provider=custom_llm_provider, - litellm_params=litellm_params, - default_api_key=default_api_key, - default_api_base=default_api_base, - ) - - assert result.api_key == "together-api-key" - assert result.api_base == "https://together.xyz/api.openai.com" - assert result.timeout == litellm.request_timeout - assert result.max_retries == 0 - - -def test_get_client_initialization_params_all_env_vars(): - # Set up environment variables - os.environ["TEST_API_KEY"] = "test-api-key" - os.environ["TEST_API_BASE"] = "https://test.openai.com" - os.environ["TEST_API_VERSION"] = "2023-05-15" - os.environ["TEST_TIMEOUT"] = "30" - os.environ["TEST_STREAM_TIMEOUT"] = "60" - os.environ["TEST_MAX_RETRIES"] = "3" - os.environ["TEST_ORGANIZATION"] = "test-org" - - model = {} - model_name = "gpt-4" - custom_llm_provider = None - litellm_params = { - "api_key": "os.environ/TEST_API_KEY", - "api_base": "os.environ/TEST_API_BASE", - "api_version": "os.environ/TEST_API_VERSION", - "timeout": "os.environ/TEST_TIMEOUT", - "stream_timeout": "os.environ/TEST_STREAM_TIMEOUT", - "max_retries": "os.environ/TEST_MAX_RETRIES", - "organization": "os.environ/TEST_ORGANIZATION", - } - default_api_key = None - default_api_base = None - - result = InitalizeOpenAISDKClient._get_client_initialization_params( - model=model, - model_name=model_name, - custom_llm_provider=custom_llm_provider, - litellm_params=litellm_params, - default_api_key=default_api_key, - default_api_base=default_api_base, - ) - - assert isinstance(result, OpenAISDKClientInitializationParams) - assert result.api_key == "test-api-key" - assert result.api_base == "https://test.openai.com" - assert result.api_version == "2023-05-15" - assert result.timeout == 30.0 - assert result.stream_timeout == 60.0 - assert result.max_retries == 3 - assert result.organization == "test-org" - assert result.model_name == "gpt-4" - assert result.custom_llm_provider is None - - # Clean up environment variables - for key in [ - "TEST_API_KEY", - "TEST_API_BASE", - "TEST_API_VERSION", - "TEST_TIMEOUT", - "TEST_STREAM_TIMEOUT", - "TEST_MAX_RETRIES", - "TEST_ORGANIZATION", - ]: - os.environ.pop(key)