forked from phoenix/litellm-mirror
fix(router.py): calculate max_parallel_requests from given tpm limits
use the azure formula to calculate rpm -> max_parallel_requests based on a deployment's tpm limits
This commit is contained in:
parent
0ce2fb83b0
commit
4c78f8f309
2 changed files with 76 additions and 13 deletions
|
@ -26,7 +26,12 @@ from litellm.llms.custom_httpx.azure_dall_e_2 import (
|
||||||
CustomHTTPTransport,
|
CustomHTTPTransport,
|
||||||
AsyncCustomHTTPTransport,
|
AsyncCustomHTTPTransport,
|
||||||
)
|
)
|
||||||
from litellm.utils import ModelResponse, CustomStreamWrapper, get_utc_datetime
|
from litellm.utils import (
|
||||||
|
ModelResponse,
|
||||||
|
CustomStreamWrapper,
|
||||||
|
get_utc_datetime,
|
||||||
|
calculate_max_parallel_requests,
|
||||||
|
)
|
||||||
import copy
|
import copy
|
||||||
from litellm._logging import verbose_router_logger
|
from litellm._logging import verbose_router_logger
|
||||||
import logging
|
import logging
|
||||||
|
@ -61,6 +66,7 @@ class Router:
|
||||||
num_retries: int = 0,
|
num_retries: int = 0,
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
default_litellm_params={}, # default params for Router.chat.completion.create
|
default_litellm_params={}, # default params for Router.chat.completion.create
|
||||||
|
default_max_parallel_requests: Optional[int] = None,
|
||||||
set_verbose: bool = False,
|
set_verbose: bool = False,
|
||||||
debug_level: Literal["DEBUG", "INFO"] = "INFO",
|
debug_level: Literal["DEBUG", "INFO"] = "INFO",
|
||||||
fallbacks: List = [],
|
fallbacks: List = [],
|
||||||
|
@ -213,6 +219,7 @@ class Router:
|
||||||
) # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown
|
) # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown
|
||||||
self.num_retries = num_retries or litellm.num_retries or 0
|
self.num_retries = num_retries or litellm.num_retries or 0
|
||||||
self.timeout = timeout or litellm.request_timeout
|
self.timeout = timeout or litellm.request_timeout
|
||||||
|
self.default_max_parallel_requests = default_max_parallel_requests
|
||||||
self.retry_after = retry_after
|
self.retry_after = retry_after
|
||||||
self.routing_strategy = routing_strategy
|
self.routing_strategy = routing_strategy
|
||||||
self.fallbacks = fallbacks or litellm.fallbacks
|
self.fallbacks = fallbacks or litellm.fallbacks
|
||||||
|
@ -496,7 +503,9 @@ class Router:
|
||||||
)
|
)
|
||||||
|
|
||||||
rpm_semaphore = self._get_client(
|
rpm_semaphore = self._get_client(
|
||||||
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
|
deployment=deployment,
|
||||||
|
kwargs=kwargs,
|
||||||
|
client_type="max_parallel_requests",
|
||||||
)
|
)
|
||||||
|
|
||||||
if rpm_semaphore is not None and isinstance(
|
if rpm_semaphore is not None and isinstance(
|
||||||
|
@ -681,7 +690,9 @@ class Router:
|
||||||
|
|
||||||
### CONCURRENCY-SAFE RPM CHECKS ###
|
### CONCURRENCY-SAFE RPM CHECKS ###
|
||||||
rpm_semaphore = self._get_client(
|
rpm_semaphore = self._get_client(
|
||||||
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
|
deployment=deployment,
|
||||||
|
kwargs=kwargs,
|
||||||
|
client_type="max_parallel_requests",
|
||||||
)
|
)
|
||||||
|
|
||||||
if rpm_semaphore is not None and isinstance(
|
if rpm_semaphore is not None and isinstance(
|
||||||
|
@ -803,7 +814,9 @@ class Router:
|
||||||
|
|
||||||
### CONCURRENCY-SAFE RPM CHECKS ###
|
### CONCURRENCY-SAFE RPM CHECKS ###
|
||||||
rpm_semaphore = self._get_client(
|
rpm_semaphore = self._get_client(
|
||||||
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
|
deployment=deployment,
|
||||||
|
kwargs=kwargs,
|
||||||
|
client_type="max_parallel_requests",
|
||||||
)
|
)
|
||||||
|
|
||||||
if rpm_semaphore is not None and isinstance(
|
if rpm_semaphore is not None and isinstance(
|
||||||
|
@ -1049,7 +1062,9 @@ class Router:
|
||||||
)
|
)
|
||||||
|
|
||||||
rpm_semaphore = self._get_client(
|
rpm_semaphore = self._get_client(
|
||||||
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
|
deployment=deployment,
|
||||||
|
kwargs=kwargs,
|
||||||
|
client_type="max_parallel_requests",
|
||||||
)
|
)
|
||||||
|
|
||||||
if rpm_semaphore is not None and isinstance(
|
if rpm_semaphore is not None and isinstance(
|
||||||
|
@ -1243,7 +1258,9 @@ class Router:
|
||||||
|
|
||||||
### CONCURRENCY-SAFE RPM CHECKS ###
|
### CONCURRENCY-SAFE RPM CHECKS ###
|
||||||
rpm_semaphore = self._get_client(
|
rpm_semaphore = self._get_client(
|
||||||
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
|
deployment=deployment,
|
||||||
|
kwargs=kwargs,
|
||||||
|
client_type="max_parallel_requests",
|
||||||
)
|
)
|
||||||
|
|
||||||
if rpm_semaphore is not None and isinstance(
|
if rpm_semaphore is not None and isinstance(
|
||||||
|
@ -1862,17 +1879,23 @@ class Router:
|
||||||
model_id = model["model_info"]["id"]
|
model_id = model["model_info"]["id"]
|
||||||
# ### IF RPM SET - initialize a semaphore ###
|
# ### IF RPM SET - initialize a semaphore ###
|
||||||
rpm = litellm_params.get("rpm", None)
|
rpm = litellm_params.get("rpm", None)
|
||||||
if rpm:
|
tpm = litellm_params.get("tpm", None)
|
||||||
semaphore = asyncio.Semaphore(rpm)
|
max_parallel_requests = litellm_params.get("max_parallel_requests", None)
|
||||||
cache_key = f"{model_id}_rpm_client"
|
calculated_max_parallel_requests = calculate_max_parallel_requests(
|
||||||
|
rpm=rpm,
|
||||||
|
max_parallel_requests=max_parallel_requests,
|
||||||
|
tpm=tpm,
|
||||||
|
default_max_parallel_requests=self.default_max_parallel_requests,
|
||||||
|
)
|
||||||
|
if calculated_max_parallel_requests:
|
||||||
|
semaphore = asyncio.Semaphore(calculated_max_parallel_requests)
|
||||||
|
cache_key = f"{model_id}_max_parallel_requests_client"
|
||||||
self.cache.set_cache(
|
self.cache.set_cache(
|
||||||
key=cache_key,
|
key=cache_key,
|
||||||
value=semaphore,
|
value=semaphore,
|
||||||
local_only=True,
|
local_only=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# print("STORES SEMAPHORE IN CACHE")
|
|
||||||
|
|
||||||
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
|
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
|
||||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||||
custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or ""
|
custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or ""
|
||||||
|
@ -2537,8 +2560,8 @@ class Router:
|
||||||
The appropriate client based on the given client_type and kwargs.
|
The appropriate client based on the given client_type and kwargs.
|
||||||
"""
|
"""
|
||||||
model_id = deployment["model_info"]["id"]
|
model_id = deployment["model_info"]["id"]
|
||||||
if client_type == "rpm_client":
|
if client_type == "max_parallel_requests":
|
||||||
cache_key = "{}_rpm_client".format(model_id)
|
cache_key = "{}_max_parallel_requests".format(model_id)
|
||||||
client = self.cache.get_cache(key=cache_key, local_only=True)
|
client = self.cache.get_cache(key=cache_key, local_only=True)
|
||||||
return client
|
return client
|
||||||
elif client_type == "async":
|
elif client_type == "async":
|
||||||
|
|
|
@ -5395,6 +5395,46 @@ def get_optional_params(
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_max_parallel_requests(
|
||||||
|
max_parallel_requests: Optional[int],
|
||||||
|
rpm: Optional[int],
|
||||||
|
tpm: Optional[int],
|
||||||
|
default_max_parallel_requests: Optional[int],
|
||||||
|
) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
Returns the max parallel requests to send to a deployment.
|
||||||
|
|
||||||
|
Used in semaphore for async requests on router.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- max_parallel_requests - Optional[int] - max_parallel_requests allowed for that deployment
|
||||||
|
- rpm - Optional[int] - requests per minute allowed for that deployment
|
||||||
|
- tpm - Optional[int] - tokens per minute allowed for that deployment
|
||||||
|
- default_max_parallel_requests - Optional[int] - default_max_parallel_requests allowed for any deployment
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- int or None (if all params are None)
|
||||||
|
|
||||||
|
Order:
|
||||||
|
max_parallel_requests > rpm > tpm / 6 (azure formula) > default max_parallel_requests
|
||||||
|
|
||||||
|
Azure RPM formula:
|
||||||
|
6 rpm per 1000 TPM
|
||||||
|
https://learn.microsoft.com/en-us/azure/ai-services/openai/quotas-limits
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
if max_parallel_requests is not None:
|
||||||
|
return max_parallel_requests
|
||||||
|
elif rpm is not None:
|
||||||
|
return rpm
|
||||||
|
elif tpm is not None:
|
||||||
|
return int(tpm / 1000 / 6)
|
||||||
|
elif default_max_parallel_requests is not None:
|
||||||
|
return default_max_parallel_requests
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_api_base(model: str, optional_params: dict) -> Optional[str]:
|
def get_api_base(model: str, optional_params: dict) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Returns the api base used for calling the model.
|
Returns the api base used for calling the model.
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue