mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
feat(dynamic_rate_limiter.py): working e2e
This commit is contained in:
parent
532f24bfb7
commit
a31a05d45d
7 changed files with 420 additions and 24 deletions
|
@ -737,6 +737,7 @@ from .utils import (
|
||||||
client,
|
client,
|
||||||
exception_type,
|
exception_type,
|
||||||
get_optional_params,
|
get_optional_params,
|
||||||
|
get_response_string,
|
||||||
modify_integration,
|
modify_integration,
|
||||||
token_counter,
|
token_counter,
|
||||||
create_pretrained_tokenizer,
|
create_pretrained_tokenizer,
|
||||||
|
|
|
@ -30,6 +30,7 @@ model_list:
|
||||||
api_key: os.environ/AZURE_API_KEY
|
api_key: os.environ/AZURE_API_KEY
|
||||||
api_version: 2024-02-15-preview
|
api_version: 2024-02-15-preview
|
||||||
model: azure/chatgpt-v-2
|
model: azure/chatgpt-v-2
|
||||||
|
tpm: 100
|
||||||
model_name: gpt-3.5-turbo
|
model_name: gpt-3.5-turbo
|
||||||
- litellm_params:
|
- litellm_params:
|
||||||
model: anthropic.claude-3-sonnet-20240229-v1:0
|
model: anthropic.claude-3-sonnet-20240229-v1:0
|
||||||
|
@ -40,6 +41,7 @@ model_list:
|
||||||
api_version: 2024-02-15-preview
|
api_version: 2024-02-15-preview
|
||||||
model: azure/chatgpt-v-2
|
model: azure/chatgpt-v-2
|
||||||
drop_params: True
|
drop_params: True
|
||||||
|
tpm: 100
|
||||||
model_name: gpt-3.5-turbo
|
model_name: gpt-3.5-turbo
|
||||||
- model_name: tts
|
- model_name: tts
|
||||||
litellm_params:
|
litellm_params:
|
||||||
|
|
|
@ -55,11 +55,21 @@ class DynamicRateLimiterCache:
|
||||||
Raises:
|
Raises:
|
||||||
- Exception, if unable to connect to cache client (if redis caching enabled)
|
- Exception, if unable to connect to cache client (if redis caching enabled)
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
dt = get_utc_datetime()
|
dt = get_utc_datetime()
|
||||||
current_minute = dt.strftime("%H-%M")
|
current_minute = dt.strftime("%H-%M")
|
||||||
|
|
||||||
key_name = "{}:{}".format(current_minute, model)
|
key_name = "{}:{}".format(current_minute, model)
|
||||||
await self.cache.async_set_cache_sadd(key=key_name, value=value, ttl=self.ttl)
|
await self.cache.async_set_cache_sadd(
|
||||||
|
key=key_name, value=value, ttl=self.ttl
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.error(
|
||||||
|
"litellm.proxy.hooks.dynamic_rate_limiter.py::async_set_cache_sadd(): Exception occured - {}\n{}".format(
|
||||||
|
str(e), traceback.format_exc()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
class _PROXY_DynamicRateLimitHandler(CustomLogger):
|
class _PROXY_DynamicRateLimitHandler(CustomLogger):
|
||||||
|
@ -79,7 +89,7 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
- Tuple[available_tpm, model_tpm, active_projects]
|
- Tuple[available_tpm, model_tpm, active_projects]
|
||||||
- available_tpm: int or null
|
- available_tpm: int or null - always 0 or positive.
|
||||||
- remaining_model_tpm: int or null. If available tpm is int, then this will be too.
|
- remaining_model_tpm: int or null. If available tpm is int, then this will be too.
|
||||||
- active_projects: int or null
|
- active_projects: int or null
|
||||||
"""
|
"""
|
||||||
|
@ -108,6 +118,8 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
|
||||||
else:
|
else:
|
||||||
available_tpm = remaining_model_tpm
|
available_tpm = remaining_model_tpm
|
||||||
|
|
||||||
|
if available_tpm is not None and available_tpm < 0:
|
||||||
|
available_tpm = 0
|
||||||
return available_tpm, remaining_model_tpm, active_projects
|
return available_tpm, remaining_model_tpm, active_projects
|
||||||
|
|
||||||
async def async_pre_call_hook(
|
async def async_pre_call_hook(
|
||||||
|
@ -150,9 +162,44 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
|
||||||
elif available_tpm is not None:
|
elif available_tpm is not None:
|
||||||
## UPDATE CACHE WITH ACTIVE PROJECT
|
## UPDATE CACHE WITH ACTIVE PROJECT
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
self.internal_usage_cache.async_set_cache_sadd(
|
self.internal_usage_cache.async_set_cache_sadd( # this is a set
|
||||||
model=data["model"], # type: ignore
|
model=data["model"], # type: ignore
|
||||||
value=[user_api_key_dict.team_id or "default_team"],
|
value=[user_api_key_dict.team_id or "default_team"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def async_post_call_success_hook(
|
||||||
|
self, user_api_key_dict: UserAPIKeyAuth, response
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
if isinstance(response, ModelResponse):
|
||||||
|
model_info = self.llm_router.get_model_info(
|
||||||
|
id=response._hidden_params["model_id"]
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
model_info is not None
|
||||||
|
), "Model info for model with id={} is None".format(
|
||||||
|
response._hidden_params["model_id"]
|
||||||
|
)
|
||||||
|
available_tpm, remaining_model_tpm, active_projects = (
|
||||||
|
await self.check_available_tpm(model=model_info["model_name"])
|
||||||
|
)
|
||||||
|
response._hidden_params["additional_headers"] = {
|
||||||
|
"x-litellm-model_group": model_info["model_name"],
|
||||||
|
"x-ratelimit-remaining-litellm-project-tokens": available_tpm,
|
||||||
|
"x-ratelimit-remaining-model-tokens": remaining_model_tpm,
|
||||||
|
"x-ratelimit-current-active-projects": active_projects,
|
||||||
|
}
|
||||||
|
|
||||||
|
return response
|
||||||
|
return await super().async_post_call_success_hook(
|
||||||
|
user_api_key_dict, response
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.error(
|
||||||
|
"litellm.proxy.hooks.dynamic_rate_limiter.py::async_post_call_success_hook(): Exception occured - {}\n{}".format(
|
||||||
|
str(e), traceback.format_exc()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
|
@ -433,6 +433,7 @@ def get_custom_headers(
|
||||||
version: Optional[str] = None,
|
version: Optional[str] = None,
|
||||||
model_region: Optional[str] = None,
|
model_region: Optional[str] = None,
|
||||||
fastest_response_batch_completion: Optional[bool] = None,
|
fastest_response_batch_completion: Optional[bool] = None,
|
||||||
|
**kwargs,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
exclude_values = {"", None}
|
exclude_values = {"", None}
|
||||||
headers = {
|
headers = {
|
||||||
|
@ -448,6 +449,7 @@ def get_custom_headers(
|
||||||
if fastest_response_batch_completion is not None
|
if fastest_response_batch_completion is not None
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
|
**{k: str(v) for k, v in kwargs.items()},
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
return {
|
return {
|
||||||
|
@ -3063,6 +3065,14 @@ async def chat_completion(
|
||||||
headers=custom_headers,
|
headers=custom_headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
### CALL HOOKS ### - modify outgoing data
|
||||||
|
response = await proxy_logging_obj.post_call_success_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict, response=response
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||||
|
additional_headers: dict = hidden_params.get("additional_headers", {}) or {}
|
||||||
|
|
||||||
fastapi_response.headers.update(
|
fastapi_response.headers.update(
|
||||||
get_custom_headers(
|
get_custom_headers(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
@ -3072,14 +3082,10 @@ async def chat_completion(
|
||||||
version=version,
|
version=version,
|
||||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
fastest_response_batch_completion=fastest_response_batch_completion,
|
fastest_response_batch_completion=fastest_response_batch_completion,
|
||||||
|
**additional_headers,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
### CALL HOOKS ### - modify outgoing data
|
|
||||||
response = await proxy_logging_obj.post_call_success_hook(
|
|
||||||
user_api_key_dict=user_api_key_dict, response=response
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except RejectedRequestError as e:
|
except RejectedRequestError as e:
|
||||||
_data = e.request_data
|
_data = e.request_data
|
||||||
|
|
|
@ -584,8 +584,15 @@ class ProxyLogging:
|
||||||
|
|
||||||
for callback in litellm.callbacks:
|
for callback in litellm.callbacks:
|
||||||
try:
|
try:
|
||||||
if isinstance(callback, CustomLogger):
|
_callback: Optional[CustomLogger] = None
|
||||||
await callback.async_post_call_failure_hook(
|
if isinstance(callback, str):
|
||||||
|
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
|
||||||
|
callback
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_callback = callback # type: ignore
|
||||||
|
if _callback is not None and isinstance(_callback, CustomLogger):
|
||||||
|
await _callback.async_post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
original_exception=original_exception,
|
original_exception=original_exception,
|
||||||
)
|
)
|
||||||
|
@ -606,8 +613,15 @@ class ProxyLogging:
|
||||||
"""
|
"""
|
||||||
for callback in litellm.callbacks:
|
for callback in litellm.callbacks:
|
||||||
try:
|
try:
|
||||||
if isinstance(callback, CustomLogger):
|
_callback: Optional[CustomLogger] = None
|
||||||
await callback.async_post_call_success_hook(
|
if isinstance(callback, str):
|
||||||
|
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
|
||||||
|
callback
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_callback = callback # type: ignore
|
||||||
|
if _callback is not None and isinstance(_callback, CustomLogger):
|
||||||
|
await _callback.async_post_call_success_hook(
|
||||||
user_api_key_dict=user_api_key_dict, response=response
|
user_api_key_dict=user_api_key_dict, response=response
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -625,11 +639,22 @@ class ProxyLogging:
|
||||||
Covers:
|
Covers:
|
||||||
1. /chat/completions
|
1. /chat/completions
|
||||||
"""
|
"""
|
||||||
|
response_str: Optional[str] = None
|
||||||
|
if isinstance(response, ModelResponse):
|
||||||
|
response_str = litellm.get_response_string(response_obj=response)
|
||||||
|
if response_str is not None:
|
||||||
for callback in litellm.callbacks:
|
for callback in litellm.callbacks:
|
||||||
try:
|
try:
|
||||||
if isinstance(callback, CustomLogger):
|
_callback: Optional[CustomLogger] = None
|
||||||
await callback.async_post_call_streaming_hook(
|
if isinstance(callback, str):
|
||||||
user_api_key_dict=user_api_key_dict, response=response
|
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
|
||||||
|
callback
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_callback = callback # type: ignore
|
||||||
|
if _callback is not None and isinstance(_callback, CustomLogger):
|
||||||
|
await _callback.async_post_call_streaming_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict, response=response_str
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
|
@ -22,6 +22,7 @@ import pytest
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import DualCache, Router
|
from litellm import DualCache, Router
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.proxy.hooks.dynamic_rate_limiter import (
|
from litellm.proxy.hooks.dynamic_rate_limiter import (
|
||||||
_PROXY_DynamicRateLimitHandler as DynamicRateLimitHandler,
|
_PROXY_DynamicRateLimitHandler as DynamicRateLimitHandler,
|
||||||
)
|
)
|
||||||
|
@ -74,6 +75,11 @@ def mock_response() -> litellm.ModelResponse:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def user_api_key_auth() -> UserAPIKeyAuth:
|
||||||
|
return UserAPIKeyAuth()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_projects", [1, 2, 100])
|
@pytest.mark.parametrize("num_projects", [1, 2, 100])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_available_tpm(num_projects, dynamic_rate_limit_handler):
|
async def test_available_tpm(num_projects, dynamic_rate_limit_handler):
|
||||||
|
@ -112,6 +118,62 @@ async def test_available_tpm(num_projects, dynamic_rate_limit_handler):
|
||||||
assert availability == expected_availability
|
assert availability == expected_availability
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rate_limit_raised(dynamic_rate_limit_handler, user_api_key_auth):
|
||||||
|
"""
|
||||||
|
Unit test. Tests if rate limit error raised when quota exhausted.
|
||||||
|
"""
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
model = "my-fake-model"
|
||||||
|
## SET CACHE W/ ACTIVE PROJECTS
|
||||||
|
projects = [str(uuid.uuid4())]
|
||||||
|
|
||||||
|
await dynamic_rate_limit_handler.internal_usage_cache.async_set_cache_sadd(
|
||||||
|
model=model, value=projects
|
||||||
|
)
|
||||||
|
|
||||||
|
model_tpm = 0
|
||||||
|
llm_router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": model,
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": "my-key",
|
||||||
|
"api_base": "my-base",
|
||||||
|
"tpm": model_tpm,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
|
||||||
|
|
||||||
|
## CHECK AVAILABLE TPM PER PROJECT
|
||||||
|
|
||||||
|
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_availability = int(model_tpm / 1)
|
||||||
|
|
||||||
|
assert availability == expected_availability
|
||||||
|
|
||||||
|
## CHECK if exception raised
|
||||||
|
|
||||||
|
try:
|
||||||
|
await dynamic_rate_limit_handler.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_auth,
|
||||||
|
cache=DualCache(),
|
||||||
|
data={"model": model},
|
||||||
|
call_type="completion",
|
||||||
|
)
|
||||||
|
pytest.fail("Expected this to raise HTTPexception")
|
||||||
|
except HTTPException as e:
|
||||||
|
assert e.status_code == 429 # check if rate limit error raised
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_base_case(dynamic_rate_limit_handler, mock_response):
|
async def test_base_case(dynamic_rate_limit_handler, mock_response):
|
||||||
"""
|
"""
|
||||||
|
@ -122,9 +184,12 @@ async def test_base_case(dynamic_rate_limit_handler, mock_response):
|
||||||
= allow request to go through
|
= allow request to go through
|
||||||
- update token usage
|
- update token usage
|
||||||
- exhaust all tpm with just 1 project
|
- exhaust all tpm with just 1 project
|
||||||
|
- assert ratelimiterror raised at 100%+1 tpm
|
||||||
"""
|
"""
|
||||||
model = "my-fake-model"
|
model = "my-fake-model"
|
||||||
|
## model tpm - 50
|
||||||
model_tpm = 50
|
model_tpm = 50
|
||||||
|
## tpm per request - 10
|
||||||
setattr(
|
setattr(
|
||||||
mock_response,
|
mock_response,
|
||||||
"usage",
|
"usage",
|
||||||
|
@ -148,7 +213,9 @@ async def test_base_case(dynamic_rate_limit_handler, mock_response):
|
||||||
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
|
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
|
||||||
|
|
||||||
prev_availability: Optional[int] = None
|
prev_availability: Optional[int] = None
|
||||||
|
allowed_fails = 1
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
|
try:
|
||||||
# check availability
|
# check availability
|
||||||
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
|
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
|
||||||
model=model
|
model=model
|
||||||
|
@ -172,3 +239,248 @@ async def test_base_case(dynamic_rate_limit_handler, mock_response):
|
||||||
)
|
)
|
||||||
|
|
||||||
await asyncio.sleep(3)
|
await asyncio.sleep(3)
|
||||||
|
except Exception:
|
||||||
|
if allowed_fails > 0:
|
||||||
|
allowed_fails -= 1
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_cache(
|
||||||
|
dynamic_rate_limit_handler, mock_response, user_api_key_auth
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Check if active project correctly updated
|
||||||
|
"""
|
||||||
|
model = "my-fake-model"
|
||||||
|
model_tpm = 50
|
||||||
|
|
||||||
|
llm_router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": model,
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": "my-key",
|
||||||
|
"api_base": "my-base",
|
||||||
|
"tpm": model_tpm,
|
||||||
|
"mock_response": mock_response,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
|
||||||
|
|
||||||
|
## INITIAL ACTIVE PROJECTS - ASSERT NONE
|
||||||
|
_, _, active_projects = await dynamic_rate_limit_handler.check_available_tpm(
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
|
||||||
|
assert active_projects is None
|
||||||
|
|
||||||
|
## MAKE CALL
|
||||||
|
await dynamic_rate_limit_handler.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_auth,
|
||||||
|
cache=DualCache(),
|
||||||
|
data={"model": model},
|
||||||
|
call_type="completion",
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
## INITIAL ACTIVE PROJECTS - ASSERT 1
|
||||||
|
_, _, active_projects = await dynamic_rate_limit_handler.check_available_tpm(
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
|
||||||
|
assert active_projects == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_projects", [2])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_projects(
|
||||||
|
dynamic_rate_limit_handler, mock_response, num_projects
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
If 2 active project
|
||||||
|
|
||||||
|
it should split 50% each
|
||||||
|
|
||||||
|
- assert available tpm is 0 after 50%+1 tpm calls
|
||||||
|
"""
|
||||||
|
model = "my-fake-model"
|
||||||
|
model_tpm = 50
|
||||||
|
total_tokens_per_call = 10
|
||||||
|
step_tokens_per_call_per_project = total_tokens_per_call / num_projects
|
||||||
|
|
||||||
|
available_tpm_per_project = int(model_tpm / num_projects)
|
||||||
|
|
||||||
|
## SET CACHE W/ ACTIVE PROJECTS
|
||||||
|
projects = [str(uuid.uuid4()) for _ in range(num_projects)]
|
||||||
|
await dynamic_rate_limit_handler.internal_usage_cache.async_set_cache_sadd(
|
||||||
|
model=model, value=projects
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_runs = int(available_tpm_per_project / step_tokens_per_call_per_project)
|
||||||
|
|
||||||
|
setattr(
|
||||||
|
mock_response,
|
||||||
|
"usage",
|
||||||
|
litellm.Usage(
|
||||||
|
prompt_tokens=5, completion_tokens=5, total_tokens=total_tokens_per_call
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": model,
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": "my-key",
|
||||||
|
"api_base": "my-base",
|
||||||
|
"tpm": model_tpm,
|
||||||
|
"mock_response": mock_response,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
|
||||||
|
|
||||||
|
prev_availability: Optional[int] = None
|
||||||
|
|
||||||
|
print("expected_runs: {}".format(expected_runs))
|
||||||
|
for i in range(expected_runs + 1):
|
||||||
|
# check availability
|
||||||
|
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
|
||||||
|
## assert availability updated
|
||||||
|
if prev_availability is not None and availability is not None:
|
||||||
|
assert (
|
||||||
|
availability == prev_availability - step_tokens_per_call_per_project
|
||||||
|
), "Current Availability: Got={}, Expected={}, Step={}, Tokens per step={}, Initial model tpm={}".format(
|
||||||
|
availability,
|
||||||
|
prev_availability - 10,
|
||||||
|
i,
|
||||||
|
step_tokens_per_call_per_project,
|
||||||
|
model_tpm,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
"prev_availability={}, availability={}".format(
|
||||||
|
prev_availability, availability
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
prev_availability = availability
|
||||||
|
|
||||||
|
# make call
|
||||||
|
await llm_router.acompletion(
|
||||||
|
model=model, messages=[{"role": "user", "content": "hey!"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(3)
|
||||||
|
|
||||||
|
# check availability
|
||||||
|
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
assert availability == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_projects", [2])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_projects_e2e(
|
||||||
|
dynamic_rate_limit_handler, mock_response, num_projects
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
2 parallel calls with different keys, same model
|
||||||
|
|
||||||
|
If 2 active project
|
||||||
|
|
||||||
|
it should split 50% each
|
||||||
|
|
||||||
|
- assert available tpm is 0 after 50%+1 tpm calls
|
||||||
|
"""
|
||||||
|
model = "my-fake-model"
|
||||||
|
model_tpm = 50
|
||||||
|
total_tokens_per_call = 10
|
||||||
|
step_tokens_per_call_per_project = total_tokens_per_call / num_projects
|
||||||
|
|
||||||
|
available_tpm_per_project = int(model_tpm / num_projects)
|
||||||
|
|
||||||
|
## SET CACHE W/ ACTIVE PROJECTS
|
||||||
|
projects = [str(uuid.uuid4()) for _ in range(num_projects)]
|
||||||
|
await dynamic_rate_limit_handler.internal_usage_cache.async_set_cache_sadd(
|
||||||
|
model=model, value=projects
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_runs = int(available_tpm_per_project / step_tokens_per_call_per_project)
|
||||||
|
|
||||||
|
setattr(
|
||||||
|
mock_response,
|
||||||
|
"usage",
|
||||||
|
litellm.Usage(
|
||||||
|
prompt_tokens=5, completion_tokens=5, total_tokens=total_tokens_per_call
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": model,
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": "my-key",
|
||||||
|
"api_base": "my-base",
|
||||||
|
"tpm": model_tpm,
|
||||||
|
"mock_response": mock_response,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
|
||||||
|
|
||||||
|
prev_availability: Optional[int] = None
|
||||||
|
|
||||||
|
print("expected_runs: {}".format(expected_runs))
|
||||||
|
for i in range(expected_runs + 1):
|
||||||
|
# check availability
|
||||||
|
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
|
||||||
|
## assert availability updated
|
||||||
|
if prev_availability is not None and availability is not None:
|
||||||
|
assert (
|
||||||
|
availability == prev_availability - step_tokens_per_call_per_project
|
||||||
|
), "Current Availability: Got={}, Expected={}, Step={}, Tokens per step={}, Initial model tpm={}".format(
|
||||||
|
availability,
|
||||||
|
prev_availability - 10,
|
||||||
|
i,
|
||||||
|
step_tokens_per_call_per_project,
|
||||||
|
model_tpm,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
"prev_availability={}, availability={}".format(
|
||||||
|
prev_availability, availability
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
prev_availability = availability
|
||||||
|
|
||||||
|
# make call
|
||||||
|
await llm_router.acompletion(
|
||||||
|
model=model, messages=[{"role": "user", "content": "hey!"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(3)
|
||||||
|
|
||||||
|
# check availability
|
||||||
|
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
assert availability == 0
|
||||||
|
|
|
@ -1768,6 +1768,9 @@ def mock_response() -> litellm.ModelResponse:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_router_model_usage(mock_response):
|
async def test_router_model_usage(mock_response):
|
||||||
|
"""
|
||||||
|
Test if tracking used model tpm works as expected
|
||||||
|
"""
|
||||||
model = "my-fake-model"
|
model = "my-fake-model"
|
||||||
model_tpm = 100
|
model_tpm = 100
|
||||||
setattr(
|
setattr(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue