feat(dynamic_rate_limiter.py): working e2e

This commit is contained in:
Krrish Dholakia 2024-06-22 14:41:22 -07:00
parent 532f24bfb7
commit a31a05d45d
7 changed files with 420 additions and 24 deletions

View file

@ -737,6 +737,7 @@ from .utils import (
client, client,
exception_type, exception_type,
get_optional_params, get_optional_params,
get_response_string,
modify_integration, modify_integration,
token_counter, token_counter,
create_pretrained_tokenizer, create_pretrained_tokenizer,

View file

@ -30,6 +30,7 @@ model_list:
api_key: os.environ/AZURE_API_KEY api_key: os.environ/AZURE_API_KEY
api_version: 2024-02-15-preview api_version: 2024-02-15-preview
model: azure/chatgpt-v-2 model: azure/chatgpt-v-2
tpm: 100
model_name: gpt-3.5-turbo model_name: gpt-3.5-turbo
- litellm_params: - litellm_params:
model: anthropic.claude-3-sonnet-20240229-v1:0 model: anthropic.claude-3-sonnet-20240229-v1:0
@ -40,6 +41,7 @@ model_list:
api_version: 2024-02-15-preview api_version: 2024-02-15-preview
model: azure/chatgpt-v-2 model: azure/chatgpt-v-2
drop_params: True drop_params: True
tpm: 100
model_name: gpt-3.5-turbo model_name: gpt-3.5-turbo
- model_name: tts - model_name: tts
litellm_params: litellm_params:

View file

@ -55,11 +55,21 @@ class DynamicRateLimiterCache:
Raises: Raises:
- Exception, if unable to connect to cache client (if redis caching enabled) - Exception, if unable to connect to cache client (if redis caching enabled)
""" """
try:
dt = get_utc_datetime() dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M") current_minute = dt.strftime("%H-%M")
key_name = "{}:{}".format(current_minute, model) key_name = "{}:{}".format(current_minute, model)
await self.cache.async_set_cache_sadd(key=key_name, value=value, ttl=self.ttl) await self.cache.async_set_cache_sadd(
key=key_name, value=value, ttl=self.ttl
)
except Exception as e:
verbose_proxy_logger.error(
"litellm.proxy.hooks.dynamic_rate_limiter.py::async_set_cache_sadd(): Exception occured - {}\n{}".format(
str(e), traceback.format_exc()
)
)
raise e
class _PROXY_DynamicRateLimitHandler(CustomLogger): class _PROXY_DynamicRateLimitHandler(CustomLogger):
@ -79,7 +89,7 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
Returns Returns
- Tuple[available_tpm, model_tpm, active_projects] - Tuple[available_tpm, model_tpm, active_projects]
- available_tpm: int or null - available_tpm: int or null - always 0 or positive.
- remaining_model_tpm: int or null. If available tpm is int, then this will be too. - remaining_model_tpm: int or null. If available tpm is int, then this will be too.
- active_projects: int or null - active_projects: int or null
""" """
@ -108,6 +118,8 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
else: else:
available_tpm = remaining_model_tpm available_tpm = remaining_model_tpm
if available_tpm is not None and available_tpm < 0:
available_tpm = 0
return available_tpm, remaining_model_tpm, active_projects return available_tpm, remaining_model_tpm, active_projects
async def async_pre_call_hook( async def async_pre_call_hook(
@ -150,9 +162,44 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
elif available_tpm is not None: elif available_tpm is not None:
## UPDATE CACHE WITH ACTIVE PROJECT ## UPDATE CACHE WITH ACTIVE PROJECT
asyncio.create_task( asyncio.create_task(
self.internal_usage_cache.async_set_cache_sadd( self.internal_usage_cache.async_set_cache_sadd( # this is a set
model=data["model"], # type: ignore model=data["model"], # type: ignore
value=[user_api_key_dict.team_id or "default_team"], value=[user_api_key_dict.team_id or "default_team"],
) )
) )
return None return None
async def async_post_call_success_hook(
self, user_api_key_dict: UserAPIKeyAuth, response
):
try:
if isinstance(response, ModelResponse):
model_info = self.llm_router.get_model_info(
id=response._hidden_params["model_id"]
)
assert (
model_info is not None
), "Model info for model with id={} is None".format(
response._hidden_params["model_id"]
)
available_tpm, remaining_model_tpm, active_projects = (
await self.check_available_tpm(model=model_info["model_name"])
)
response._hidden_params["additional_headers"] = {
"x-litellm-model_group": model_info["model_name"],
"x-ratelimit-remaining-litellm-project-tokens": available_tpm,
"x-ratelimit-remaining-model-tokens": remaining_model_tpm,
"x-ratelimit-current-active-projects": active_projects,
}
return response
return await super().async_post_call_success_hook(
user_api_key_dict, response
)
except Exception as e:
verbose_proxy_logger.error(
"litellm.proxy.hooks.dynamic_rate_limiter.py::async_post_call_success_hook(): Exception occured - {}\n{}".format(
str(e), traceback.format_exc()
)
)
return response

View file

@ -433,6 +433,7 @@ def get_custom_headers(
version: Optional[str] = None, version: Optional[str] = None,
model_region: Optional[str] = None, model_region: Optional[str] = None,
fastest_response_batch_completion: Optional[bool] = None, fastest_response_batch_completion: Optional[bool] = None,
**kwargs,
) -> dict: ) -> dict:
exclude_values = {"", None} exclude_values = {"", None}
headers = { headers = {
@ -448,6 +449,7 @@ def get_custom_headers(
if fastest_response_batch_completion is not None if fastest_response_batch_completion is not None
else None else None
), ),
**{k: str(v) for k, v in kwargs.items()},
} }
try: try:
return { return {
@ -3063,6 +3065,14 @@ async def chat_completion(
headers=custom_headers, headers=custom_headers,
) )
### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook(
user_api_key_dict=user_api_key_dict, response=response
)
hidden_params = getattr(response, "_hidden_params", {}) or {}
additional_headers: dict = hidden_params.get("additional_headers", {}) or {}
fastapi_response.headers.update( fastapi_response.headers.update(
get_custom_headers( get_custom_headers(
user_api_key_dict=user_api_key_dict, user_api_key_dict=user_api_key_dict,
@ -3072,14 +3082,10 @@ async def chat_completion(
version=version, version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""), model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
fastest_response_batch_completion=fastest_response_batch_completion, fastest_response_batch_completion=fastest_response_batch_completion,
**additional_headers,
) )
) )
### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook(
user_api_key_dict=user_api_key_dict, response=response
)
return response return response
except RejectedRequestError as e: except RejectedRequestError as e:
_data = e.request_data _data = e.request_data

View file

@ -584,8 +584,15 @@ class ProxyLogging:
for callback in litellm.callbacks: for callback in litellm.callbacks:
try: try:
if isinstance(callback, CustomLogger): _callback: Optional[CustomLogger] = None
await callback.async_post_call_failure_hook( if isinstance(callback, str):
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
callback
)
else:
_callback = callback # type: ignore
if _callback is not None and isinstance(_callback, CustomLogger):
await _callback.async_post_call_failure_hook(
user_api_key_dict=user_api_key_dict, user_api_key_dict=user_api_key_dict,
original_exception=original_exception, original_exception=original_exception,
) )
@ -606,8 +613,15 @@ class ProxyLogging:
""" """
for callback in litellm.callbacks: for callback in litellm.callbacks:
try: try:
if isinstance(callback, CustomLogger): _callback: Optional[CustomLogger] = None
await callback.async_post_call_success_hook( if isinstance(callback, str):
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
callback
)
else:
_callback = callback # type: ignore
if _callback is not None and isinstance(_callback, CustomLogger):
await _callback.async_post_call_success_hook(
user_api_key_dict=user_api_key_dict, response=response user_api_key_dict=user_api_key_dict, response=response
) )
except Exception as e: except Exception as e:
@ -625,11 +639,22 @@ class ProxyLogging:
Covers: Covers:
1. /chat/completions 1. /chat/completions
""" """
response_str: Optional[str] = None
if isinstance(response, ModelResponse):
response_str = litellm.get_response_string(response_obj=response)
if response_str is not None:
for callback in litellm.callbacks: for callback in litellm.callbacks:
try: try:
if isinstance(callback, CustomLogger): _callback: Optional[CustomLogger] = None
await callback.async_post_call_streaming_hook( if isinstance(callback, str):
user_api_key_dict=user_api_key_dict, response=response _callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
callback
)
else:
_callback = callback # type: ignore
if _callback is not None and isinstance(_callback, CustomLogger):
await _callback.async_post_call_streaming_hook(
user_api_key_dict=user_api_key_dict, response=response_str
) )
except Exception as e: except Exception as e:
raise e raise e

View file

@ -22,6 +22,7 @@ import pytest
import litellm import litellm
from litellm import DualCache, Router from litellm import DualCache, Router
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.hooks.dynamic_rate_limiter import ( from litellm.proxy.hooks.dynamic_rate_limiter import (
_PROXY_DynamicRateLimitHandler as DynamicRateLimitHandler, _PROXY_DynamicRateLimitHandler as DynamicRateLimitHandler,
) )
@ -74,6 +75,11 @@ def mock_response() -> litellm.ModelResponse:
) )
@pytest.fixture
def user_api_key_auth() -> UserAPIKeyAuth:
return UserAPIKeyAuth()
@pytest.mark.parametrize("num_projects", [1, 2, 100]) @pytest.mark.parametrize("num_projects", [1, 2, 100])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_available_tpm(num_projects, dynamic_rate_limit_handler): async def test_available_tpm(num_projects, dynamic_rate_limit_handler):
@ -112,6 +118,62 @@ async def test_available_tpm(num_projects, dynamic_rate_limit_handler):
assert availability == expected_availability assert availability == expected_availability
@pytest.mark.asyncio
async def test_rate_limit_raised(dynamic_rate_limit_handler, user_api_key_auth):
"""
Unit test. Tests if rate limit error raised when quota exhausted.
"""
from fastapi import HTTPException
model = "my-fake-model"
## SET CACHE W/ ACTIVE PROJECTS
projects = [str(uuid.uuid4())]
await dynamic_rate_limit_handler.internal_usage_cache.async_set_cache_sadd(
model=model, value=projects
)
model_tpm = 0
llm_router = Router(
model_list=[
{
"model_name": model,
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": "my-key",
"api_base": "my-base",
"tpm": model_tpm,
},
}
]
)
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
## CHECK AVAILABLE TPM PER PROJECT
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
model=model
)
expected_availability = int(model_tpm / 1)
assert availability == expected_availability
## CHECK if exception raised
try:
await dynamic_rate_limit_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_auth,
cache=DualCache(),
data={"model": model},
call_type="completion",
)
pytest.fail("Expected this to raise HTTPexception")
except HTTPException as e:
assert e.status_code == 429 # check if rate limit error raised
pass
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_base_case(dynamic_rate_limit_handler, mock_response): async def test_base_case(dynamic_rate_limit_handler, mock_response):
""" """
@ -122,9 +184,12 @@ async def test_base_case(dynamic_rate_limit_handler, mock_response):
= allow request to go through = allow request to go through
- update token usage - update token usage
- exhaust all tpm with just 1 project - exhaust all tpm with just 1 project
- assert ratelimiterror raised at 100%+1 tpm
""" """
model = "my-fake-model" model = "my-fake-model"
## model tpm - 50
model_tpm = 50 model_tpm = 50
## tpm per request - 10
setattr( setattr(
mock_response, mock_response,
"usage", "usage",
@ -148,7 +213,9 @@ async def test_base_case(dynamic_rate_limit_handler, mock_response):
dynamic_rate_limit_handler.update_variables(llm_router=llm_router) dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
prev_availability: Optional[int] = None prev_availability: Optional[int] = None
allowed_fails = 1
for _ in range(5): for _ in range(5):
try:
# check availability # check availability
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm( availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
model=model model=model
@ -172,3 +239,248 @@ async def test_base_case(dynamic_rate_limit_handler, mock_response):
) )
await asyncio.sleep(3) await asyncio.sleep(3)
except Exception:
if allowed_fails > 0:
allowed_fails -= 1
else:
raise
@pytest.mark.asyncio
async def test_update_cache(
dynamic_rate_limit_handler, mock_response, user_api_key_auth
):
"""
Check if active project correctly updated
"""
model = "my-fake-model"
model_tpm = 50
llm_router = Router(
model_list=[
{
"model_name": model,
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": "my-key",
"api_base": "my-base",
"tpm": model_tpm,
"mock_response": mock_response,
},
}
]
)
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
## INITIAL ACTIVE PROJECTS - ASSERT NONE
_, _, active_projects = await dynamic_rate_limit_handler.check_available_tpm(
model=model
)
assert active_projects is None
## MAKE CALL
await dynamic_rate_limit_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_auth,
cache=DualCache(),
data={"model": model},
call_type="completion",
)
await asyncio.sleep(2)
## INITIAL ACTIVE PROJECTS - ASSERT 1
_, _, active_projects = await dynamic_rate_limit_handler.check_available_tpm(
model=model
)
assert active_projects == 1
@pytest.mark.parametrize("num_projects", [2])
@pytest.mark.asyncio
async def test_multiple_projects(
dynamic_rate_limit_handler, mock_response, num_projects
):
"""
If 2 active project
it should split 50% each
- assert available tpm is 0 after 50%+1 tpm calls
"""
model = "my-fake-model"
model_tpm = 50
total_tokens_per_call = 10
step_tokens_per_call_per_project = total_tokens_per_call / num_projects
available_tpm_per_project = int(model_tpm / num_projects)
## SET CACHE W/ ACTIVE PROJECTS
projects = [str(uuid.uuid4()) for _ in range(num_projects)]
await dynamic_rate_limit_handler.internal_usage_cache.async_set_cache_sadd(
model=model, value=projects
)
expected_runs = int(available_tpm_per_project / step_tokens_per_call_per_project)
setattr(
mock_response,
"usage",
litellm.Usage(
prompt_tokens=5, completion_tokens=5, total_tokens=total_tokens_per_call
),
)
llm_router = Router(
model_list=[
{
"model_name": model,
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": "my-key",
"api_base": "my-base",
"tpm": model_tpm,
"mock_response": mock_response,
},
}
]
)
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
prev_availability: Optional[int] = None
print("expected_runs: {}".format(expected_runs))
for i in range(expected_runs + 1):
# check availability
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
model=model
)
## assert availability updated
if prev_availability is not None and availability is not None:
assert (
availability == prev_availability - step_tokens_per_call_per_project
), "Current Availability: Got={}, Expected={}, Step={}, Tokens per step={}, Initial model tpm={}".format(
availability,
prev_availability - 10,
i,
step_tokens_per_call_per_project,
model_tpm,
)
print(
"prev_availability={}, availability={}".format(
prev_availability, availability
)
)
prev_availability = availability
# make call
await llm_router.acompletion(
model=model, messages=[{"role": "user", "content": "hey!"}]
)
await asyncio.sleep(3)
# check availability
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
model=model
)
assert availability == 0
@pytest.mark.parametrize("num_projects", [2])
@pytest.mark.asyncio
async def test_multiple_projects_e2e(
dynamic_rate_limit_handler, mock_response, num_projects
):
"""
2 parallel calls with different keys, same model
If 2 active project
it should split 50% each
- assert available tpm is 0 after 50%+1 tpm calls
"""
model = "my-fake-model"
model_tpm = 50
total_tokens_per_call = 10
step_tokens_per_call_per_project = total_tokens_per_call / num_projects
available_tpm_per_project = int(model_tpm / num_projects)
## SET CACHE W/ ACTIVE PROJECTS
projects = [str(uuid.uuid4()) for _ in range(num_projects)]
await dynamic_rate_limit_handler.internal_usage_cache.async_set_cache_sadd(
model=model, value=projects
)
expected_runs = int(available_tpm_per_project / step_tokens_per_call_per_project)
setattr(
mock_response,
"usage",
litellm.Usage(
prompt_tokens=5, completion_tokens=5, total_tokens=total_tokens_per_call
),
)
llm_router = Router(
model_list=[
{
"model_name": model,
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": "my-key",
"api_base": "my-base",
"tpm": model_tpm,
"mock_response": mock_response,
},
}
]
)
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
prev_availability: Optional[int] = None
print("expected_runs: {}".format(expected_runs))
for i in range(expected_runs + 1):
# check availability
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
model=model
)
## assert availability updated
if prev_availability is not None and availability is not None:
assert (
availability == prev_availability - step_tokens_per_call_per_project
), "Current Availability: Got={}, Expected={}, Step={}, Tokens per step={}, Initial model tpm={}".format(
availability,
prev_availability - 10,
i,
step_tokens_per_call_per_project,
model_tpm,
)
print(
"prev_availability={}, availability={}".format(
prev_availability, availability
)
)
prev_availability = availability
# make call
await llm_router.acompletion(
model=model, messages=[{"role": "user", "content": "hey!"}]
)
await asyncio.sleep(3)
# check availability
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
model=model
)
assert availability == 0

View file

@ -1768,6 +1768,9 @@ def mock_response() -> litellm.ModelResponse:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_router_model_usage(mock_response): async def test_router_model_usage(mock_response):
"""
Test if tracking used model tpm works as expected
"""
model = "my-fake-model" model = "my-fake-model"
model_tpm = 100 model_tpm = 100
setattr( setattr(