mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
test(test_parallel_request_limiter.py): fix test
This commit is contained in:
parent
76c9b715f2
commit
8d56f72d5a
1 changed files with 44 additions and 26 deletions
|
@ -40,7 +40,9 @@ async def test_global_max_parallel_requests():
|
|||
_api_key = hash_token("sk-12345")
|
||||
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=100)
|
||||
local_cache = DualCache()
|
||||
parallel_request_handler = MaxParallelRequestsHandler()
|
||||
parallel_request_handler = MaxParallelRequestsHandler(
|
||||
internal_usage_cache=local_cache
|
||||
)
|
||||
|
||||
for _ in range(3):
|
||||
try:
|
||||
|
@ -68,7 +70,9 @@ async def test_pre_call_hook():
|
|||
_api_key = hash_token("sk-12345")
|
||||
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
|
||||
local_cache = DualCache()
|
||||
parallel_request_handler = MaxParallelRequestsHandler()
|
||||
parallel_request_handler = MaxParallelRequestsHandler(
|
||||
internal_usage_cache=local_cache
|
||||
)
|
||||
|
||||
await parallel_request_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||
|
@ -81,10 +85,12 @@ async def test_pre_call_hook():
|
|||
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
|
||||
|
||||
print(
|
||||
parallel_request_handler.user_api_key_cache.get_cache(key=request_count_api_key)
|
||||
parallel_request_handler.internal_usage_cache.get_cache(
|
||||
key=request_count_api_key
|
||||
)
|
||||
)
|
||||
assert (
|
||||
parallel_request_handler.user_api_key_cache.get_cache(
|
||||
parallel_request_handler.internal_usage_cache.get_cache(
|
||||
key=request_count_api_key
|
||||
)["current_requests"]
|
||||
== 1
|
||||
|
@ -101,7 +107,9 @@ async def test_pre_call_hook_rpm_limits():
|
|||
api_key=_api_key, max_parallel_requests=1, tpm_limit=9, rpm_limit=1
|
||||
)
|
||||
local_cache = DualCache()
|
||||
parallel_request_handler = MaxParallelRequestsHandler()
|
||||
parallel_request_handler = MaxParallelRequestsHandler(
|
||||
internal_usage_cache=local_cache
|
||||
)
|
||||
|
||||
await parallel_request_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||
|
@ -148,7 +156,9 @@ async def test_pre_call_hook_team_rpm_limits():
|
|||
team_id=_team_id,
|
||||
)
|
||||
local_cache = DualCache()
|
||||
parallel_request_handler = MaxParallelRequestsHandler()
|
||||
parallel_request_handler = MaxParallelRequestsHandler(
|
||||
internal_usage_cache=local_cache
|
||||
)
|
||||
|
||||
await parallel_request_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||
|
@ -194,7 +204,9 @@ async def test_pre_call_hook_tpm_limits():
|
|||
api_key=_api_key, max_parallel_requests=1, tpm_limit=9, rpm_limit=10
|
||||
)
|
||||
local_cache = DualCache()
|
||||
parallel_request_handler = MaxParallelRequestsHandler()
|
||||
parallel_request_handler = MaxParallelRequestsHandler(
|
||||
internal_usage_cache=local_cache
|
||||
)
|
||||
|
||||
await parallel_request_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||
|
@ -244,7 +256,9 @@ async def test_pre_call_hook_user_tpm_limits():
|
|||
res = dict(user_api_key_dict)
|
||||
print("dict user", res)
|
||||
|
||||
parallel_request_handler = MaxParallelRequestsHandler()
|
||||
parallel_request_handler = MaxParallelRequestsHandler(
|
||||
internal_usage_cache=local_cache
|
||||
)
|
||||
|
||||
await parallel_request_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||
|
@ -287,7 +301,9 @@ async def test_success_call_hook():
|
|||
_api_key = hash_token("sk-12345")
|
||||
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
|
||||
local_cache = DualCache()
|
||||
parallel_request_handler = MaxParallelRequestsHandler()
|
||||
parallel_request_handler = MaxParallelRequestsHandler(
|
||||
internal_usage_cache=local_cache
|
||||
)
|
||||
|
||||
await parallel_request_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||
|
@ -300,7 +316,7 @@ async def test_success_call_hook():
|
|||
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
|
||||
|
||||
assert (
|
||||
parallel_request_handler.user_api_key_cache.get_cache(
|
||||
parallel_request_handler.internal_usage_cache.get_cache(
|
||||
key=request_count_api_key
|
||||
)["current_requests"]
|
||||
== 1
|
||||
|
@ -313,7 +329,7 @@ async def test_success_call_hook():
|
|||
)
|
||||
|
||||
assert (
|
||||
parallel_request_handler.user_api_key_cache.get_cache(
|
||||
parallel_request_handler.internal_usage_cache.get_cache(
|
||||
key=request_count_api_key
|
||||
)["current_requests"]
|
||||
== 0
|
||||
|
@ -329,7 +345,9 @@ async def test_failure_call_hook():
|
|||
_api_key = hash_token(_api_key)
|
||||
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
|
||||
local_cache = DualCache()
|
||||
parallel_request_handler = MaxParallelRequestsHandler()
|
||||
parallel_request_handler = MaxParallelRequestsHandler(
|
||||
internal_usage_cache=local_cache
|
||||
)
|
||||
|
||||
await parallel_request_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||
|
@ -342,7 +360,7 @@ async def test_failure_call_hook():
|
|||
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
|
||||
|
||||
assert (
|
||||
parallel_request_handler.user_api_key_cache.get_cache(
|
||||
parallel_request_handler.internal_usage_cache.get_cache(
|
||||
key=request_count_api_key
|
||||
)["current_requests"]
|
||||
== 1
|
||||
|
@ -358,7 +376,7 @@ async def test_failure_call_hook():
|
|||
)
|
||||
|
||||
assert (
|
||||
parallel_request_handler.user_api_key_cache.get_cache(
|
||||
parallel_request_handler.internal_usage_cache.get_cache(
|
||||
key=request_count_api_key
|
||||
)["current_requests"]
|
||||
== 0
|
||||
|
@ -423,7 +441,7 @@ async def test_normal_router_call():
|
|||
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
|
||||
|
||||
assert (
|
||||
parallel_request_handler.user_api_key_cache.get_cache(
|
||||
parallel_request_handler.internal_usage_cache.get_cache(
|
||||
key=request_count_api_key
|
||||
)["current_requests"]
|
||||
== 1
|
||||
|
@ -440,7 +458,7 @@ async def test_normal_router_call():
|
|||
print(f"response: {response}")
|
||||
|
||||
assert (
|
||||
parallel_request_handler.user_api_key_cache.get_cache(
|
||||
parallel_request_handler.internal_usage_cache.get_cache(
|
||||
key=request_count_api_key
|
||||
)["current_requests"]
|
||||
== 0
|
||||
|
@ -504,7 +522,7 @@ async def test_normal_router_tpm_limit():
|
|||
print("Test: Checking current_requests for precise_minute=", precise_minute)
|
||||
|
||||
assert (
|
||||
parallel_request_handler.user_api_key_cache.get_cache(
|
||||
parallel_request_handler.internal_usage_cache.get_cache(
|
||||
key=request_count_api_key
|
||||
)["current_requests"]
|
||||
== 1
|
||||
|
@ -522,7 +540,7 @@ async def test_normal_router_tpm_limit():
|
|||
|
||||
try:
|
||||
assert (
|
||||
parallel_request_handler.user_api_key_cache.get_cache(
|
||||
parallel_request_handler.internal_usage_cache.get_cache(
|
||||
key=request_count_api_key
|
||||
)["current_tpm"]
|
||||
> 0
|
||||
|
@ -583,7 +601,7 @@ async def test_streaming_router_call():
|
|||
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
|
||||
|
||||
assert (
|
||||
parallel_request_handler.user_api_key_cache.get_cache(
|
||||
parallel_request_handler.internal_usage_cache.get_cache(
|
||||
key=request_count_api_key
|
||||
)["current_requests"]
|
||||
== 1
|
||||
|
@ -601,7 +619,7 @@ async def test_streaming_router_call():
|
|||
continue
|
||||
await asyncio.sleep(1) # success is done in a separate thread
|
||||
assert (
|
||||
parallel_request_handler.user_api_key_cache.get_cache(
|
||||
parallel_request_handler.internal_usage_cache.get_cache(
|
||||
key=request_count_api_key
|
||||
)["current_requests"]
|
||||
== 0
|
||||
|
@ -661,7 +679,7 @@ async def test_streaming_router_tpm_limit():
|
|||
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
|
||||
|
||||
assert (
|
||||
parallel_request_handler.user_api_key_cache.get_cache(
|
||||
parallel_request_handler.internal_usage_cache.get_cache(
|
||||
key=request_count_api_key
|
||||
)["current_requests"]
|
||||
== 1
|
||||
|
@ -680,7 +698,7 @@ async def test_streaming_router_tpm_limit():
|
|||
await asyncio.sleep(5) # success is done in a separate thread
|
||||
|
||||
assert (
|
||||
parallel_request_handler.user_api_key_cache.get_cache(
|
||||
parallel_request_handler.internal_usage_cache.get_cache(
|
||||
key=request_count_api_key
|
||||
)["current_tpm"]
|
||||
> 0
|
||||
|
@ -738,7 +756,7 @@ async def test_bad_router_call():
|
|||
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
|
||||
|
||||
assert (
|
||||
parallel_request_handler.user_api_key_cache.get_cache( # type: ignore
|
||||
parallel_request_handler.internal_usage_cache.get_cache( # type: ignore
|
||||
key=request_count_api_key
|
||||
)["current_requests"]
|
||||
== 1
|
||||
|
@ -755,7 +773,7 @@ async def test_bad_router_call():
|
|||
except:
|
||||
pass
|
||||
assert (
|
||||
parallel_request_handler.user_api_key_cache.get_cache( # type: ignore
|
||||
parallel_request_handler.internal_usage_cache.get_cache( # type: ignore
|
||||
key=request_count_api_key
|
||||
)["current_requests"]
|
||||
== 0
|
||||
|
@ -814,7 +832,7 @@ async def test_bad_router_tpm_limit():
|
|||
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
|
||||
|
||||
assert (
|
||||
parallel_request_handler.user_api_key_cache.get_cache(
|
||||
parallel_request_handler.internal_usage_cache.get_cache(
|
||||
key=request_count_api_key
|
||||
)["current_requests"]
|
||||
== 1
|
||||
|
@ -833,7 +851,7 @@ async def test_bad_router_tpm_limit():
|
|||
await asyncio.sleep(1) # success is done in a separate thread
|
||||
|
||||
assert (
|
||||
parallel_request_handler.user_api_key_cache.get_cache(
|
||||
parallel_request_handler.internal_usage_cache.get_cache(
|
||||
key=request_count_api_key
|
||||
)["current_tpm"]
|
||||
== 0
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue