mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat(parallel_request_limiter.py): add support for tpm/rpm limits
This commit is contained in:
parent
2e06e00413
commit
aef59c554f
2 changed files with 166 additions and 50 deletions
|
@ -19,6 +19,7 @@ from litellm.proxy.utils import ProxyLogging
|
|||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
|
||||
from datetime import datetime
|
||||
|
||||
## On Request received
|
||||
## On Request success
|
||||
|
@ -39,15 +40,19 @@ async def test_pre_call_hook():
|
|||
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||
)
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
|
||||
|
||||
print(
|
||||
parallel_request_handler.user_api_key_cache.get_cache(
|
||||
key=f"{_api_key}_request_count"
|
||||
)
|
||||
parallel_request_handler.user_api_key_cache.get_cache(key=request_count_api_key)
|
||||
)
|
||||
assert (
|
||||
parallel_request_handler.user_api_key_cache.get_cache(
|
||||
key=f"{_api_key}_request_count"
|
||||
)
|
||||
key=request_count_api_key
|
||||
)["current_requests"]
|
||||
== 1
|
||||
)
|
||||
|
||||
|
@ -66,10 +71,16 @@ async def test_success_call_hook():
|
|||
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||
)
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
|
||||
|
||||
assert (
|
||||
parallel_request_handler.user_api_key_cache.get_cache(
|
||||
key=f"{_api_key}_request_count"
|
||||
)
|
||||
key=request_count_api_key
|
||||
)["current_requests"]
|
||||
== 1
|
||||
)
|
||||
|
||||
|
@ -81,8 +92,8 @@ async def test_success_call_hook():
|
|||
|
||||
assert (
|
||||
parallel_request_handler.user_api_key_cache.get_cache(
|
||||
key=f"{_api_key}_request_count"
|
||||
)
|
||||
key=request_count_api_key
|
||||
)["current_requests"]
|
||||
== 0
|
||||
)
|
||||
|
||||
|
@ -101,10 +112,16 @@ async def test_failure_call_hook():
|
|||
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||
)
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
|
||||
|
||||
assert (
|
||||
parallel_request_handler.user_api_key_cache.get_cache(
|
||||
key=f"{_api_key}_request_count"
|
||||
)
|
||||
key=request_count_api_key
|
||||
)["current_requests"]
|
||||
== 1
|
||||
)
|
||||
|
||||
|
@ -119,8 +136,8 @@ async def test_failure_call_hook():
|
|||
|
||||
assert (
|
||||
parallel_request_handler.user_api_key_cache.get_cache(
|
||||
key=f"{_api_key}_request_count"
|
||||
)
|
||||
key=request_count_api_key
|
||||
)["current_requests"]
|
||||
== 0
|
||||
)
|
||||
|
||||
|
@ -175,10 +192,16 @@ async def test_normal_router_call():
|
|||
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||
)
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
|
||||
|
||||
assert (
|
||||
parallel_request_handler.user_api_key_cache.get_cache(
|
||||
key=f"{_api_key}_request_count"
|
||||
)
|
||||
key=request_count_api_key
|
||||
)["current_requests"]
|
||||
== 1
|
||||
)
|
||||
|
||||
|
@ -190,12 +213,13 @@ async def test_normal_router_call():
|
|||
)
|
||||
await asyncio.sleep(1) # success is done in a separate thread
|
||||
print(f"response: {response}")
|
||||
value = parallel_request_handler.user_api_key_cache.get_cache(
|
||||
key=f"{_api_key}_request_count"
|
||||
)
|
||||
print(f"cache value: {value}")
|
||||
|
||||
assert value == 0
|
||||
assert (
|
||||
parallel_request_handler.user_api_key_cache.get_cache(
|
||||
key=request_count_api_key
|
||||
)["current_requests"]
|
||||
== 0
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -240,10 +264,16 @@ async def test_streaming_router_call():
|
|||
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||
)
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
|
||||
|
||||
assert (
|
||||
parallel_request_handler.user_api_key_cache.get_cache(
|
||||
key=f"{_api_key}_request_count"
|
||||
)
|
||||
key=request_count_api_key
|
||||
)["current_requests"]
|
||||
== 1
|
||||
)
|
||||
|
||||
|
@ -257,12 +287,12 @@ async def test_streaming_router_call():
|
|||
async for chunk in response:
|
||||
continue
|
||||
await asyncio.sleep(1) # success is done in a separate thread
|
||||
value = parallel_request_handler.user_api_key_cache.get_cache(
|
||||
key=f"{_api_key}_request_count"
|
||||
assert (
|
||||
parallel_request_handler.user_api_key_cache.get_cache(
|
||||
key=request_count_api_key
|
||||
)["current_requests"]
|
||||
== 0
|
||||
)
|
||||
print(f"cache value: {value}")
|
||||
|
||||
assert value == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -307,10 +337,16 @@ async def test_bad_router_call():
|
|||
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||
)
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
|
||||
|
||||
assert (
|
||||
parallel_request_handler.user_api_key_cache.get_cache(
|
||||
key=f"{_api_key}_request_count"
|
||||
)
|
||||
key=request_count_api_key
|
||||
)["current_requests"]
|
||||
== 1
|
||||
)
|
||||
|
||||
|
@ -324,9 +360,9 @@ async def test_bad_router_call():
|
|||
)
|
||||
except:
|
||||
pass
|
||||
value = parallel_request_handler.user_api_key_cache.get_cache(
|
||||
key=f"{_api_key}_request_count"
|
||||
assert (
|
||||
parallel_request_handler.user_api_key_cache.get_cache(
|
||||
key=request_count_api_key
|
||||
)["current_requests"]
|
||||
== 0
|
||||
)
|
||||
print(f"cache value: {value}")
|
||||
|
||||
assert value == 0
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue