diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index d633075b7..521ded627 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -207,6 +207,7 @@ from litellm.router import ModelInfo as RouterModelInfo from litellm.router import updateDeployment from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler from litellm.types.llms.openai import HttpxBinaryResponseContent +from litellm.types.router import RouterGeneralSettings try: from litellm._version import version @@ -1765,7 +1766,11 @@ class ProxyConfig: if k in available_args: router_params[k] = v router = litellm.Router( - **router_params, assistants_config=assistants_config + **router_params, + assistants_config=assistants_config, + router_general_settings=RouterGeneralSettings( + async_only_mode=True # only init async clients + ), ) # type:ignore return router, router.get_model_list(), general_settings @@ -1957,7 +1962,12 @@ class ProxyConfig: ) if len(_model_list) > 0: verbose_proxy_logger.debug(f"_model_list: {_model_list}") - llm_router = litellm.Router(model_list=_model_list) + llm_router = litellm.Router( + model_list=_model_list, + router_general_settings=RouterGeneralSettings( + async_only_mode=True # only init async clients + ), + ) verbose_proxy_logger.debug(f"updated llm_router: {llm_router}") else: verbose_proxy_logger.debug(f"len new_models: {len(new_models)}") diff --git a/litellm/router.py b/litellm/router.py index b54d70dbb..db68197a4 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -55,6 +55,10 @@ from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2 +from litellm.router_utils.client_initalization_utils import ( + set_client, + should_initialize_sync_client, +) from litellm.router_utils.handle_error import send_llm_exception_alert from litellm.scheduler import FlowItem, Scheduler from litellm.types.llms.openai import ( @@ -79,6 +83,7 @@ from litellm.types.router import ( ModelInfo, RetryPolicy, RouterErrors, + RouterGeneralSettings, updateDeployment, updateLiteLLMParams, ) @@ -169,6 +174,7 @@ class Router: routing_strategy_args: dict = {}, # just for latency-based routing semaphore: Optional[asyncio.Semaphore] = None, alerting_config: Optional[AlertingConfig] = None, + router_general_settings: Optional[RouterGeneralSettings] = None, ) -> None: """ Initialize the Router class with the given parameters for caching, reliability, and routing strategy. @@ -246,6 +252,9 @@ class Router: verbose_router_logger.setLevel(logging.INFO) elif debug_level == "DEBUG": verbose_router_logger.setLevel(logging.DEBUG) + self.router_general_settings: Optional[RouterGeneralSettings] = ( + router_general_settings + ) self.assistants_config = assistants_config self.deployment_names: List = ( @@ -3247,520 +3256,6 @@ class Router: except Exception as e: raise e - def set_client(self, model: dict): - """ - - Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278 - - Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994 - """ - client_ttl = self.client_ttl - litellm_params = model.get("litellm_params", {}) - model_name = litellm_params.get("model") - model_id = model["model_info"]["id"] - # ### IF RPM SET - initialize a semaphore ### - rpm = litellm_params.get("rpm", None) - tpm = litellm_params.get("tpm", None) - max_parallel_requests = litellm_params.get("max_parallel_requests", None) - calculated_max_parallel_requests = calculate_max_parallel_requests( - rpm=rpm, - max_parallel_requests=max_parallel_requests, - tpm=tpm, - default_max_parallel_requests=self.default_max_parallel_requests, - ) - if calculated_max_parallel_requests: - semaphore = asyncio.Semaphore(calculated_max_parallel_requests) - cache_key = f"{model_id}_max_parallel_requests_client" - self.cache.set_cache( - key=cache_key, - value=semaphore, - local_only=True, - ) - - #### for OpenAI / Azure we need to initalize the Client for High Traffic ######## - custom_llm_provider = litellm_params.get("custom_llm_provider") - custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or "" - default_api_base = None - default_api_key = None - if custom_llm_provider in litellm.openai_compatible_providers: - _, custom_llm_provider, api_key, api_base = litellm.get_llm_provider( - model=model_name - ) - default_api_base = api_base - default_api_key = api_key - - if ( - model_name in litellm.open_ai_chat_completion_models - or custom_llm_provider in litellm.openai_compatible_providers - or custom_llm_provider == "azure" - or custom_llm_provider == "azure_text" - or custom_llm_provider == "custom_openai" - or custom_llm_provider == "openai" - or custom_llm_provider == "text-completion-openai" - or "ft:gpt-3.5-turbo" in model_name - or model_name in litellm.open_ai_embedding_models - ): - is_azure_ai_studio_model: bool = False - if custom_llm_provider == "azure": - if litellm.utils._is_non_openai_azure_model(model_name): - is_azure_ai_studio_model = True - custom_llm_provider = "openai" - # remove azure prefx from model_name - model_name = model_name.replace("azure/", "") - # glorified / complicated reading of configs - # user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env - # we do this here because we init clients for Azure, OpenAI and we need to set the right key - api_key = litellm_params.get("api_key") or default_api_key - if ( - api_key - and isinstance(api_key, str) - and api_key.startswith("os.environ/") - ): - api_key_env_name = api_key.replace("os.environ/", "") - api_key = litellm.get_secret(api_key_env_name) - litellm_params["api_key"] = api_key - - api_base = litellm_params.get("api_base") - base_url = litellm_params.get("base_url") - api_base = ( - api_base or base_url or default_api_base - ) # allow users to pass in `api_base` or `base_url` for azure - if api_base and api_base.startswith("os.environ/"): - api_base_env_name = api_base.replace("os.environ/", "") - api_base = litellm.get_secret(api_base_env_name) - litellm_params["api_base"] = api_base - - ## AZURE AI STUDIO MISTRAL CHECK ## - """ - Make sure api base ends in /v1/ - - if not, add it - https://github.com/BerriAI/litellm/issues/2279 - """ - if ( - is_azure_ai_studio_model is True - and api_base is not None - and isinstance(api_base, str) - and not api_base.endswith("/v1/") - ): - # check if it ends with a trailing slash - if api_base.endswith("/"): - api_base += "v1/" - elif api_base.endswith("/v1"): - api_base += "/" - else: - api_base += "/v1/" - - api_version = litellm_params.get("api_version") - if api_version and api_version.startswith("os.environ/"): - api_version_env_name = api_version.replace("os.environ/", "") - api_version = litellm.get_secret(api_version_env_name) - litellm_params["api_version"] = api_version - - timeout = litellm_params.pop("timeout", None) or litellm.request_timeout - if isinstance(timeout, str) and timeout.startswith("os.environ/"): - timeout_env_name = timeout.replace("os.environ/", "") - timeout = litellm.get_secret(timeout_env_name) - litellm_params["timeout"] = timeout - - stream_timeout = litellm_params.pop( - "stream_timeout", timeout - ) # if no stream_timeout is set, default to timeout - if isinstance(stream_timeout, str) and stream_timeout.startswith( - "os.environ/" - ): - stream_timeout_env_name = stream_timeout.replace("os.environ/", "") - stream_timeout = litellm.get_secret(stream_timeout_env_name) - litellm_params["stream_timeout"] = stream_timeout - - max_retries = litellm_params.pop( - "max_retries", 0 - ) # router handles retry logic - if isinstance(max_retries, str) and max_retries.startswith("os.environ/"): - max_retries_env_name = max_retries.replace("os.environ/", "") - max_retries = litellm.get_secret(max_retries_env_name) - litellm_params["max_retries"] = max_retries - - # proxy support - import os - - import httpx - - # Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly. - http_proxy = os.getenv("HTTP_PROXY", None) - https_proxy = os.getenv("HTTPS_PROXY", None) - no_proxy = os.getenv("NO_PROXY", None) - - # Create the proxies dictionary only if the environment variables are set. - sync_proxy_mounts = None - async_proxy_mounts = None - if http_proxy is not None and https_proxy is not None: - sync_proxy_mounts = { - "http://": httpx.HTTPTransport(proxy=httpx.Proxy(url=http_proxy)), - "https://": httpx.HTTPTransport(proxy=httpx.Proxy(url=https_proxy)), - } - async_proxy_mounts = { - "http://": httpx.AsyncHTTPTransport( - proxy=httpx.Proxy(url=http_proxy) - ), - "https://": httpx.AsyncHTTPTransport( - proxy=httpx.Proxy(url=https_proxy) - ), - } - - # assume no_proxy is a list of comma separated urls - if no_proxy is not None and isinstance(no_proxy, str): - no_proxy_urls = no_proxy.split(",") - - for url in no_proxy_urls: # set no-proxy support for specific urls - sync_proxy_mounts[url] = None # type: ignore - async_proxy_mounts[url] = None # type: ignore - - organization = litellm_params.get("organization", None) - if isinstance(organization, str) and organization.startswith("os.environ/"): - organization_env_name = organization.replace("os.environ/", "") - organization = litellm.get_secret(organization_env_name) - litellm_params["organization"] = organization - - if custom_llm_provider == "azure" or custom_llm_provider == "azure_text": - if api_base is None or not isinstance(api_base, str): - filtered_litellm_params = { - k: v - for k, v in model["litellm_params"].items() - if k != "api_key" - } - _filtered_model = { - "model_name": model["model_name"], - "litellm_params": filtered_litellm_params, - } - raise ValueError( - f"api_base is required for Azure OpenAI. Set it on your config. Model - {_filtered_model}" - ) - azure_ad_token = litellm_params.get("azure_ad_token") - if azure_ad_token is not None: - if azure_ad_token.startswith("oidc/"): - azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) - if api_version is None: - api_version = litellm.AZURE_DEFAULT_API_VERSION - - if "gateway.ai.cloudflare.com" in api_base: - if not api_base.endswith("/"): - api_base += "/" - azure_model = model_name.replace("azure/", "") - api_base += f"{azure_model}" - cache_key = f"{model_id}_async_client" - _client = openai.AsyncAzureOpenAI( - api_key=api_key, - azure_ad_token=azure_ad_token, - base_url=api_base, - api_version=api_version, - timeout=timeout, - max_retries=max_retries, - http_client=httpx.AsyncClient( - transport=AsyncCustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), - mounts=async_proxy_mounts, - ), # type: ignore - ) - self.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - - cache_key = f"{model_id}_client" - _client = openai.AzureOpenAI( # type: ignore - api_key=api_key, - azure_ad_token=azure_ad_token, - base_url=api_base, - api_version=api_version, - timeout=timeout, - max_retries=max_retries, - http_client=httpx.Client( - transport=CustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), - mounts=sync_proxy_mounts, - ), # type: ignore - ) - self.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - # streaming clients can have diff timeouts - cache_key = f"{model_id}_stream_async_client" - _client = openai.AsyncAzureOpenAI( # type: ignore - api_key=api_key, - azure_ad_token=azure_ad_token, - base_url=api_base, - api_version=api_version, - timeout=stream_timeout, - max_retries=max_retries, - http_client=httpx.AsyncClient( - transport=AsyncCustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), - mounts=async_proxy_mounts, - ), # type: ignore - ) - self.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - - cache_key = f"{model_id}_stream_client" - _client = openai.AzureOpenAI( # type: ignore - api_key=api_key, - azure_ad_token=azure_ad_token, - base_url=api_base, - api_version=api_version, - timeout=stream_timeout, - max_retries=max_retries, - http_client=httpx.Client( - transport=CustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), - mounts=sync_proxy_mounts, - ), # type: ignore - ) - self.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - else: - _api_key = api_key - if _api_key is not None and isinstance(_api_key, str): - # only show first 5 chars of api_key - _api_key = _api_key[:8] + "*" * 15 - verbose_router_logger.debug( - f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{_api_key}" - ) - azure_client_params = { - "api_key": api_key, - "azure_endpoint": api_base, - "api_version": api_version, - "azure_ad_token": azure_ad_token, - } - from litellm.llms.azure import select_azure_base_url_or_endpoint - - # this decides if we should set azure_endpoint or base_url on Azure OpenAI Client - # required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client - azure_client_params = select_azure_base_url_or_endpoint( - azure_client_params - ) - - cache_key = f"{model_id}_async_client" - _client = openai.AsyncAzureOpenAI( # type: ignore - **azure_client_params, - timeout=timeout, - max_retries=max_retries, - http_client=httpx.AsyncClient( - transport=AsyncCustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), - mounts=async_proxy_mounts, - ), # type: ignore - ) - self.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - - cache_key = f"{model_id}_client" - _client = openai.AzureOpenAI( # type: ignore - **azure_client_params, - timeout=timeout, - max_retries=max_retries, - http_client=httpx.Client( - transport=CustomHTTPTransport( - verify=litellm.ssl_verify, - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - ), - mounts=sync_proxy_mounts, - ), # type: ignore - ) - self.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - - # streaming clients should have diff timeouts - cache_key = f"{model_id}_stream_async_client" - _client = openai.AsyncAzureOpenAI( # type: ignore - **azure_client_params, - timeout=stream_timeout, - max_retries=max_retries, - http_client=httpx.AsyncClient( - transport=AsyncCustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), - mounts=async_proxy_mounts, - ), - ) - self.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - - cache_key = f"{model_id}_stream_client" - _client = openai.AzureOpenAI( # type: ignore - **azure_client_params, - timeout=stream_timeout, - max_retries=max_retries, - http_client=httpx.Client( - transport=CustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), - mounts=sync_proxy_mounts, - ), - ) - self.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - - else: - _api_key = api_key # type: ignore - if _api_key is not None and isinstance(_api_key, str): - # only show first 5 chars of api_key - _api_key = _api_key[:8] + "*" * 15 - verbose_router_logger.debug( - f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{_api_key}" - ) - cache_key = f"{model_id}_async_client" - _client = openai.AsyncOpenAI( # type: ignore - api_key=api_key, - base_url=api_base, - timeout=timeout, - max_retries=max_retries, - organization=organization, - http_client=httpx.AsyncClient( - transport=AsyncCustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), - mounts=async_proxy_mounts, - ), # type: ignore - ) - self.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - - cache_key = f"{model_id}_client" - _client = openai.OpenAI( # type: ignore - api_key=api_key, - base_url=api_base, - timeout=timeout, - max_retries=max_retries, - organization=organization, - http_client=httpx.Client( - transport=CustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), - mounts=sync_proxy_mounts, - ), # type: ignore - ) - self.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - - # streaming clients should have diff timeouts - cache_key = f"{model_id}_stream_async_client" - _client = openai.AsyncOpenAI( # type: ignore - api_key=api_key, - base_url=api_base, - timeout=stream_timeout, - max_retries=max_retries, - organization=organization, - http_client=httpx.AsyncClient( - transport=AsyncCustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), - mounts=async_proxy_mounts, - ), # type: ignore - ) - self.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - - # streaming clients should have diff timeouts - cache_key = f"{model_id}_stream_client" - _client = openai.OpenAI( # type: ignore - api_key=api_key, - base_url=api_base, - timeout=stream_timeout, - max_retries=max_retries, - organization=organization, - http_client=httpx.Client( - transport=CustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), - mounts=sync_proxy_mounts, - ), # type: ignore - ) - self.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - def _generate_model_id(self, model_group: str, litellm_params: dict): """ Helper function to consistently generate the same id for a deployment @@ -3904,7 +3399,9 @@ class Router: raise Exception(f"Unsupported provider - {custom_llm_provider}") # init OpenAI, Azure clients - self.set_client(model=deployment.to_json(exclude_none=True)) + set_client( + litellm_router_instance=self, model=deployment.to_json(exclude_none=True) + ) # set region (if azure model) ## PREVIEW FEATURE ## if litellm.enable_preview_features == True: @@ -4432,7 +3929,7 @@ class Router: """ Re-initialize the client """ - self.set_client(model=deployment) + set_client(litellm_router_instance=self, model=deployment) client = self.cache.get_cache(key=cache_key, local_only=True) return client else: @@ -4442,7 +3939,7 @@ class Router: """ Re-initialize the client """ - self.set_client(model=deployment) + set_client(litellm_router_instance=self, model=deployment) client = self.cache.get_cache(key=cache_key, local_only=True) return client else: @@ -4453,7 +3950,7 @@ class Router: """ Re-initialize the client """ - self.set_client(model=deployment) + set_client(litellm_router_instance=self, model=deployment) client = self.cache.get_cache(key=cache_key) return client else: @@ -4463,7 +3960,7 @@ class Router: """ Re-initialize the client """ - self.set_client(model=deployment) + set_client(litellm_router_instance=self, model=deployment) client = self.cache.get_cache(key=cache_key) return client diff --git a/litellm/router_utils/client_initalization_utils.py b/litellm/router_utils/client_initalization_utils.py new file mode 100644 index 000000000..0160ffda1 --- /dev/null +++ b/litellm/router_utils/client_initalization_utils.py @@ -0,0 +1,566 @@ +import asyncio +import traceback +from typing import TYPE_CHECKING, Any + +import openai + +import litellm +from litellm._logging import verbose_router_logger +from litellm.llms.azure import get_azure_ad_token_from_oidc +from litellm.llms.custom_httpx.azure_dall_e_2 import ( + AsyncCustomHTTPTransport, + CustomHTTPTransport, +) +from litellm.utils import calculate_max_parallel_requests + +if TYPE_CHECKING: + from litellm.router import Router as _Router + + LitellmRouter = _Router +else: + LitellmRouter = Any + + +def should_initialize_sync_client( + litellm_router_instance: LitellmRouter, +) -> bool: + """ + Returns if Sync OpenAI, Azure Clients should be initialized. + + Do not init sync clients when router.router_general_settings.async_only_mode is True + + """ + if litellm_router_instance is None: + return False + + if litellm_router_instance.router_general_settings is not None: + if ( + hasattr(litellm_router_instance, "router_general_settings") + and hasattr( + litellm_router_instance.router_general_settings, "async_only_mode" + ) + and litellm_router_instance.router_general_settings.async_only_mode is True + ): + return False + + return True + + +def set_client(litellm_router_instance: LitellmRouter, model: dict): + """ + - Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278 + - Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994 + """ + client_ttl = litellm_router_instance.client_ttl + litellm_params = model.get("litellm_params", {}) + model_name = litellm_params.get("model") + model_id = model["model_info"]["id"] + # ### IF RPM SET - initialize a semaphore ### + rpm = litellm_params.get("rpm", None) + tpm = litellm_params.get("tpm", None) + max_parallel_requests = litellm_params.get("max_parallel_requests", None) + calculated_max_parallel_requests = calculate_max_parallel_requests( + rpm=rpm, + max_parallel_requests=max_parallel_requests, + tpm=tpm, + default_max_parallel_requests=litellm_router_instance.default_max_parallel_requests, + ) + if calculated_max_parallel_requests: + semaphore = asyncio.Semaphore(calculated_max_parallel_requests) + cache_key = f"{model_id}_max_parallel_requests_client" + litellm_router_instance.cache.set_cache( + key=cache_key, + value=semaphore, + local_only=True, + ) + + #### for OpenAI / Azure we need to initalize the Client for High Traffic ######## + custom_llm_provider = litellm_params.get("custom_llm_provider") + custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or "" + default_api_base = None + default_api_key = None + if custom_llm_provider in litellm.openai_compatible_providers: + _, custom_llm_provider, api_key, api_base = litellm.get_llm_provider( + model=model_name + ) + default_api_base = api_base + default_api_key = api_key + + if ( + model_name in litellm.open_ai_chat_completion_models + or custom_llm_provider in litellm.openai_compatible_providers + or custom_llm_provider == "azure" + or custom_llm_provider == "azure_text" + or custom_llm_provider == "custom_openai" + or custom_llm_provider == "openai" + or custom_llm_provider == "text-completion-openai" + or "ft:gpt-3.5-turbo" in model_name + or model_name in litellm.open_ai_embedding_models + ): + is_azure_ai_studio_model: bool = False + if custom_llm_provider == "azure": + if litellm.utils._is_non_openai_azure_model(model_name): + is_azure_ai_studio_model = True + custom_llm_provider = "openai" + # remove azure prefx from model_name + model_name = model_name.replace("azure/", "") + # glorified / complicated reading of configs + # user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env + # we do this here because we init clients for Azure, OpenAI and we need to set the right key + api_key = litellm_params.get("api_key") or default_api_key + if api_key and isinstance(api_key, str) and api_key.startswith("os.environ/"): + api_key_env_name = api_key.replace("os.environ/", "") + api_key = litellm.get_secret(api_key_env_name) + litellm_params["api_key"] = api_key + + api_base = litellm_params.get("api_base") + base_url = litellm_params.get("base_url") + api_base = ( + api_base or base_url or default_api_base + ) # allow users to pass in `api_base` or `base_url` for azure + if api_base and api_base.startswith("os.environ/"): + api_base_env_name = api_base.replace("os.environ/", "") + api_base = litellm.get_secret(api_base_env_name) + litellm_params["api_base"] = api_base + + ## AZURE AI STUDIO MISTRAL CHECK ## + """ + Make sure api base ends in /v1/ + + if not, add it - https://github.com/BerriAI/litellm/issues/2279 + """ + if ( + is_azure_ai_studio_model is True + and api_base is not None + and isinstance(api_base, str) + and not api_base.endswith("/v1/") + ): + # check if it ends with a trailing slash + if api_base.endswith("/"): + api_base += "v1/" + elif api_base.endswith("/v1"): + api_base += "/" + else: + api_base += "/v1/" + + api_version = litellm_params.get("api_version") + if api_version and api_version.startswith("os.environ/"): + api_version_env_name = api_version.replace("os.environ/", "") + api_version = litellm.get_secret(api_version_env_name) + litellm_params["api_version"] = api_version + + timeout = litellm_params.pop("timeout", None) or litellm.request_timeout + if isinstance(timeout, str) and timeout.startswith("os.environ/"): + timeout_env_name = timeout.replace("os.environ/", "") + timeout = litellm.get_secret(timeout_env_name) + litellm_params["timeout"] = timeout + + stream_timeout = litellm_params.pop( + "stream_timeout", timeout + ) # if no stream_timeout is set, default to timeout + if isinstance(stream_timeout, str) and stream_timeout.startswith("os.environ/"): + stream_timeout_env_name = stream_timeout.replace("os.environ/", "") + stream_timeout = litellm.get_secret(stream_timeout_env_name) + litellm_params["stream_timeout"] = stream_timeout + + max_retries = litellm_params.pop("max_retries", 0) # router handles retry logic + if isinstance(max_retries, str) and max_retries.startswith("os.environ/"): + max_retries_env_name = max_retries.replace("os.environ/", "") + max_retries = litellm.get_secret(max_retries_env_name) + litellm_params["max_retries"] = max_retries + + # proxy support + import os + + import httpx + + # Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly. + http_proxy = os.getenv("HTTP_PROXY", None) + https_proxy = os.getenv("HTTPS_PROXY", None) + no_proxy = os.getenv("NO_PROXY", None) + + # Create the proxies dictionary only if the environment variables are set. + sync_proxy_mounts = None + async_proxy_mounts = None + if http_proxy is not None and https_proxy is not None: + sync_proxy_mounts = { + "http://": httpx.HTTPTransport(proxy=httpx.Proxy(url=http_proxy)), + "https://": httpx.HTTPTransport(proxy=httpx.Proxy(url=https_proxy)), + } + async_proxy_mounts = { + "http://": httpx.AsyncHTTPTransport(proxy=httpx.Proxy(url=http_proxy)), + "https://": httpx.AsyncHTTPTransport( + proxy=httpx.Proxy(url=https_proxy) + ), + } + + # assume no_proxy is a list of comma separated urls + if no_proxy is not None and isinstance(no_proxy, str): + no_proxy_urls = no_proxy.split(",") + + for url in no_proxy_urls: # set no-proxy support for specific urls + sync_proxy_mounts[url] = None # type: ignore + async_proxy_mounts[url] = None # type: ignore + + organization = litellm_params.get("organization", None) + if isinstance(organization, str) and organization.startswith("os.environ/"): + organization_env_name = organization.replace("os.environ/", "") + organization = litellm.get_secret(organization_env_name) + litellm_params["organization"] = organization + + if custom_llm_provider == "azure" or custom_llm_provider == "azure_text": + if api_base is None or not isinstance(api_base, str): + filtered_litellm_params = { + k: v for k, v in model["litellm_params"].items() if k != "api_key" + } + _filtered_model = { + "model_name": model["model_name"], + "litellm_params": filtered_litellm_params, + } + raise ValueError( + f"api_base is required for Azure OpenAI. Set it on your config. Model - {_filtered_model}" + ) + azure_ad_token = litellm_params.get("azure_ad_token") + if azure_ad_token is not None: + if azure_ad_token.startswith("oidc/"): + azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) + if api_version is None: + api_version = litellm.AZURE_DEFAULT_API_VERSION + + if "gateway.ai.cloudflare.com" in api_base: + if not api_base.endswith("/"): + api_base += "/" + azure_model = model_name.replace("azure/", "") + api_base += f"{azure_model}" + cache_key = f"{model_id}_async_client" + _client = openai.AsyncAzureOpenAI( + api_key=api_key, + azure_ad_token=azure_ad_token, + base_url=api_base, + api_version=api_version, + timeout=timeout, + max_retries=max_retries, + http_client=httpx.AsyncClient( + transport=AsyncCustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), + mounts=async_proxy_mounts, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + if should_initialize_sync_client( + litellm_router_instance=litellm_router_instance + ): + cache_key = f"{model_id}_client" + _client = openai.AzureOpenAI( # type: ignore + api_key=api_key, + azure_ad_token=azure_ad_token, + base_url=api_base, + api_version=api_version, + timeout=timeout, + max_retries=max_retries, + http_client=httpx.Client( + transport=CustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), + mounts=sync_proxy_mounts, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + # streaming clients can have diff timeouts + cache_key = f"{model_id}_stream_async_client" + _client = openai.AsyncAzureOpenAI( # type: ignore + api_key=api_key, + azure_ad_token=azure_ad_token, + base_url=api_base, + api_version=api_version, + timeout=stream_timeout, + max_retries=max_retries, + http_client=httpx.AsyncClient( + transport=AsyncCustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), + mounts=async_proxy_mounts, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + if should_initialize_sync_client( + litellm_router_instance=litellm_router_instance + ): + cache_key = f"{model_id}_stream_client" + _client = openai.AzureOpenAI( # type: ignore + api_key=api_key, + azure_ad_token=azure_ad_token, + base_url=api_base, + api_version=api_version, + timeout=stream_timeout, + max_retries=max_retries, + http_client=httpx.Client( + transport=CustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), + mounts=sync_proxy_mounts, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + else: + _api_key = api_key + if _api_key is not None and isinstance(_api_key, str): + # only show first 5 chars of api_key + _api_key = _api_key[:8] + "*" * 15 + verbose_router_logger.debug( + f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{_api_key}" + ) + azure_client_params = { + "api_key": api_key, + "azure_endpoint": api_base, + "api_version": api_version, + "azure_ad_token": azure_ad_token, + } + from litellm.llms.azure import select_azure_base_url_or_endpoint + + # this decides if we should set azure_endpoint or base_url on Azure OpenAI Client + # required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client + azure_client_params = select_azure_base_url_or_endpoint( + azure_client_params + ) + + cache_key = f"{model_id}_async_client" + _client = openai.AsyncAzureOpenAI( # type: ignore + **azure_client_params, + timeout=timeout, + max_retries=max_retries, + http_client=httpx.AsyncClient( + transport=AsyncCustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), + mounts=async_proxy_mounts, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + if should_initialize_sync_client( + litellm_router_instance=litellm_router_instance + ): + cache_key = f"{model_id}_client" + _client = openai.AzureOpenAI( # type: ignore + **azure_client_params, + timeout=timeout, + max_retries=max_retries, + http_client=httpx.Client( + transport=CustomHTTPTransport( + verify=litellm.ssl_verify, + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + ), + mounts=sync_proxy_mounts, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + # streaming clients should have diff timeouts + cache_key = f"{model_id}_stream_async_client" + _client = openai.AsyncAzureOpenAI( # type: ignore + **azure_client_params, + timeout=stream_timeout, + max_retries=max_retries, + http_client=httpx.AsyncClient( + transport=AsyncCustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), + mounts=async_proxy_mounts, + ), + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + if should_initialize_sync_client( + litellm_router_instance=litellm_router_instance + ): + cache_key = f"{model_id}_stream_client" + _client = openai.AzureOpenAI( # type: ignore + **azure_client_params, + timeout=stream_timeout, + max_retries=max_retries, + http_client=httpx.Client( + transport=CustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), + mounts=sync_proxy_mounts, + ), + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + else: + _api_key = api_key # type: ignore + if _api_key is not None and isinstance(_api_key, str): + # only show first 5 chars of api_key + _api_key = _api_key[:8] + "*" * 15 + verbose_router_logger.debug( + f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{_api_key}" + ) + cache_key = f"{model_id}_async_client" + _client = openai.AsyncOpenAI( # type: ignore + api_key=api_key, + base_url=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + http_client=httpx.AsyncClient( + transport=AsyncCustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), + mounts=async_proxy_mounts, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + if should_initialize_sync_client( + litellm_router_instance=litellm_router_instance + ): + cache_key = f"{model_id}_client" + _client = openai.OpenAI( # type: ignore + api_key=api_key, + base_url=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + http_client=httpx.Client( + transport=CustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), + mounts=sync_proxy_mounts, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + # streaming clients should have diff timeouts + cache_key = f"{model_id}_stream_async_client" + _client = openai.AsyncOpenAI( # type: ignore + api_key=api_key, + base_url=api_base, + timeout=stream_timeout, + max_retries=max_retries, + organization=organization, + http_client=httpx.AsyncClient( + transport=AsyncCustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), + mounts=async_proxy_mounts, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + if should_initialize_sync_client( + litellm_router_instance=litellm_router_instance + ): + # streaming clients should have diff timeouts + cache_key = f"{model_id}_stream_client" + _client = openai.OpenAI( # type: ignore + api_key=api_key, + base_url=api_base, + timeout=stream_timeout, + max_retries=max_retries, + organization=organization, + http_client=httpx.Client( + transport=CustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), + mounts=sync_proxy_mounts, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr diff --git a/litellm/tests/test_router_init.py b/litellm/tests/test_router_init.py index f0f0cc541..13167c10f 100644 --- a/litellm/tests/test_router_init.py +++ b/litellm/tests/test_router_init.py @@ -1,16 +1,22 @@ # this tests if the router is initialized correctly -import sys, os, time -import traceback, asyncio +import asyncio +import os +import sys +import time +import traceback + import pytest sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor + +from dotenv import load_dotenv + import litellm from litellm import Router -from concurrent.futures import ThreadPoolExecutor -from collections import defaultdict -from dotenv import load_dotenv load_dotenv() @@ -24,6 +30,7 @@ load_dotenv() def test_init_clients(): litellm.set_verbose = True import logging + from litellm._logging import verbose_router_logger verbose_router_logger.setLevel(logging.DEBUG) @@ -489,6 +496,7 @@ def test_init_clients_azure_command_r_plus(): # For azure/command-r-plus we need to use openai.OpenAI because of how the Azure provider requires requests being sent litellm.set_verbose = True import logging + from litellm._logging import verbose_router_logger verbose_router_logger.setLevel(logging.DEBUG) @@ -585,3 +593,46 @@ async def test_text_completion_with_organization(): except Exception as e: pytest.fail(f"Error occurred: {e}") + + +def test_init_clients_async_mode(): + litellm.set_verbose = True + import logging + + from litellm._logging import verbose_router_logger + from litellm.types.router import RouterGeneralSettings + + verbose_router_logger.setLevel(logging.DEBUG) + try: + print("testing init 4 clients with diff timeouts") + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + "timeout": 0.01, + "stream_timeout": 0.000_001, + "max_retries": 7, + }, + }, + ] + router = Router( + model_list=model_list, + set_verbose=True, + router_general_settings=RouterGeneralSettings(async_only_mode=True), + ) + for elem in router.model_list: + model_id = elem["model_info"]["id"] + + # sync clients not initialized in async_only_mode=True + assert router.cache.get_cache(f"{model_id}_client") is None + assert router.cache.get_cache(f"{model_id}_stream_client") is None + + # only async clients initialized in async_only_mode=True + assert router.cache.get_cache(f"{model_id}_async_client") is not None + assert router.cache.get_cache(f"{model_id}_stream_async_client") is not None + except Exception as e: + pytest.fail(f"Error occurred: {e}") diff --git a/litellm/types/router.py b/litellm/types/router.py index 78d516d6c..46fc0c9e7 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -517,3 +517,9 @@ class CustomRoutingStrategyBase: """ pass + + +class RouterGeneralSettings(BaseModel): + async_only_mode: bool = Field( + default=False + ) # this will only initialize async clients. Good for memory utils