feat(parallel_request_limiter.py): add support for tpm/rpm limits

This commit is contained in:
Krrish Dholakia 2024-01-18 13:52:15 -08:00
parent 2e06e00413
commit aef59c554f
2 changed files with 166 additions and 50 deletions

View file

@ -19,6 +19,7 @@ from litellm.proxy.utils import ProxyLogging
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
from datetime import datetime
## On Request received
## On Request success
@ -39,15 +40,19 @@ async def test_pre_call_hook():
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
)
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
print(
parallel_request_handler.user_api_key_cache.get_cache(
key=f"{_api_key}_request_count"
)
parallel_request_handler.user_api_key_cache.get_cache(key=request_count_api_key)
)
assert (
parallel_request_handler.user_api_key_cache.get_cache(
key=f"{_api_key}_request_count"
)
key=request_count_api_key
)["current_requests"]
== 1
)
@ -66,10 +71,16 @@ async def test_success_call_hook():
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
)
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
assert (
parallel_request_handler.user_api_key_cache.get_cache(
key=f"{_api_key}_request_count"
)
key=request_count_api_key
)["current_requests"]
== 1
)
@ -81,8 +92,8 @@ async def test_success_call_hook():
assert (
parallel_request_handler.user_api_key_cache.get_cache(
key=f"{_api_key}_request_count"
)
key=request_count_api_key
)["current_requests"]
== 0
)
@ -101,10 +112,16 @@ async def test_failure_call_hook():
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
)
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
assert (
parallel_request_handler.user_api_key_cache.get_cache(
key=f"{_api_key}_request_count"
)
key=request_count_api_key
)["current_requests"]
== 1
)
@ -119,8 +136,8 @@ async def test_failure_call_hook():
assert (
parallel_request_handler.user_api_key_cache.get_cache(
key=f"{_api_key}_request_count"
)
key=request_count_api_key
)["current_requests"]
== 0
)
@ -175,10 +192,16 @@ async def test_normal_router_call():
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
)
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
assert (
parallel_request_handler.user_api_key_cache.get_cache(
key=f"{_api_key}_request_count"
)
key=request_count_api_key
)["current_requests"]
== 1
)
@ -190,12 +213,13 @@ async def test_normal_router_call():
)
await asyncio.sleep(1) # success is done in a separate thread
print(f"response: {response}")
value = parallel_request_handler.user_api_key_cache.get_cache(
key=f"{_api_key}_request_count"
)
print(f"cache value: {value}")
assert value == 0
assert (
parallel_request_handler.user_api_key_cache.get_cache(
key=request_count_api_key
)["current_requests"]
== 0
)
@pytest.mark.asyncio
@ -240,10 +264,16 @@ async def test_streaming_router_call():
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
)
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
assert (
parallel_request_handler.user_api_key_cache.get_cache(
key=f"{_api_key}_request_count"
)
key=request_count_api_key
)["current_requests"]
== 1
)
@ -257,12 +287,12 @@ async def test_streaming_router_call():
async for chunk in response:
continue
await asyncio.sleep(1) # success is done in a separate thread
value = parallel_request_handler.user_api_key_cache.get_cache(
key=f"{_api_key}_request_count"
assert (
parallel_request_handler.user_api_key_cache.get_cache(
key=request_count_api_key
)["current_requests"]
== 0
)
print(f"cache value: {value}")
assert value == 0
@pytest.mark.asyncio
@ -307,10 +337,16 @@ async def test_bad_router_call():
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
)
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
assert (
parallel_request_handler.user_api_key_cache.get_cache(
key=f"{_api_key}_request_count"
)
key=request_count_api_key
)["current_requests"]
== 1
)
@ -324,9 +360,9 @@ async def test_bad_router_call():
)
except:
pass
value = parallel_request_handler.user_api_key_cache.get_cache(
key=f"{_api_key}_request_count"
assert (
parallel_request_handler.user_api_key_cache.get_cache(
key=request_count_api_key
)["current_requests"]
== 0
)
print(f"cache value: {value}")
assert value == 0