fix(router.py): calculate max_parallel_requests from given tpm limits

use the azure formula to calculate rpm -> max_parallel_requests based on a deployment's tpm limits
This commit is contained in:
Krrish Dholakia 2024-04-20 10:43:18 -07:00
parent 0ce2fb83b0
commit 4c78f8f309
2 changed files with 76 additions and 13 deletions

View file

@ -26,7 +26,12 @@ from litellm.llms.custom_httpx.azure_dall_e_2 import (
CustomHTTPTransport,
AsyncCustomHTTPTransport,
)
from litellm.utils import ModelResponse, CustomStreamWrapper, get_utc_datetime
from litellm.utils import (
ModelResponse,
CustomStreamWrapper,
get_utc_datetime,
calculate_max_parallel_requests,
)
import copy
from litellm._logging import verbose_router_logger
import logging
@ -61,6 +66,7 @@ class Router:
num_retries: int = 0,
timeout: Optional[float] = None,
default_litellm_params={}, # default params for Router.chat.completion.create
default_max_parallel_requests: Optional[int] = None,
set_verbose: bool = False,
debug_level: Literal["DEBUG", "INFO"] = "INFO",
fallbacks: List = [],
@ -213,6 +219,7 @@ class Router:
) # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown
self.num_retries = num_retries or litellm.num_retries or 0
self.timeout = timeout or litellm.request_timeout
self.default_max_parallel_requests = default_max_parallel_requests
self.retry_after = retry_after
self.routing_strategy = routing_strategy
self.fallbacks = fallbacks or litellm.fallbacks
@ -496,7 +503,9 @@ class Router:
)
rpm_semaphore = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
deployment=deployment,
kwargs=kwargs,
client_type="max_parallel_requests",
)
if rpm_semaphore is not None and isinstance(
@ -681,7 +690,9 @@ class Router:
### CONCURRENCY-SAFE RPM CHECKS ###
rpm_semaphore = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
deployment=deployment,
kwargs=kwargs,
client_type="max_parallel_requests",
)
if rpm_semaphore is not None and isinstance(
@ -803,7 +814,9 @@ class Router:
### CONCURRENCY-SAFE RPM CHECKS ###
rpm_semaphore = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
deployment=deployment,
kwargs=kwargs,
client_type="max_parallel_requests",
)
if rpm_semaphore is not None and isinstance(
@ -1049,7 +1062,9 @@ class Router:
)
rpm_semaphore = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
deployment=deployment,
kwargs=kwargs,
client_type="max_parallel_requests",
)
if rpm_semaphore is not None and isinstance(
@ -1243,7 +1258,9 @@ class Router:
### CONCURRENCY-SAFE RPM CHECKS ###
rpm_semaphore = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
deployment=deployment,
kwargs=kwargs,
client_type="max_parallel_requests",
)
if rpm_semaphore is not None and isinstance(
@ -1862,17 +1879,23 @@ class Router:
model_id = model["model_info"]["id"]
# ### IF RPM SET - initialize a semaphore ###
rpm = litellm_params.get("rpm", None)
if rpm:
semaphore = asyncio.Semaphore(rpm)
cache_key = f"{model_id}_rpm_client"
tpm = litellm_params.get("tpm", None)
max_parallel_requests = litellm_params.get("max_parallel_requests", None)
calculated_max_parallel_requests = calculate_max_parallel_requests(
rpm=rpm,
max_parallel_requests=max_parallel_requests,
tpm=tpm,
default_max_parallel_requests=self.default_max_parallel_requests,
)
if calculated_max_parallel_requests:
semaphore = asyncio.Semaphore(calculated_max_parallel_requests)
cache_key = f"{model_id}_max_parallel_requests_client"
self.cache.set_cache(
key=cache_key,
value=semaphore,
local_only=True,
)
# print("STORES SEMAPHORE IN CACHE")
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
custom_llm_provider = litellm_params.get("custom_llm_provider")
custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or ""
@ -2537,8 +2560,8 @@ class Router:
The appropriate client based on the given client_type and kwargs.
"""
model_id = deployment["model_info"]["id"]
if client_type == "rpm_client":
cache_key = "{}_rpm_client".format(model_id)
if client_type == "max_parallel_requests":
cache_key = "{}_max_parallel_requests".format(model_id)
client = self.cache.get_cache(key=cache_key, local_only=True)
return client
elif client_type == "async":