mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
feat(dynamic_rate_limiter.py): working e2e
This commit is contained in:
parent
532f24bfb7
commit
a31a05d45d
7 changed files with 420 additions and 24 deletions
|
@ -737,6 +737,7 @@ from .utils import (
|
|||
client,
|
||||
exception_type,
|
||||
get_optional_params,
|
||||
get_response_string,
|
||||
modify_integration,
|
||||
token_counter,
|
||||
create_pretrained_tokenizer,
|
||||
|
|
|
@ -30,6 +30,7 @@ model_list:
|
|||
api_key: os.environ/AZURE_API_KEY
|
||||
api_version: 2024-02-15-preview
|
||||
model: azure/chatgpt-v-2
|
||||
tpm: 100
|
||||
model_name: gpt-3.5-turbo
|
||||
- litellm_params:
|
||||
model: anthropic.claude-3-sonnet-20240229-v1:0
|
||||
|
@ -40,6 +41,7 @@ model_list:
|
|||
api_version: 2024-02-15-preview
|
||||
model: azure/chatgpt-v-2
|
||||
drop_params: True
|
||||
tpm: 100
|
||||
model_name: gpt-3.5-turbo
|
||||
- model_name: tts
|
||||
litellm_params:
|
||||
|
|
|
@ -55,11 +55,21 @@ class DynamicRateLimiterCache:
|
|||
Raises:
|
||||
- Exception, if unable to connect to cache client (if redis caching enabled)
|
||||
"""
|
||||
try:
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
|
||||
key_name = "{}:{}".format(current_minute, model)
|
||||
await self.cache.async_set_cache_sadd(key=key_name, value=value, ttl=self.ttl)
|
||||
await self.cache.async_set_cache_sadd(
|
||||
key=key_name, value=value, ttl=self.ttl
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.hooks.dynamic_rate_limiter.py::async_set_cache_sadd(): Exception occured - {}\n{}".format(
|
||||
str(e), traceback.format_exc()
|
||||
)
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
class _PROXY_DynamicRateLimitHandler(CustomLogger):
|
||||
|
@ -79,7 +89,7 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
|
|||
|
||||
Returns
|
||||
- Tuple[available_tpm, model_tpm, active_projects]
|
||||
- available_tpm: int or null
|
||||
- available_tpm: int or null - always 0 or positive.
|
||||
- remaining_model_tpm: int or null. If available tpm is int, then this will be too.
|
||||
- active_projects: int or null
|
||||
"""
|
||||
|
@ -108,6 +118,8 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
|
|||
else:
|
||||
available_tpm = remaining_model_tpm
|
||||
|
||||
if available_tpm is not None and available_tpm < 0:
|
||||
available_tpm = 0
|
||||
return available_tpm, remaining_model_tpm, active_projects
|
||||
|
||||
async def async_pre_call_hook(
|
||||
|
@ -150,9 +162,44 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
|
|||
elif available_tpm is not None:
|
||||
## UPDATE CACHE WITH ACTIVE PROJECT
|
||||
asyncio.create_task(
|
||||
self.internal_usage_cache.async_set_cache_sadd(
|
||||
self.internal_usage_cache.async_set_cache_sadd( # this is a set
|
||||
model=data["model"], # type: ignore
|
||||
value=[user_api_key_dict.team_id or "default_team"],
|
||||
)
|
||||
)
|
||||
return None
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self, user_api_key_dict: UserAPIKeyAuth, response
|
||||
):
|
||||
try:
|
||||
if isinstance(response, ModelResponse):
|
||||
model_info = self.llm_router.get_model_info(
|
||||
id=response._hidden_params["model_id"]
|
||||
)
|
||||
assert (
|
||||
model_info is not None
|
||||
), "Model info for model with id={} is None".format(
|
||||
response._hidden_params["model_id"]
|
||||
)
|
||||
available_tpm, remaining_model_tpm, active_projects = (
|
||||
await self.check_available_tpm(model=model_info["model_name"])
|
||||
)
|
||||
response._hidden_params["additional_headers"] = {
|
||||
"x-litellm-model_group": model_info["model_name"],
|
||||
"x-ratelimit-remaining-litellm-project-tokens": available_tpm,
|
||||
"x-ratelimit-remaining-model-tokens": remaining_model_tpm,
|
||||
"x-ratelimit-current-active-projects": active_projects,
|
||||
}
|
||||
|
||||
return response
|
||||
return await super().async_post_call_success_hook(
|
||||
user_api_key_dict, response
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.hooks.dynamic_rate_limiter.py::async_post_call_success_hook(): Exception occured - {}\n{}".format(
|
||||
str(e), traceback.format_exc()
|
||||
)
|
||||
)
|
||||
return response
|
||||
|
|
|
@ -433,6 +433,7 @@ def get_custom_headers(
|
|||
version: Optional[str] = None,
|
||||
model_region: Optional[str] = None,
|
||||
fastest_response_batch_completion: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
exclude_values = {"", None}
|
||||
headers = {
|
||||
|
@ -448,6 +449,7 @@ def get_custom_headers(
|
|||
if fastest_response_batch_completion is not None
|
||||
else None
|
||||
),
|
||||
**{k: str(v) for k, v in kwargs.items()},
|
||||
}
|
||||
try:
|
||||
return {
|
||||
|
@ -3063,6 +3065,14 @@ async def chat_completion(
|
|||
headers=custom_headers,
|
||||
)
|
||||
|
||||
### CALL HOOKS ### - modify outgoing data
|
||||
response = await proxy_logging_obj.post_call_success_hook(
|
||||
user_api_key_dict=user_api_key_dict, response=response
|
||||
)
|
||||
|
||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||
additional_headers: dict = hidden_params.get("additional_headers", {}) or {}
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
|
@ -3072,14 +3082,10 @@ async def chat_completion(
|
|||
version=version,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
fastest_response_batch_completion=fastest_response_batch_completion,
|
||||
**additional_headers,
|
||||
)
|
||||
)
|
||||
|
||||
### CALL HOOKS ### - modify outgoing data
|
||||
response = await proxy_logging_obj.post_call_success_hook(
|
||||
user_api_key_dict=user_api_key_dict, response=response
|
||||
)
|
||||
|
||||
return response
|
||||
except RejectedRequestError as e:
|
||||
_data = e.request_data
|
||||
|
|
|
@ -584,8 +584,15 @@ class ProxyLogging:
|
|||
|
||||
for callback in litellm.callbacks:
|
||||
try:
|
||||
if isinstance(callback, CustomLogger):
|
||||
await callback.async_post_call_failure_hook(
|
||||
_callback: Optional[CustomLogger] = None
|
||||
if isinstance(callback, str):
|
||||
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
|
||||
callback
|
||||
)
|
||||
else:
|
||||
_callback = callback # type: ignore
|
||||
if _callback is not None and isinstance(_callback, CustomLogger):
|
||||
await _callback.async_post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
original_exception=original_exception,
|
||||
)
|
||||
|
@ -606,8 +613,15 @@ class ProxyLogging:
|
|||
"""
|
||||
for callback in litellm.callbacks:
|
||||
try:
|
||||
if isinstance(callback, CustomLogger):
|
||||
await callback.async_post_call_success_hook(
|
||||
_callback: Optional[CustomLogger] = None
|
||||
if isinstance(callback, str):
|
||||
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
|
||||
callback
|
||||
)
|
||||
else:
|
||||
_callback = callback # type: ignore
|
||||
if _callback is not None and isinstance(_callback, CustomLogger):
|
||||
await _callback.async_post_call_success_hook(
|
||||
user_api_key_dict=user_api_key_dict, response=response
|
||||
)
|
||||
except Exception as e:
|
||||
|
@ -625,11 +639,22 @@ class ProxyLogging:
|
|||
Covers:
|
||||
1. /chat/completions
|
||||
"""
|
||||
response_str: Optional[str] = None
|
||||
if isinstance(response, ModelResponse):
|
||||
response_str = litellm.get_response_string(response_obj=response)
|
||||
if response_str is not None:
|
||||
for callback in litellm.callbacks:
|
||||
try:
|
||||
if isinstance(callback, CustomLogger):
|
||||
await callback.async_post_call_streaming_hook(
|
||||
user_api_key_dict=user_api_key_dict, response=response
|
||||
_callback: Optional[CustomLogger] = None
|
||||
if isinstance(callback, str):
|
||||
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
|
||||
callback
|
||||
)
|
||||
else:
|
||||
_callback = callback # type: ignore
|
||||
if _callback is not None and isinstance(_callback, CustomLogger):
|
||||
await _callback.async_post_call_streaming_hook(
|
||||
user_api_key_dict=user_api_key_dict, response=response_str
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
|
|
@ -22,6 +22,7 @@ import pytest
|
|||
|
||||
import litellm
|
||||
from litellm import DualCache, Router
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.hooks.dynamic_rate_limiter import (
|
||||
_PROXY_DynamicRateLimitHandler as DynamicRateLimitHandler,
|
||||
)
|
||||
|
@ -74,6 +75,11 @@ def mock_response() -> litellm.ModelResponse:
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_api_key_auth() -> UserAPIKeyAuth:
|
||||
return UserAPIKeyAuth()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_projects", [1, 2, 100])
|
||||
@pytest.mark.asyncio
|
||||
async def test_available_tpm(num_projects, dynamic_rate_limit_handler):
|
||||
|
@ -112,6 +118,62 @@ async def test_available_tpm(num_projects, dynamic_rate_limit_handler):
|
|||
assert availability == expected_availability
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_raised(dynamic_rate_limit_handler, user_api_key_auth):
|
||||
"""
|
||||
Unit test. Tests if rate limit error raised when quota exhausted.
|
||||
"""
|
||||
from fastapi import HTTPException
|
||||
|
||||
model = "my-fake-model"
|
||||
## SET CACHE W/ ACTIVE PROJECTS
|
||||
projects = [str(uuid.uuid4())]
|
||||
|
||||
await dynamic_rate_limit_handler.internal_usage_cache.async_set_cache_sadd(
|
||||
model=model, value=projects
|
||||
)
|
||||
|
||||
model_tpm = 0
|
||||
llm_router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": model,
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": "my-key",
|
||||
"api_base": "my-base",
|
||||
"tpm": model_tpm,
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
|
||||
|
||||
## CHECK AVAILABLE TPM PER PROJECT
|
||||
|
||||
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
|
||||
model=model
|
||||
)
|
||||
|
||||
expected_availability = int(model_tpm / 1)
|
||||
|
||||
assert availability == expected_availability
|
||||
|
||||
## CHECK if exception raised
|
||||
|
||||
try:
|
||||
await dynamic_rate_limit_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_auth,
|
||||
cache=DualCache(),
|
||||
data={"model": model},
|
||||
call_type="completion",
|
||||
)
|
||||
pytest.fail("Expected this to raise HTTPexception")
|
||||
except HTTPException as e:
|
||||
assert e.status_code == 429 # check if rate limit error raised
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_case(dynamic_rate_limit_handler, mock_response):
|
||||
"""
|
||||
|
@ -122,9 +184,12 @@ async def test_base_case(dynamic_rate_limit_handler, mock_response):
|
|||
= allow request to go through
|
||||
- update token usage
|
||||
- exhaust all tpm with just 1 project
|
||||
- assert ratelimiterror raised at 100%+1 tpm
|
||||
"""
|
||||
model = "my-fake-model"
|
||||
## model tpm - 50
|
||||
model_tpm = 50
|
||||
## tpm per request - 10
|
||||
setattr(
|
||||
mock_response,
|
||||
"usage",
|
||||
|
@ -148,7 +213,9 @@ async def test_base_case(dynamic_rate_limit_handler, mock_response):
|
|||
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
|
||||
|
||||
prev_availability: Optional[int] = None
|
||||
allowed_fails = 1
|
||||
for _ in range(5):
|
||||
try:
|
||||
# check availability
|
||||
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
|
||||
model=model
|
||||
|
@ -172,3 +239,248 @@ async def test_base_case(dynamic_rate_limit_handler, mock_response):
|
|||
)
|
||||
|
||||
await asyncio.sleep(3)
|
||||
except Exception:
|
||||
if allowed_fails > 0:
|
||||
allowed_fails -= 1
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_cache(
|
||||
dynamic_rate_limit_handler, mock_response, user_api_key_auth
|
||||
):
|
||||
"""
|
||||
Check if active project correctly updated
|
||||
"""
|
||||
model = "my-fake-model"
|
||||
model_tpm = 50
|
||||
|
||||
llm_router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": model,
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": "my-key",
|
||||
"api_base": "my-base",
|
||||
"tpm": model_tpm,
|
||||
"mock_response": mock_response,
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
|
||||
|
||||
## INITIAL ACTIVE PROJECTS - ASSERT NONE
|
||||
_, _, active_projects = await dynamic_rate_limit_handler.check_available_tpm(
|
||||
model=model
|
||||
)
|
||||
|
||||
assert active_projects is None
|
||||
|
||||
## MAKE CALL
|
||||
await dynamic_rate_limit_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_auth,
|
||||
cache=DualCache(),
|
||||
data={"model": model},
|
||||
call_type="completion",
|
||||
)
|
||||
|
||||
await asyncio.sleep(2)
|
||||
## INITIAL ACTIVE PROJECTS - ASSERT 1
|
||||
_, _, active_projects = await dynamic_rate_limit_handler.check_available_tpm(
|
||||
model=model
|
||||
)
|
||||
|
||||
assert active_projects == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_projects", [2])
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_projects(
|
||||
dynamic_rate_limit_handler, mock_response, num_projects
|
||||
):
|
||||
"""
|
||||
If 2 active project
|
||||
|
||||
it should split 50% each
|
||||
|
||||
- assert available tpm is 0 after 50%+1 tpm calls
|
||||
"""
|
||||
model = "my-fake-model"
|
||||
model_tpm = 50
|
||||
total_tokens_per_call = 10
|
||||
step_tokens_per_call_per_project = total_tokens_per_call / num_projects
|
||||
|
||||
available_tpm_per_project = int(model_tpm / num_projects)
|
||||
|
||||
## SET CACHE W/ ACTIVE PROJECTS
|
||||
projects = [str(uuid.uuid4()) for _ in range(num_projects)]
|
||||
await dynamic_rate_limit_handler.internal_usage_cache.async_set_cache_sadd(
|
||||
model=model, value=projects
|
||||
)
|
||||
|
||||
expected_runs = int(available_tpm_per_project / step_tokens_per_call_per_project)
|
||||
|
||||
setattr(
|
||||
mock_response,
|
||||
"usage",
|
||||
litellm.Usage(
|
||||
prompt_tokens=5, completion_tokens=5, total_tokens=total_tokens_per_call
|
||||
),
|
||||
)
|
||||
|
||||
llm_router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": model,
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": "my-key",
|
||||
"api_base": "my-base",
|
||||
"tpm": model_tpm,
|
||||
"mock_response": mock_response,
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
|
||||
|
||||
prev_availability: Optional[int] = None
|
||||
|
||||
print("expected_runs: {}".format(expected_runs))
|
||||
for i in range(expected_runs + 1):
|
||||
# check availability
|
||||
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
|
||||
model=model
|
||||
)
|
||||
|
||||
## assert availability updated
|
||||
if prev_availability is not None and availability is not None:
|
||||
assert (
|
||||
availability == prev_availability - step_tokens_per_call_per_project
|
||||
), "Current Availability: Got={}, Expected={}, Step={}, Tokens per step={}, Initial model tpm={}".format(
|
||||
availability,
|
||||
prev_availability - 10,
|
||||
i,
|
||||
step_tokens_per_call_per_project,
|
||||
model_tpm,
|
||||
)
|
||||
|
||||
print(
|
||||
"prev_availability={}, availability={}".format(
|
||||
prev_availability, availability
|
||||
)
|
||||
)
|
||||
|
||||
prev_availability = availability
|
||||
|
||||
# make call
|
||||
await llm_router.acompletion(
|
||||
model=model, messages=[{"role": "user", "content": "hey!"}]
|
||||
)
|
||||
|
||||
await asyncio.sleep(3)
|
||||
|
||||
# check availability
|
||||
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
|
||||
model=model
|
||||
)
|
||||
assert availability == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_projects", [2])
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_projects_e2e(
|
||||
dynamic_rate_limit_handler, mock_response, num_projects
|
||||
):
|
||||
"""
|
||||
2 parallel calls with different keys, same model
|
||||
|
||||
If 2 active project
|
||||
|
||||
it should split 50% each
|
||||
|
||||
- assert available tpm is 0 after 50%+1 tpm calls
|
||||
"""
|
||||
model = "my-fake-model"
|
||||
model_tpm = 50
|
||||
total_tokens_per_call = 10
|
||||
step_tokens_per_call_per_project = total_tokens_per_call / num_projects
|
||||
|
||||
available_tpm_per_project = int(model_tpm / num_projects)
|
||||
|
||||
## SET CACHE W/ ACTIVE PROJECTS
|
||||
projects = [str(uuid.uuid4()) for _ in range(num_projects)]
|
||||
await dynamic_rate_limit_handler.internal_usage_cache.async_set_cache_sadd(
|
||||
model=model, value=projects
|
||||
)
|
||||
|
||||
expected_runs = int(available_tpm_per_project / step_tokens_per_call_per_project)
|
||||
|
||||
setattr(
|
||||
mock_response,
|
||||
"usage",
|
||||
litellm.Usage(
|
||||
prompt_tokens=5, completion_tokens=5, total_tokens=total_tokens_per_call
|
||||
),
|
||||
)
|
||||
|
||||
llm_router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": model,
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": "my-key",
|
||||
"api_base": "my-base",
|
||||
"tpm": model_tpm,
|
||||
"mock_response": mock_response,
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
|
||||
|
||||
prev_availability: Optional[int] = None
|
||||
|
||||
print("expected_runs: {}".format(expected_runs))
|
||||
for i in range(expected_runs + 1):
|
||||
# check availability
|
||||
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
|
||||
model=model
|
||||
)
|
||||
|
||||
## assert availability updated
|
||||
if prev_availability is not None and availability is not None:
|
||||
assert (
|
||||
availability == prev_availability - step_tokens_per_call_per_project
|
||||
), "Current Availability: Got={}, Expected={}, Step={}, Tokens per step={}, Initial model tpm={}".format(
|
||||
availability,
|
||||
prev_availability - 10,
|
||||
i,
|
||||
step_tokens_per_call_per_project,
|
||||
model_tpm,
|
||||
)
|
||||
|
||||
print(
|
||||
"prev_availability={}, availability={}".format(
|
||||
prev_availability, availability
|
||||
)
|
||||
)
|
||||
|
||||
prev_availability = availability
|
||||
|
||||
# make call
|
||||
await llm_router.acompletion(
|
||||
model=model, messages=[{"role": "user", "content": "hey!"}]
|
||||
)
|
||||
|
||||
await asyncio.sleep(3)
|
||||
|
||||
# check availability
|
||||
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
|
||||
model=model
|
||||
)
|
||||
assert availability == 0
|
||||
|
|
|
@ -1768,6 +1768,9 @@ def mock_response() -> litellm.ModelResponse:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_model_usage(mock_response):
|
||||
"""
|
||||
Test if tracking used model tpm works as expected
|
||||
"""
|
||||
model = "my-fake-model"
|
||||
model_tpm = 100
|
||||
setattr(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue