Merge pull request #3192 from BerriAI/litellm_calculate_max_parallel_requests

fix(router.py): Make TPM limits concurrency-safe
This commit is contained in:
Krish Dholakia 2024-04-20 13:24:29 -07:00 committed by GitHub
commit fcde3ba213
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 208 additions and 19 deletions

View file

@ -609,6 +609,7 @@ def completion(
"client",
"rpm",
"tpm",
"max_parallel_requests",
"input_cost_per_token",
"output_cost_per_token",
"input_cost_per_second",
@ -2560,6 +2561,7 @@ def embedding(
client = kwargs.pop("client", None)
rpm = kwargs.pop("rpm", None)
tpm = kwargs.pop("tpm", None)
max_parallel_requests = kwargs.pop("max_parallel_requests", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", None)
encoding_format = kwargs.get("encoding_format", None)
@ -2617,6 +2619,7 @@ def embedding(
"client",
"rpm",
"tpm",
"max_parallel_requests",
"input_cost_per_token",
"output_cost_per_token",
"input_cost_per_second",
@ -3476,6 +3479,7 @@ def image_generation(
"client",
"rpm",
"tpm",
"max_parallel_requests",
"input_cost_per_token",
"output_cost_per_token",
"hf_model_name",

View file

@ -26,7 +26,12 @@ from litellm.llms.custom_httpx.azure_dall_e_2 import (
CustomHTTPTransport,
AsyncCustomHTTPTransport,
)
from litellm.utils import ModelResponse, CustomStreamWrapper, get_utc_datetime
from litellm.utils import (
ModelResponse,
CustomStreamWrapper,
get_utc_datetime,
calculate_max_parallel_requests,
)
import copy
from litellm._logging import verbose_router_logger
import logging
@ -61,6 +66,7 @@ class Router:
num_retries: int = 0,
timeout: Optional[float] = None,
default_litellm_params={}, # default params for Router.chat.completion.create
default_max_parallel_requests: Optional[int] = None,
set_verbose: bool = False,
debug_level: Literal["DEBUG", "INFO"] = "INFO",
fallbacks: List = [],
@ -198,6 +204,7 @@ class Router:
) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
self.default_deployment = None # use this to track the users default deployment, when they want to use model = *
self.default_max_parallel_requests = default_max_parallel_requests
if model_list:
model_list = copy.deepcopy(model_list)
@ -213,6 +220,7 @@ class Router:
) # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown
self.num_retries = num_retries or litellm.num_retries or 0
self.timeout = timeout or litellm.request_timeout
self.retry_after = retry_after
self.routing_strategy = routing_strategy
self.fallbacks = fallbacks or litellm.fallbacks
@ -298,7 +306,7 @@ class Router:
else:
litellm.failure_callback = [self.deployment_callback_on_failure]
verbose_router_logger.info(
f"Intialized router with Routing strategy: {self.routing_strategy}\n\nRouting fallbacks: {self.fallbacks}\n\nRouting context window fallbacks: {self.context_window_fallbacks}"
f"Intialized router with Routing strategy: {self.routing_strategy}\n\nRouting fallbacks: {self.fallbacks}\n\nRouting context window fallbacks: {self.context_window_fallbacks}\n\nRouter Redis Caching={self.cache.redis_cache}"
)
self.routing_strategy_args = routing_strategy_args
@ -496,7 +504,9 @@ class Router:
)
rpm_semaphore = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
deployment=deployment,
kwargs=kwargs,
client_type="max_parallel_requests",
)
if rpm_semaphore is not None and isinstance(
@ -681,7 +691,9 @@ class Router:
### CONCURRENCY-SAFE RPM CHECKS ###
rpm_semaphore = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
deployment=deployment,
kwargs=kwargs,
client_type="max_parallel_requests",
)
if rpm_semaphore is not None and isinstance(
@ -803,7 +815,9 @@ class Router:
### CONCURRENCY-SAFE RPM CHECKS ###
rpm_semaphore = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
deployment=deployment,
kwargs=kwargs,
client_type="max_parallel_requests",
)
if rpm_semaphore is not None and isinstance(
@ -1049,7 +1063,9 @@ class Router:
)
rpm_semaphore = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
deployment=deployment,
kwargs=kwargs,
client_type="max_parallel_requests",
)
if rpm_semaphore is not None and isinstance(
@ -1243,7 +1259,9 @@ class Router:
### CONCURRENCY-SAFE RPM CHECKS ###
rpm_semaphore = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
deployment=deployment,
kwargs=kwargs,
client_type="max_parallel_requests",
)
if rpm_semaphore is not None and isinstance(
@ -1862,17 +1880,23 @@ class Router:
model_id = model["model_info"]["id"]
# ### IF RPM SET - initialize a semaphore ###
rpm = litellm_params.get("rpm", None)
if rpm:
semaphore = asyncio.Semaphore(rpm)
cache_key = f"{model_id}_rpm_client"
tpm = litellm_params.get("tpm", None)
max_parallel_requests = litellm_params.get("max_parallel_requests", None)
calculated_max_parallel_requests = calculate_max_parallel_requests(
rpm=rpm,
max_parallel_requests=max_parallel_requests,
tpm=tpm,
default_max_parallel_requests=self.default_max_parallel_requests,
)
if calculated_max_parallel_requests:
semaphore = asyncio.Semaphore(calculated_max_parallel_requests)
cache_key = f"{model_id}_max_parallel_requests_client"
self.cache.set_cache(
key=cache_key,
value=semaphore,
local_only=True,
)
# print("STORES SEMAPHORE IN CACHE")
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
custom_llm_provider = litellm_params.get("custom_llm_provider")
custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or ""
@ -2537,8 +2561,8 @@ class Router:
The appropriate client based on the given client_type and kwargs.
"""
model_id = deployment["model_info"]["id"]
if client_type == "rpm_client":
cache_key = "{}_rpm_client".format(model_id)
if client_type == "max_parallel_requests":
cache_key = "{}_max_parallel_requests_client".format(model_id)
client = self.cache.get_cache(key=cache_key, local_only=True)
return client
elif client_type == "async":

View file

@ -81,7 +81,7 @@ def test_async_fallbacks(caplog):
# Define the expected log messages
# - error request, falling back notice, success notice
expected_logs = [
"Intialized router with Routing strategy: simple-shuffle\n\nRouting fallbacks: [{'gpt-3.5-turbo': ['azure/gpt-3.5-turbo']}]\n\nRouting context window fallbacks: None",
"Intialized router with Routing strategy: simple-shuffle\n\nRouting fallbacks: [{'gpt-3.5-turbo': ['azure/gpt-3.5-turbo']}]\n\nRouting context window fallbacks: None\n\nRouter Redis Caching=None",
"litellm.acompletion(model=gpt-3.5-turbo)\x1b[31m Exception OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}\x1b[0m",
"Falling back to model_group = azure/gpt-3.5-turbo",
"litellm.acompletion(model=azure/chatgpt-v-2)\x1b[32m 200 OK\x1b[0m",

View file

@ -0,0 +1,115 @@
# What is this?
## Unit tests for the max_parallel_requests feature on Router
import sys, os, time, inspect, asyncio, traceback
from datetime import datetime
import pytest
sys.path.insert(0, os.path.abspath("../.."))
import litellm
from litellm.utils import calculate_max_parallel_requests
from typing import Optional
"""
- only rpm
- only tpm
- only max_parallel_requests
- max_parallel_requests + rpm
- max_parallel_requests + tpm
- max_parallel_requests + tpm + rpm
"""
max_parallel_requests_values = [None, 10]
tpm_values = [None, 20, 300000]
rpm_values = [None, 30]
default_max_parallel_requests = [None, 40]
@pytest.mark.parametrize(
"max_parallel_requests, tpm, rpm, default_max_parallel_requests",
[
(mp, tp, rp, dmp)
for mp in max_parallel_requests_values
for tp in tpm_values
for rp in rpm_values
for dmp in default_max_parallel_requests
],
)
def test_scenario(max_parallel_requests, tpm, rpm, default_max_parallel_requests):
calculated_max_parallel_requests = calculate_max_parallel_requests(
max_parallel_requests=max_parallel_requests,
rpm=rpm,
tpm=tpm,
default_max_parallel_requests=default_max_parallel_requests,
)
if max_parallel_requests is not None:
assert max_parallel_requests == calculated_max_parallel_requests
elif rpm is not None:
assert rpm == calculated_max_parallel_requests
elif tpm is not None:
calculated_rpm = int(tpm / 1000 / 6)
if calculated_rpm == 0:
calculated_rpm = 1
print(
f"test calculated_rpm: {calculated_rpm}, calculated_max_parallel_requests={calculated_max_parallel_requests}"
)
assert calculated_rpm == calculated_max_parallel_requests
elif default_max_parallel_requests is not None:
assert calculated_max_parallel_requests == default_max_parallel_requests
else:
assert calculated_max_parallel_requests is None
@pytest.mark.parametrize(
"max_parallel_requests, tpm, rpm, default_max_parallel_requests",
[
(mp, tp, rp, dmp)
for mp in max_parallel_requests_values
for tp in tpm_values
for rp in rpm_values
for dmp in default_max_parallel_requests
],
)
def test_setting_mpr_limits_per_model(
max_parallel_requests, tpm, rpm, default_max_parallel_requests
):
deployment = {
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"max_parallel_requests": max_parallel_requests,
"tpm": tpm,
"rpm": rpm,
},
"model_info": {"id": "my-unique-id"},
}
router = litellm.Router(
model_list=[deployment],
default_max_parallel_requests=default_max_parallel_requests,
)
mpr_client: Optional[asyncio.Semaphore] = router._get_client(
deployment=deployment,
kwargs={},
client_type="max_parallel_requests",
)
if max_parallel_requests is not None:
assert max_parallel_requests == mpr_client._value
elif rpm is not None:
assert rpm == mpr_client._value
elif tpm is not None:
calculated_rpm = int(tpm / 1000 / 6)
if calculated_rpm == 0:
calculated_rpm = 1
print(
f"test calculated_rpm: {calculated_rpm}, calculated_max_parallel_requests={mpr_client._value}"
)
assert calculated_rpm == mpr_client._value
elif default_max_parallel_requests is not None:
assert mpr_client._value == default_max_parallel_requests
else:
assert mpr_client is None
# raise Exception("it worked!")

View file

@ -5395,6 +5395,49 @@ def get_optional_params(
return optional_params
def calculate_max_parallel_requests(
max_parallel_requests: Optional[int],
rpm: Optional[int],
tpm: Optional[int],
default_max_parallel_requests: Optional[int],
) -> Optional[int]:
"""
Returns the max parallel requests to send to a deployment.
Used in semaphore for async requests on router.
Parameters:
- max_parallel_requests - Optional[int] - max_parallel_requests allowed for that deployment
- rpm - Optional[int] - requests per minute allowed for that deployment
- tpm - Optional[int] - tokens per minute allowed for that deployment
- default_max_parallel_requests - Optional[int] - default_max_parallel_requests allowed for any deployment
Returns:
- int or None (if all params are None)
Order:
max_parallel_requests > rpm > tpm / 6 (azure formula) > default max_parallel_requests
Azure RPM formula:
6 rpm per 1000 TPM
https://learn.microsoft.com/en-us/azure/ai-services/openai/quotas-limits
"""
if max_parallel_requests is not None:
return max_parallel_requests
elif rpm is not None:
return rpm
elif tpm is not None:
calculated_rpm = int(tpm / 1000 / 6)
if calculated_rpm == 0:
calculated_rpm = 1
return calculated_rpm
elif default_max_parallel_requests is not None:
return default_max_parallel_requests
return None
def get_api_base(model: str, optional_params: dict) -> Optional[str]:
"""
Returns the api base used for calling the model.

View file

@ -96,9 +96,9 @@ litellm_settings:
router_settings:
routing_strategy: usage-based-routing-v2
redis_host: os.environ/REDIS_HOST
redis_password: os.environ/REDIS_PASSWORD
redis_port: os.environ/REDIS_PORT
# redis_host: os.environ/REDIS_HOST
# redis_password: os.environ/REDIS_PASSWORD
# redis_port: os.environ/REDIS_PORT
enable_pre_call_checks: true
general_settings:

View file

@ -260,7 +260,10 @@ async def test_chat_completion_ratelimit():
await asyncio.gather(*tasks)
pytest.fail("Expected at least 1 call to fail")
except Exception as e:
pass
if "Request did not return a 200 status code: 429" in str(e):
pass
else:
pytest.fail(f"Wrong error received - {str(e)}")
@pytest.mark.asyncio