From 4c78f8f309e618b6875d34f54a79f85ce3f75834 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 20 Apr 2024 10:43:18 -0700 Subject: [PATCH] fix(router.py): calculate max_parallel_requests from given tpm limits use the azure formula to calculate rpm -> max_parallel_requests based on a deployment's tpm limits --- litellm/router.py | 49 ++++++++++++++++++++++++++++++++++------------- litellm/utils.py | 40 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 13 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index 8145ef619..d7988aaba 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -26,7 +26,12 @@ from litellm.llms.custom_httpx.azure_dall_e_2 import ( CustomHTTPTransport, AsyncCustomHTTPTransport, ) -from litellm.utils import ModelResponse, CustomStreamWrapper, get_utc_datetime +from litellm.utils import ( + ModelResponse, + CustomStreamWrapper, + get_utc_datetime, + calculate_max_parallel_requests, +) import copy from litellm._logging import verbose_router_logger import logging @@ -61,6 +66,7 @@ class Router: num_retries: int = 0, timeout: Optional[float] = None, default_litellm_params={}, # default params for Router.chat.completion.create + default_max_parallel_requests: Optional[int] = None, set_verbose: bool = False, debug_level: Literal["DEBUG", "INFO"] = "INFO", fallbacks: List = [], @@ -213,6 +219,7 @@ class Router: ) # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown self.num_retries = num_retries or litellm.num_retries or 0 self.timeout = timeout or litellm.request_timeout + self.default_max_parallel_requests = default_max_parallel_requests self.retry_after = retry_after self.routing_strategy = routing_strategy self.fallbacks = fallbacks or litellm.fallbacks @@ -496,7 +503,9 @@ class Router: ) rpm_semaphore = self._get_client( - deployment=deployment, kwargs=kwargs, client_type="rpm_client" + deployment=deployment, + kwargs=kwargs, + client_type="max_parallel_requests", ) if rpm_semaphore is not None and isinstance( @@ -681,7 +690,9 @@ class Router: ### CONCURRENCY-SAFE RPM CHECKS ### rpm_semaphore = self._get_client( - deployment=deployment, kwargs=kwargs, client_type="rpm_client" + deployment=deployment, + kwargs=kwargs, + client_type="max_parallel_requests", ) if rpm_semaphore is not None and isinstance( @@ -803,7 +814,9 @@ class Router: ### CONCURRENCY-SAFE RPM CHECKS ### rpm_semaphore = self._get_client( - deployment=deployment, kwargs=kwargs, client_type="rpm_client" + deployment=deployment, + kwargs=kwargs, + client_type="max_parallel_requests", ) if rpm_semaphore is not None and isinstance( @@ -1049,7 +1062,9 @@ class Router: ) rpm_semaphore = self._get_client( - deployment=deployment, kwargs=kwargs, client_type="rpm_client" + deployment=deployment, + kwargs=kwargs, + client_type="max_parallel_requests", ) if rpm_semaphore is not None and isinstance( @@ -1243,7 +1258,9 @@ class Router: ### CONCURRENCY-SAFE RPM CHECKS ### rpm_semaphore = self._get_client( - deployment=deployment, kwargs=kwargs, client_type="rpm_client" + deployment=deployment, + kwargs=kwargs, + client_type="max_parallel_requests", ) if rpm_semaphore is not None and isinstance( @@ -1862,17 +1879,23 @@ class Router: model_id = model["model_info"]["id"] # ### IF RPM SET - initialize a semaphore ### rpm = litellm_params.get("rpm", None) - if rpm: - semaphore = asyncio.Semaphore(rpm) - cache_key = f"{model_id}_rpm_client" + tpm = litellm_params.get("tpm", None) + max_parallel_requests = litellm_params.get("max_parallel_requests", None) + calculated_max_parallel_requests = calculate_max_parallel_requests( + rpm=rpm, + max_parallel_requests=max_parallel_requests, + tpm=tpm, + default_max_parallel_requests=self.default_max_parallel_requests, + ) + if calculated_max_parallel_requests: + semaphore = asyncio.Semaphore(calculated_max_parallel_requests) + cache_key = f"{model_id}_max_parallel_requests_client" self.cache.set_cache( key=cache_key, value=semaphore, local_only=True, ) - # print("STORES SEMAPHORE IN CACHE") - #### for OpenAI / Azure we need to initalize the Client for High Traffic ######## custom_llm_provider = litellm_params.get("custom_llm_provider") custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or "" @@ -2537,8 +2560,8 @@ class Router: The appropriate client based on the given client_type and kwargs. """ model_id = deployment["model_info"]["id"] - if client_type == "rpm_client": - cache_key = "{}_rpm_client".format(model_id) + if client_type == "max_parallel_requests": + cache_key = "{}_max_parallel_requests".format(model_id) client = self.cache.get_cache(key=cache_key, local_only=True) return client elif client_type == "async": diff --git a/litellm/utils.py b/litellm/utils.py index e230675e6..566ef2099 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5395,6 +5395,46 @@ def get_optional_params( return optional_params +def calculate_max_parallel_requests( + max_parallel_requests: Optional[int], + rpm: Optional[int], + tpm: Optional[int], + default_max_parallel_requests: Optional[int], +) -> Optional[int]: + """ + Returns the max parallel requests to send to a deployment. + + Used in semaphore for async requests on router. + + Parameters: + - max_parallel_requests - Optional[int] - max_parallel_requests allowed for that deployment + - rpm - Optional[int] - requests per minute allowed for that deployment + - tpm - Optional[int] - tokens per minute allowed for that deployment + - default_max_parallel_requests - Optional[int] - default_max_parallel_requests allowed for any deployment + + Returns: + - int or None (if all params are None) + + Order: + max_parallel_requests > rpm > tpm / 6 (azure formula) > default max_parallel_requests + + Azure RPM formula: + 6 rpm per 1000 TPM + https://learn.microsoft.com/en-us/azure/ai-services/openai/quotas-limits + + + """ + if max_parallel_requests is not None: + return max_parallel_requests + elif rpm is not None: + return rpm + elif tpm is not None: + return int(tpm / 1000 / 6) + elif default_max_parallel_requests is not None: + return default_max_parallel_requests + return None + + def get_api_base(model: str, optional_params: dict) -> Optional[str]: """ Returns the api base used for calling the model.