diff --git a/litellm/router.py b/litellm/router.py index 142a781bb..0cad565b0 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -63,10 +63,7 @@ from litellm.router_utils.batch_utils import ( _get_router_metadata_variable_name, replace_model_in_jsonl, ) -from litellm.router_utils.client_initalization_utils import ( - set_client, - should_initialize_sync_client, -) +from litellm.router_utils.client_initalization_utils import InitalizeOpenAISDKClient from litellm.router_utils.cooldown_cache import CooldownCache from litellm.router_utils.cooldown_callbacks import router_cooldown_event_callback from litellm.router_utils.cooldown_handlers import ( @@ -3951,7 +3948,7 @@ class Router: raise Exception(f"Unsupported provider - {custom_llm_provider}") # init OpenAI, Azure clients - set_client( + InitalizeOpenAISDKClient.set_client( litellm_router_instance=self, model=deployment.to_json(exclude_none=True) ) @@ -4661,7 +4658,9 @@ class Router: """ Re-initialize the client """ - set_client(litellm_router_instance=self, model=deployment) + InitalizeOpenAISDKClient.set_client( + litellm_router_instance=self, model=deployment + ) client = self.cache.get_cache(key=cache_key, local_only=True) return client else: @@ -4671,7 +4670,9 @@ class Router: """ Re-initialize the client """ - set_client(litellm_router_instance=self, model=deployment) + InitalizeOpenAISDKClient.set_client( + litellm_router_instance=self, model=deployment + ) client = self.cache.get_cache(key=cache_key, local_only=True) return client else: @@ -4682,7 +4683,9 @@ class Router: """ Re-initialize the client """ - set_client(litellm_router_instance=self, model=deployment) + InitalizeOpenAISDKClient.set_client( + litellm_router_instance=self, model=deployment + ) client = self.cache.get_cache(key=cache_key) return client else: @@ -4692,7 +4695,9 @@ class Router: """ Re-initialize the client """ - set_client(litellm_router_instance=self, model=deployment) + InitalizeOpenAISDKClient.set_client( + litellm_router_instance=self, model=deployment + ) client = self.cache.get_cache(key=cache_key) return client diff --git a/litellm/router_utils/client_initalization_utils.py b/litellm/router_utils/client_initalization_utils.py index 6c845296a..679cefadf 100644 --- a/litellm/router_utils/client_initalization_utils.py +++ b/litellm/router_utils/client_initalization_utils.py @@ -1,10 +1,11 @@ import asyncio import os import traceback -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import httpx import openai +from pydantic import BaseModel import litellm from litellm import get_secret, get_secret_str @@ -16,89 +17,511 @@ from litellm.secret_managers.get_azure_ad_token_provider import ( from litellm.utils import calculate_max_parallel_requests if TYPE_CHECKING: + from httpx import Timeout as httpxTimeout + from litellm.router import Router as _Router LitellmRouter = _Router else: LitellmRouter = Any + httpxTimeout = Any -def should_initialize_sync_client( - litellm_router_instance: LitellmRouter, -) -> bool: +class OpenAISDKClientInitializationParams(BaseModel): + api_key: Optional[str] + api_base: Optional[str] + api_version: Optional[str] + azure_ad_token_provider: Optional[Callable[[], str]] + timeout: Optional[Union[float, httpxTimeout]] + stream_timeout: Optional[Union[float, httpxTimeout]] + max_retries: int + organization: Optional[str] + + # Internal LiteLLM specific params + custom_llm_provider: Optional[str] + model_name: str + + +class InitalizeOpenAISDKClient: """ - Returns if Sync OpenAI, Azure Clients should be initialized. - - Do not init sync clients when router.router_general_settings.async_only_mode is True - + OpenAI Python SDK requires creating a OpenAI/AzureOpenAI client + this class is responsible for creating that client """ - if litellm_router_instance is None: - return False - if litellm_router_instance.router_general_settings is not None: - if ( - hasattr(litellm_router_instance, "router_general_settings") - and hasattr( - litellm_router_instance.router_general_settings, "async_only_mode" - ) - and litellm_router_instance.router_general_settings.async_only_mode is True - ): + @staticmethod + def should_initialize_sync_client( + litellm_router_instance: LitellmRouter, + ) -> bool: + """ + Returns if Sync OpenAI, Azure Clients should be initialized. + + Do not init sync clients when router.router_general_settings.async_only_mode is True + + """ + if litellm_router_instance is None: return False - return True + if litellm_router_instance.router_general_settings is not None: + if ( + hasattr(litellm_router_instance, "router_general_settings") + and hasattr( + litellm_router_instance.router_general_settings, "async_only_mode" + ) + and litellm_router_instance.router_general_settings.async_only_mode + is True + ): + return False + return True -def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PLR0915 - """ - - Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278 - - Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994 - """ - client_ttl = litellm_router_instance.client_ttl - litellm_params = model.get("litellm_params", {}) - model_name = litellm_params.get("model") - model_id = model["model_info"]["id"] - # ### IF RPM SET - initialize a semaphore ### - rpm = litellm_params.get("rpm", None) - tpm = litellm_params.get("tpm", None) - max_parallel_requests = litellm_params.get("max_parallel_requests", None) - calculated_max_parallel_requests = calculate_max_parallel_requests( - rpm=rpm, - max_parallel_requests=max_parallel_requests, - tpm=tpm, - default_max_parallel_requests=litellm_router_instance.default_max_parallel_requests, - ) - if calculated_max_parallel_requests: - semaphore = asyncio.Semaphore(calculated_max_parallel_requests) - cache_key = f"{model_id}_max_parallel_requests_client" - litellm_router_instance.cache.set_cache( - key=cache_key, - value=semaphore, - local_only=True, - ) - - #### for OpenAI / Azure we need to initalize the Client for High Traffic ######## - custom_llm_provider = litellm_params.get("custom_llm_provider") - custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or "" - default_api_base = None - default_api_key = None - if custom_llm_provider in litellm.openai_compatible_providers: - _, custom_llm_provider, api_key, api_base = litellm.get_llm_provider( - model=model_name - ) - default_api_base = api_base - default_api_key = api_key - - if ( - model_name in litellm.open_ai_chat_completion_models - or custom_llm_provider in litellm.openai_compatible_providers - or custom_llm_provider == "azure" - or custom_llm_provider == "azure_text" - or custom_llm_provider == "custom_openai" - or custom_llm_provider == "openai" - or custom_llm_provider == "text-completion-openai" - or "ft:gpt-3.5-turbo" in model_name - or model_name in litellm.open_ai_embedding_models + @staticmethod + def set_client( # noqa: PLR0915 + litellm_router_instance: LitellmRouter, model: dict ): + """ + - Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278 + - Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994 + """ + client_ttl = litellm_router_instance.client_ttl + litellm_params = model.get("litellm_params", {}) + model_name = litellm_params.get("model") + model_id = model["model_info"]["id"] + # ### IF RPM SET - initialize a semaphore ### + rpm = litellm_params.get("rpm", None) + tpm = litellm_params.get("tpm", None) + max_parallel_requests = litellm_params.get("max_parallel_requests", None) + calculated_max_parallel_requests = calculate_max_parallel_requests( + rpm=rpm, + max_parallel_requests=max_parallel_requests, + tpm=tpm, + default_max_parallel_requests=litellm_router_instance.default_max_parallel_requests, + ) + if calculated_max_parallel_requests: + semaphore = asyncio.Semaphore(calculated_max_parallel_requests) + cache_key = f"{model_id}_max_parallel_requests_client" + litellm_router_instance.cache.set_cache( + key=cache_key, + value=semaphore, + local_only=True, + ) + + #### for OpenAI / Azure we need to initalize the Client for High Traffic ######## + custom_llm_provider = litellm_params.get("custom_llm_provider") + custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or "" + default_api_base = None + default_api_key = None + if custom_llm_provider in litellm.openai_compatible_providers: + _, custom_llm_provider, api_key, api_base = litellm.get_llm_provider( + model=model_name + ) + default_api_base = api_base + default_api_key = api_key + + if InitalizeOpenAISDKClient._should_create_openai_sdk_client_for_model( + model_name=model_name, + custom_llm_provider=custom_llm_provider, + ): + client_initialization_params = ( + InitalizeOpenAISDKClient._get_client_initialization_params( + model=model, + model_name=model_name, + custom_llm_provider=custom_llm_provider, + litellm_params=litellm_params, + default_api_key=default_api_key, + default_api_base=default_api_base, + ) + ) + + ############### Unpack client initialization params ####################### + api_key = client_initialization_params.api_key + api_base = client_initialization_params.api_base + api_version: Optional[str] = client_initialization_params.api_version + timeout: Optional[Union[float, httpxTimeout]] = ( + client_initialization_params.timeout + ) + stream_timeout: Optional[Union[float, httpxTimeout]] = ( + client_initialization_params.stream_timeout + ) + max_retries: int = client_initialization_params.max_retries + organization: Optional[str] = client_initialization_params.organization + azure_ad_token_provider: Optional[Callable[[], str]] = ( + client_initialization_params.azure_ad_token_provider + ) + custom_llm_provider = client_initialization_params.custom_llm_provider + model_name = client_initialization_params.model_name + ########################################################################## + + if custom_llm_provider == "azure" or custom_llm_provider == "azure_text": + if api_base is None or not isinstance(api_base, str): + filtered_litellm_params = { + k: v + for k, v in model["litellm_params"].items() + if k != "api_key" + } + _filtered_model = { + "model_name": model["model_name"], + "litellm_params": filtered_litellm_params, + } + raise ValueError( + f"api_base is required for Azure OpenAI. Set it on your config. Model - {_filtered_model}" + ) + azure_ad_token = litellm_params.get("azure_ad_token") + if azure_ad_token is not None: + if azure_ad_token.startswith("oidc/"): + azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) + elif ( + azure_ad_token_provider is None + and litellm.enable_azure_ad_token_refresh is True + ): + try: + azure_ad_token_provider = get_azure_ad_token_provider() + except ValueError: + verbose_router_logger.debug( + "Azure AD Token Provider could not be used." + ) + if api_version is None: + api_version = os.getenv( + "AZURE_API_VERSION", litellm.AZURE_DEFAULT_API_VERSION + ) + + if "gateway.ai.cloudflare.com" in api_base: + if not api_base.endswith("/"): + api_base += "/" + azure_model = model_name.replace("azure/", "") + api_base += f"{azure_model}" + cache_key = f"{model_id}_async_client" + _client = openai.AsyncAzureOpenAI( + api_key=api_key, + azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, + base_url=api_base, + api_version=api_version, + timeout=timeout, + max_retries=max_retries, + http_client=httpx.AsyncClient( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + if InitalizeOpenAISDKClient.should_initialize_sync_client( + litellm_router_instance=litellm_router_instance + ): + cache_key = f"{model_id}_client" + _client = openai.AzureOpenAI( # type: ignore + api_key=api_key, + azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, + base_url=api_base, + api_version=api_version, + timeout=timeout, + max_retries=max_retries, + http_client=httpx.Client( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + # streaming clients can have diff timeouts + cache_key = f"{model_id}_stream_async_client" + _client = openai.AsyncAzureOpenAI( # type: ignore + api_key=api_key, + azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, + base_url=api_base, + api_version=api_version, + timeout=stream_timeout, + max_retries=max_retries, + http_client=httpx.AsyncClient( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + if InitalizeOpenAISDKClient.should_initialize_sync_client( + litellm_router_instance=litellm_router_instance + ): + cache_key = f"{model_id}_stream_client" + _client = openai.AzureOpenAI( # type: ignore + api_key=api_key, + azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, + base_url=api_base, + api_version=api_version, + timeout=stream_timeout, + max_retries=max_retries, + http_client=httpx.Client( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + else: + _api_key = api_key + if _api_key is not None and isinstance(_api_key, str): + # only show first 5 chars of api_key + _api_key = _api_key[:8] + "*" * 15 + verbose_router_logger.debug( + f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{_api_key}" + ) + azure_client_params = { + "api_key": api_key, + "azure_endpoint": api_base, + "api_version": api_version, + "azure_ad_token": azure_ad_token, + "azure_ad_token_provider": azure_ad_token_provider, + } + + if azure_ad_token_provider is not None: + azure_client_params["azure_ad_token_provider"] = ( + azure_ad_token_provider + ) + from litellm.llms.AzureOpenAI.azure import ( + select_azure_base_url_or_endpoint, + ) + + # this decides if we should set azure_endpoint or base_url on Azure OpenAI Client + # required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client + azure_client_params = select_azure_base_url_or_endpoint( + azure_client_params + ) + + cache_key = f"{model_id}_async_client" + _client = openai.AsyncAzureOpenAI( # type: ignore + **azure_client_params, + timeout=timeout, + max_retries=max_retries, + http_client=httpx.AsyncClient( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + if InitalizeOpenAISDKClient.should_initialize_sync_client( + litellm_router_instance=litellm_router_instance + ): + cache_key = f"{model_id}_client" + _client = openai.AzureOpenAI( # type: ignore + **azure_client_params, + timeout=timeout, + max_retries=max_retries, + http_client=httpx.Client( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + # streaming clients should have diff timeouts + cache_key = f"{model_id}_stream_async_client" + _client = openai.AsyncAzureOpenAI( # type: ignore + **azure_client_params, + timeout=stream_timeout, + max_retries=max_retries, + http_client=httpx.AsyncClient( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + if InitalizeOpenAISDKClient.should_initialize_sync_client( + litellm_router_instance=litellm_router_instance + ): + cache_key = f"{model_id}_stream_client" + _client = openai.AzureOpenAI( # type: ignore + **azure_client_params, + timeout=stream_timeout, + max_retries=max_retries, + http_client=httpx.Client( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + else: + _api_key = api_key # type: ignore + if _api_key is not None and isinstance(_api_key, str): + # only show first 5 chars of api_key + _api_key = _api_key[:8] + "*" * 15 + verbose_router_logger.debug( + f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{_api_key}" + ) + cache_key = f"{model_id}_async_client" + _client = openai.AsyncOpenAI( # type: ignore + api_key=api_key, + base_url=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + http_client=httpx.AsyncClient( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + if InitalizeOpenAISDKClient.should_initialize_sync_client( + litellm_router_instance=litellm_router_instance + ): + cache_key = f"{model_id}_client" + _client = openai.OpenAI( # type: ignore + api_key=api_key, + base_url=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + http_client=httpx.Client( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + # streaming clients should have diff timeouts + cache_key = f"{model_id}_stream_async_client" + _client = openai.AsyncOpenAI( # type: ignore + api_key=api_key, + base_url=api_base, + timeout=stream_timeout, + max_retries=max_retries, + organization=organization, + http_client=httpx.AsyncClient( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + if InitalizeOpenAISDKClient.should_initialize_sync_client( + litellm_router_instance=litellm_router_instance + ): + # streaming clients should have diff timeouts + cache_key = f"{model_id}_stream_client" + _client = openai.OpenAI( # type: ignore + api_key=api_key, + base_url=api_base, + timeout=stream_timeout, + max_retries=max_retries, + organization=organization, + http_client=httpx.Client( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + @staticmethod + def _get_client_initialization_params( + model: dict, + model_name: str, + custom_llm_provider: Optional[str], + litellm_params: dict, + default_api_key: Optional[str], + default_api_base: Optional[str], + ) -> OpenAISDKClientInitializationParams: + """ + Returns params to use for initializing OpenAI SDK clients (for OpenAI, Azure OpenAI, OpenAI Compatible Providers) + + Args: + model: model dict + model_name: model name + custom_llm_provider: custom llm provider + litellm_params: litellm params + default_api_key: default api key + default_api_base: default api base + + Returns: + OpenAISDKClientInitializationParams + """ + is_azure_ai_studio_model: bool = False if custom_llm_provider == "azure": if litellm.utils._is_non_openai_azure_model(model_name): @@ -111,8 +534,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL # we do this here because we init clients for Azure, OpenAI and we need to set the right key api_key = litellm_params.get("api_key") or default_api_key if api_key and isinstance(api_key, str) and api_key.startswith("os.environ/"): - api_key_env_name = api_key.replace("os.environ/", "") - api_key = get_secret_str(api_key_env_name) + api_key = get_secret_str(api_key) litellm_params["api_key"] = api_key api_base = litellm_params.get("api_base") @@ -121,8 +543,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL api_base or base_url or default_api_base ) # allow users to pass in `api_base` or `base_url` for azure if api_base and api_base.startswith("os.environ/"): - api_base_env_name = api_base.replace("os.environ/", "") - api_base = get_secret_str(api_base_env_name) + api_base = get_secret_str(api_base) litellm_params["api_base"] = api_base ## AZURE AI STUDIO MISTRAL CHECK ## @@ -147,436 +568,132 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): # noqa: PL api_version = litellm_params.get("api_version") if api_version and api_version.startswith("os.environ/"): - api_version_env_name = api_version.replace("os.environ/", "") - api_version = get_secret_str(api_version_env_name) + api_version = get_secret_str(api_version) litellm_params["api_version"] = api_version - timeout: Optional[float] = ( + timeout: Optional[Union[float, str, httpxTimeout]] = ( litellm_params.pop("timeout", None) or litellm.request_timeout ) if isinstance(timeout, str) and timeout.startswith("os.environ/"): - timeout_env_name = timeout.replace("os.environ/", "") - timeout = get_secret(timeout_env_name) # type: ignore + timeout = float(get_secret(timeout)) # type: ignore litellm_params["timeout"] = timeout - stream_timeout: Optional[float] = litellm_params.pop( + stream_timeout: Optional[Union[float, str, httpxTimeout]] = litellm_params.pop( "stream_timeout", timeout ) # if no stream_timeout is set, default to timeout if isinstance(stream_timeout, str) and stream_timeout.startswith("os.environ/"): - stream_timeout_env_name = stream_timeout.replace("os.environ/", "") - stream_timeout = get_secret(stream_timeout_env_name) # type: ignore + stream_timeout = float(get_secret(stream_timeout)) # type: ignore litellm_params["stream_timeout"] = stream_timeout max_retries: Optional[int] = litellm_params.pop( "max_retries", 0 ) # router handles retry logic if isinstance(max_retries, str) and max_retries.startswith("os.environ/"): - max_retries_env_name = max_retries.replace("os.environ/", "") - max_retries = get_secret(max_retries_env_name) # type: ignore + max_retries = get_secret(max_retries) # type: ignore litellm_params["max_retries"] = max_retries organization = litellm_params.get("organization", None) if isinstance(organization, str) and organization.startswith("os.environ/"): - organization_env_name = organization.replace("os.environ/", "") - organization = get_secret_str(organization_env_name) + organization = get_secret_str(organization) litellm_params["organization"] = organization azure_ad_token_provider: Optional[Callable[[], str]] = None - if litellm_params.get("tenant_id"): + tenant_id = litellm_params.get("tenant_id") + if tenant_id is not None: verbose_router_logger.debug("Using Azure AD Token Provider for Azure Auth") - azure_ad_token_provider = get_azure_ad_token_from_entrata_id( - tenant_id=litellm_params.get("tenant_id"), - client_id=litellm_params.get("client_id"), - client_secret=litellm_params.get("client_secret"), + azure_ad_token_provider = ( + InitalizeOpenAISDKClient.get_azure_ad_token_from_entrata_id( + tenant_id=tenant_id, + client_id=litellm_params.get("client_id"), + client_secret=litellm_params.get("client_secret"), + ) ) - if custom_llm_provider == "azure" or custom_llm_provider == "azure_text": - if api_base is None or not isinstance(api_base, str): - filtered_litellm_params = { - k: v for k, v in model["litellm_params"].items() if k != "api_key" - } - _filtered_model = { - "model_name": model["model_name"], - "litellm_params": filtered_litellm_params, - } - raise ValueError( - f"api_base is required for Azure OpenAI. Set it on your config. Model - {_filtered_model}" - ) - azure_ad_token = litellm_params.get("azure_ad_token") - if azure_ad_token is not None: - if azure_ad_token.startswith("oidc/"): - azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) - elif ( - azure_ad_token_provider is None - and litellm.enable_azure_ad_token_refresh is True - ): - try: - azure_ad_token_provider = get_azure_ad_token_provider() - except ValueError: - verbose_router_logger.debug( - "Azure AD Token Provider could not be used." - ) - if api_version is None: - api_version = os.getenv( - "AZURE_API_VERSION", litellm.AZURE_DEFAULT_API_VERSION - ) + return OpenAISDKClientInitializationParams( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token_provider=azure_ad_token_provider, + timeout=timeout, # type: ignore + stream_timeout=stream_timeout, # type: ignore + max_retries=max_retries, # type: ignore + organization=organization, + custom_llm_provider=custom_llm_provider, + model_name=model_name, + ) - if "gateway.ai.cloudflare.com" in api_base: - if not api_base.endswith("/"): - api_base += "/" - azure_model = model_name.replace("azure/", "") - api_base += f"{azure_model}" - cache_key = f"{model_id}_async_client" - _client = openai.AsyncAzureOpenAI( - api_key=api_key, - azure_ad_token=azure_ad_token, - azure_ad_token_provider=azure_ad_token_provider, - base_url=api_base, - api_version=api_version, - timeout=timeout, # type: ignore - max_retries=max_retries, # type: ignore - http_client=httpx.AsyncClient( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), # type: ignore - ) - litellm_router_instance.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr + @staticmethod + def _should_create_openai_sdk_client_for_model( + model_name: str, + custom_llm_provider: str, + ) -> bool: + """ + Returns True if a OpenAI SDK client should be created for a given model - if should_initialize_sync_client( - litellm_router_instance=litellm_router_instance - ): - cache_key = f"{model_id}_client" - _client = openai.AzureOpenAI( # type: ignore - api_key=api_key, - azure_ad_token=azure_ad_token, - azure_ad_token_provider=azure_ad_token_provider, - base_url=api_base, - api_version=api_version, - timeout=timeout, # type: ignore - max_retries=max_retries, # type: ignore - http_client=httpx.Client( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), # type: ignore - ) - litellm_router_instance.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - # streaming clients can have diff timeouts - cache_key = f"{model_id}_stream_async_client" - _client = openai.AsyncAzureOpenAI( # type: ignore - api_key=api_key, - azure_ad_token=azure_ad_token, - azure_ad_token_provider=azure_ad_token_provider, - base_url=api_base, - api_version=api_version, - timeout=stream_timeout, # type: ignore - max_retries=max_retries, # type: ignore - http_client=httpx.AsyncClient( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), # type: ignore - ) - litellm_router_instance.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr + We need a OpenAI SDK client for models that are callsed using OpenAI Python SDK + Azure OpenAI, OpenAI, OpenAI Compatible Providers, OpenAI Embedding Models + """ + if ( + model_name in litellm.open_ai_chat_completion_models + or custom_llm_provider in litellm.openai_compatible_providers + or custom_llm_provider == "azure" + or custom_llm_provider == "azure_text" + or custom_llm_provider == "custom_openai" + or custom_llm_provider == "openai" + or custom_llm_provider == "text-completion-openai" + or "ft:gpt-3.5-turbo" in model_name + or model_name in litellm.open_ai_embedding_models + ): + return True + return False - if should_initialize_sync_client( - litellm_router_instance=litellm_router_instance - ): - cache_key = f"{model_id}_stream_client" - _client = openai.AzureOpenAI( # type: ignore - api_key=api_key, - azure_ad_token=azure_ad_token, - azure_ad_token_provider=azure_ad_token_provider, - base_url=api_base, - api_version=api_version, - timeout=stream_timeout, # type: ignore - max_retries=max_retries, # type: ignore - http_client=httpx.Client( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), # type: ignore - ) - litellm_router_instance.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - else: - _api_key = api_key - if _api_key is not None and isinstance(_api_key, str): - # only show first 5 chars of api_key - _api_key = _api_key[:8] + "*" * 15 - verbose_router_logger.debug( - f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{_api_key}" - ) - azure_client_params = { - "api_key": api_key, - "azure_endpoint": api_base, - "api_version": api_version, - "azure_ad_token": azure_ad_token, - "azure_ad_token_provider": azure_ad_token_provider, - } + @staticmethod + def get_azure_ad_token_from_entrata_id( + tenant_id: str, client_id: Optional[str], client_secret: Optional[str] + ) -> Callable[[], str]: + from azure.identity import ( + ClientSecretCredential, + DefaultAzureCredential, + get_bearer_token_provider, + ) - if azure_ad_token_provider is not None: - azure_client_params["azure_ad_token_provider"] = ( - azure_ad_token_provider - ) - from litellm.llms.AzureOpenAI.azure import ( - select_azure_base_url_or_endpoint, - ) + if client_id is None or client_secret is None: + raise ValueError( + "client_id and client_secret must be provided when using `tenant_id`" + ) - # this decides if we should set azure_endpoint or base_url on Azure OpenAI Client - # required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client - azure_client_params = select_azure_base_url_or_endpoint( - azure_client_params - ) - - cache_key = f"{model_id}_async_client" - _client = openai.AsyncAzureOpenAI( # type: ignore - **azure_client_params, - timeout=timeout, # type: ignore - max_retries=max_retries, # type: ignore - http_client=httpx.AsyncClient( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), # type: ignore - ) - litellm_router_instance.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - if should_initialize_sync_client( - litellm_router_instance=litellm_router_instance - ): - cache_key = f"{model_id}_client" - _client = openai.AzureOpenAI( # type: ignore - **azure_client_params, - timeout=timeout, # type: ignore - max_retries=max_retries, # type: ignore - http_client=httpx.Client( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), # type: ignore - ) - litellm_router_instance.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - - # streaming clients should have diff timeouts - cache_key = f"{model_id}_stream_async_client" - _client = openai.AsyncAzureOpenAI( # type: ignore - **azure_client_params, - timeout=stream_timeout, # type: ignore - max_retries=max_retries, # type: ignore - http_client=httpx.AsyncClient( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), - ) - litellm_router_instance.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - - if should_initialize_sync_client( - litellm_router_instance=litellm_router_instance - ): - cache_key = f"{model_id}_stream_client" - _client = openai.AzureOpenAI( # type: ignore - **azure_client_params, - timeout=stream_timeout, # type: ignore - max_retries=max_retries, # type: ignore - http_client=httpx.Client( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), - ) - litellm_router_instance.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr + verbose_router_logger.debug("Getting Azure AD Token from Entrata ID") + if tenant_id.startswith("os.environ/"): + _tenant_id = get_secret_str(tenant_id) else: - _api_key = api_key # type: ignore - if _api_key is not None and isinstance(_api_key, str): - # only show first 5 chars of api_key - _api_key = _api_key[:8] + "*" * 15 - verbose_router_logger.debug( - f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{_api_key}" - ) - cache_key = f"{model_id}_async_client" - _client = openai.AsyncOpenAI( # type: ignore - api_key=api_key, - base_url=api_base, - timeout=timeout, # type: ignore - max_retries=max_retries, # type: ignore - organization=organization, - http_client=httpx.AsyncClient( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), # type: ignore - ) - litellm_router_instance.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr + _tenant_id = tenant_id - if should_initialize_sync_client( - litellm_router_instance=litellm_router_instance - ): - cache_key = f"{model_id}_client" - _client = openai.OpenAI( # type: ignore - api_key=api_key, - base_url=api_base, - timeout=timeout, # type: ignore - max_retries=max_retries, # type: ignore - organization=organization, - http_client=httpx.Client( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), # type: ignore - ) - litellm_router_instance.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr + if client_id.startswith("os.environ/"): + _client_id = get_secret_str(client_id) + else: + _client_id = client_id - # streaming clients should have diff timeouts - cache_key = f"{model_id}_stream_async_client" - _client = openai.AsyncOpenAI( # type: ignore - api_key=api_key, - base_url=api_base, - timeout=stream_timeout, # type: ignore - max_retries=max_retries, # type: ignore - organization=organization, - http_client=httpx.AsyncClient( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), # type: ignore - ) - litellm_router_instance.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr + if client_secret.startswith("os.environ/"): + _client_secret = get_secret_str(client_secret) + else: + _client_secret = client_secret - if should_initialize_sync_client( - litellm_router_instance=litellm_router_instance - ): - # streaming clients should have diff timeouts - cache_key = f"{model_id}_stream_client" - _client = openai.OpenAI( # type: ignore - api_key=api_key, - base_url=api_base, - timeout=stream_timeout, # type: ignore - max_retries=max_retries, # type: ignore - organization=organization, - http_client=httpx.Client( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), # type: ignore - ) - litellm_router_instance.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr + verbose_router_logger.debug( + "tenant_id %s, client_id %s, client_secret %s", + _tenant_id, + _client_id, + _client_secret, + ) + if _tenant_id is None or _client_id is None or _client_secret is None: + raise ValueError("tenant_id, client_id, and client_secret must be provided") + credential = ClientSecretCredential(_tenant_id, _client_id, _client_secret) + verbose_router_logger.debug("credential %s", credential) -def get_azure_ad_token_from_entrata_id( - tenant_id: str, client_id: str, client_secret: str -) -> Callable[[], str]: - from azure.identity import ( - ClientSecretCredential, - DefaultAzureCredential, - get_bearer_token_provider, - ) + token_provider = get_bearer_token_provider( + credential, "https://cognitiveservices.azure.com/.default" + ) - verbose_router_logger.debug("Getting Azure AD Token from Entrata ID") + verbose_router_logger.debug("token_provider %s", token_provider) - if tenant_id.startswith("os.environ/"): - _tenant_id = get_secret_str(tenant_id) - else: - _tenant_id = tenant_id - - if client_id.startswith("os.environ/"): - _client_id = get_secret_str(client_id) - else: - _client_id = client_id - - if client_secret.startswith("os.environ/"): - _client_secret = get_secret_str(client_secret) - else: - _client_secret = client_secret - - verbose_router_logger.debug( - "tenant_id %s, client_id %s, client_secret %s", - _tenant_id, - _client_id, - _client_secret, - ) - if _tenant_id is None or _client_id is None or _client_secret is None: - raise ValueError("tenant_id, client_id, and client_secret must be provided") - credential = ClientSecretCredential(_tenant_id, _client_id, _client_secret) - - verbose_router_logger.debug("credential %s", credential) - - token_provider = get_bearer_token_provider( - credential, "https://cognitiveservices.azure.com/.default" - ) - - verbose_router_logger.debug("token_provider %s", token_provider) - - return token_provider + return token_provider diff --git a/tests/local_testing/test_router_init.py b/tests/local_testing/test_router_init.py index 3733af252..45695f3c3 100644 --- a/tests/local_testing/test_router_init.py +++ b/tests/local_testing/test_router_init.py @@ -17,6 +17,10 @@ from dotenv import load_dotenv import litellm from litellm import Router +from litellm.router_utils.client_initalization_utils import ( + InitalizeOpenAISDKClient, + OpenAISDKClientInitializationParams, +) load_dotenv() @@ -696,3 +700,283 @@ def test_init_router_with_supported_environments(environment, expected_models): assert set(_model_list) == set(expected_models) os.environ.pop("LITELLM_ENVIRONMENT") + + +def test_should_initialize_sync_client(): + from litellm.types.router import RouterGeneralSettings + + # Test case 1: Router instance is None + assert InitalizeOpenAISDKClient.should_initialize_sync_client(None) is False + + # Test case 2: Router instance without router_general_settings + router = Router(model_list=[]) + assert InitalizeOpenAISDKClient.should_initialize_sync_client(router) is True + + # Test case 3: Router instance with async_only_mode = False + router = Router( + model_list=[], + router_general_settings=RouterGeneralSettings(async_only_mode=False), + ) + assert InitalizeOpenAISDKClient.should_initialize_sync_client(router) is True + + # Test case 4: Router instance with async_only_mode = True + router = Router( + model_list=[], + router_general_settings=RouterGeneralSettings(async_only_mode=True), + ) + assert InitalizeOpenAISDKClient.should_initialize_sync_client(router) is False + + # Test case 5: Router instance with router_general_settings but without async_only_mode + router = Router(model_list=[], router_general_settings=RouterGeneralSettings()) + assert InitalizeOpenAISDKClient.should_initialize_sync_client(router) is True + + print("All test cases passed!") + + +@pytest.mark.parametrize( + "model_name, custom_llm_provider, expected_result", + [ + ("gpt-3.5-turbo", None, True), # OpenAI chat completion model + ("text-embedding-ada-002", None, True), # OpenAI embedding model + ("claude-2", None, False), # Non-OpenAI model + ("gpt-3.5-turbo", "azure", True), # Azure OpenAI + ("text-davinci-003", "azure_text", True), # Azure OpenAI + ("gpt-3.5-turbo", "openai", True), # OpenAI + ("custom-model", "custom_openai", True), # Custom OpenAI compatible + ("text-davinci-003", "text-completion-openai", True), # OpenAI text completion + ( + "ft:gpt-3.5-turbo-0613:my-org:custom-model:7p4lURel", + None, + True, + ), # Fine-tuned GPT model + ("mistral-7b", "huggingface", False), # Non-OpenAI provider + ("custom-model", "anthropic", False), # Non-OpenAI compatible provider + ], +) +def test_should_create_openai_sdk_client_for_model( + model_name, custom_llm_provider, expected_result +): + result = InitalizeOpenAISDKClient._should_create_openai_sdk_client_for_model( + model_name, custom_llm_provider + ) + assert ( + result == expected_result + ), f"Failed for model: {model_name}, provider: {custom_llm_provider}" + + +def test_should_create_openai_sdk_client_for_model_openai_compatible_providers(): + # Test with a known OpenAI compatible provider + assert InitalizeOpenAISDKClient._should_create_openai_sdk_client_for_model( + "custom-model", "groq" + ), "Should return True for OpenAI compatible provider" + + # Add a new compatible provider and test + litellm.openai_compatible_providers.append("new_provider") + assert InitalizeOpenAISDKClient._should_create_openai_sdk_client_for_model( + "custom-model", "new_provider" + ), "Should return True for newly added OpenAI compatible provider" + + # Clean up + litellm.openai_compatible_providers.remove("new_provider") + + +def test_get_client_initialization_params_openai(): + """Test basic OpenAI configuration with direct parameter passing.""" + model = {} + model_name = "gpt-3.5-turbo" + custom_llm_provider = None + litellm_params = {"api_key": "sk-openai-key", "timeout": 30, "max_retries": 3} + default_api_key = None + default_api_base = None + + result = InitalizeOpenAISDKClient._get_client_initialization_params( + model=model, + model_name=model_name, + custom_llm_provider=custom_llm_provider, + litellm_params=litellm_params, + default_api_key=default_api_key, + default_api_base=default_api_base, + ) + + assert isinstance(result, OpenAISDKClientInitializationParams) + assert result.api_key == "sk-openai-key" + assert result.timeout == 30 + assert result.max_retries == 3 + assert result.model_name == "gpt-3.5-turbo" + + +def test_get_client_initialization_params_azure(): + """Test Azure OpenAI configuration with specific Azure parameters.""" + model = {} + model_name = "azure/gpt-4" + custom_llm_provider = "azure" + litellm_params = { + "api_key": "azure-key", + "api_base": "https://example.azure.openai.com", + "api_version": "2023-05-15", + } + default_api_key = None + default_api_base = None + + result = InitalizeOpenAISDKClient._get_client_initialization_params( + model=model, + model_name=model_name, + custom_llm_provider=custom_llm_provider, + litellm_params=litellm_params, + default_api_key=default_api_key, + default_api_base=default_api_base, + ) + + assert result.api_key == "azure-key" + assert result.api_base == "https://example.azure.openai.com" + assert result.api_version == "2023-05-15" + assert result.custom_llm_provider == "azure" + + +def test_get_client_initialization_params_environment_variable_parsing(): + """Test parsing of environment variables for configuration.""" + os.environ["UNIQUE_OPENAI_API_KEY"] = "env-openai-key" + os.environ["UNIQUE_TIMEOUT"] = "45" + + model = {} + model_name = "gpt-4" + custom_llm_provider = None + litellm_params = { + "api_key": "os.environ/UNIQUE_OPENAI_API_KEY", + "timeout": "os.environ/UNIQUE_TIMEOUT", + "organization": "os.environ/UNIQUE_ORG_ID", + } + default_api_key = None + default_api_base = None + + result = InitalizeOpenAISDKClient._get_client_initialization_params( + model=model, + model_name=model_name, + custom_llm_provider=custom_llm_provider, + litellm_params=litellm_params, + default_api_key=default_api_key, + default_api_base=default_api_base, + ) + + assert result.api_key == "env-openai-key" + assert result.timeout == 45.0 + assert result.organization is None # Since ORG_ID is not set in the environment + + +def test_get_client_initialization_params_azure_ai_studio_mistral(): + """ + Test configuration for Azure AI Studio Mistral model. + + - /v1/ is added to the api_base if it is not present + - custom_llm_provider is set to openai (Azure AI Studio Mistral models need to use OpenAI route) + """ + + model = {} + model_name = "azure/mistral-large-latest" + custom_llm_provider = "azure" + litellm_params = { + "api_key": "azure-key", + "api_base": "https://example.azure.openai.com", + } + default_api_key = None + default_api_base = None + + result = InitalizeOpenAISDKClient._get_client_initialization_params( + model, + model_name, + custom_llm_provider, + litellm_params, + default_api_key, + default_api_base, + ) + + assert result.custom_llm_provider == "openai" + assert result.model_name == "mistral-large-latest" + assert result.api_base == "https://example.azure.openai.com/v1/" + + +def test_get_client_initialization_params_default_values(): + """ + Test use of default values when specific parameters are not provided. + + This is used typically for OpenAI compatible providers - example Together AI + + """ + model = {} + model_name = "together/meta-llama-3.1-8b-instruct" + custom_llm_provider = None + litellm_params = {} + default_api_key = "together-api-key" + default_api_base = "https://together.xyz/api.openai.com" + + result = InitalizeOpenAISDKClient._get_client_initialization_params( + model=model, + model_name=model_name, + custom_llm_provider=custom_llm_provider, + litellm_params=litellm_params, + default_api_key=default_api_key, + default_api_base=default_api_base, + ) + + assert result.api_key == "together-api-key" + assert result.api_base == "https://together.xyz/api.openai.com" + assert result.timeout == litellm.request_timeout + assert result.max_retries == 0 + + +def test_get_client_initialization_params_all_env_vars(): + # Set up environment variables + os.environ["TEST_API_KEY"] = "test-api-key" + os.environ["TEST_API_BASE"] = "https://test.openai.com" + os.environ["TEST_API_VERSION"] = "2023-05-15" + os.environ["TEST_TIMEOUT"] = "30" + os.environ["TEST_STREAM_TIMEOUT"] = "60" + os.environ["TEST_MAX_RETRIES"] = "3" + os.environ["TEST_ORGANIZATION"] = "test-org" + + model = {} + model_name = "gpt-4" + custom_llm_provider = None + litellm_params = { + "api_key": "os.environ/TEST_API_KEY", + "api_base": "os.environ/TEST_API_BASE", + "api_version": "os.environ/TEST_API_VERSION", + "timeout": "os.environ/TEST_TIMEOUT", + "stream_timeout": "os.environ/TEST_STREAM_TIMEOUT", + "max_retries": "os.environ/TEST_MAX_RETRIES", + "organization": "os.environ/TEST_ORGANIZATION", + } + default_api_key = None + default_api_base = None + + result = InitalizeOpenAISDKClient._get_client_initialization_params( + model=model, + model_name=model_name, + custom_llm_provider=custom_llm_provider, + litellm_params=litellm_params, + default_api_key=default_api_key, + default_api_base=default_api_base, + ) + + assert isinstance(result, OpenAISDKClientInitializationParams) + assert result.api_key == "test-api-key" + assert result.api_base == "https://test.openai.com" + assert result.api_version == "2023-05-15" + assert result.timeout == 30.0 + assert result.stream_timeout == 60.0 + assert result.max_retries == 3 + assert result.organization == "test-org" + assert result.model_name == "gpt-4" + assert result.custom_llm_provider is None + + # Clean up environment variables + for key in [ + "TEST_API_KEY", + "TEST_API_BASE", + "TEST_API_VERSION", + "TEST_TIMEOUT", + "TEST_STREAM_TIMEOUT", + "TEST_MAX_RETRIES", + "TEST_ORGANIZATION", + ]: + os.environ.pop(key)