test(test_parallel_request_limiter.py): fix test

This commit is contained in:
Krrish Dholakia 2024-06-13 17:13:44 -07:00
parent 76c9b715f2
commit 8d56f72d5a

View file

@ -40,7 +40,9 @@ async def test_global_max_parallel_requests():
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=100)
local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler()
parallel_request_handler = MaxParallelRequestsHandler(
internal_usage_cache=local_cache
)
for _ in range(3):
try:
@ -68,7 +70,9 @@ async def test_pre_call_hook():
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler()
parallel_request_handler = MaxParallelRequestsHandler(
internal_usage_cache=local_cache
)
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
@ -81,10 +85,12 @@ async def test_pre_call_hook():
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
print(
parallel_request_handler.user_api_key_cache.get_cache(key=request_count_api_key)
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
)
)
assert (
parallel_request_handler.user_api_key_cache.get_cache(
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
)["current_requests"]
== 1
@ -101,7 +107,9 @@ async def test_pre_call_hook_rpm_limits():
api_key=_api_key, max_parallel_requests=1, tpm_limit=9, rpm_limit=1
)
local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler()
parallel_request_handler = MaxParallelRequestsHandler(
internal_usage_cache=local_cache
)
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
@ -148,7 +156,9 @@ async def test_pre_call_hook_team_rpm_limits():
team_id=_team_id,
)
local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler()
parallel_request_handler = MaxParallelRequestsHandler(
internal_usage_cache=local_cache
)
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
@ -194,7 +204,9 @@ async def test_pre_call_hook_tpm_limits():
api_key=_api_key, max_parallel_requests=1, tpm_limit=9, rpm_limit=10
)
local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler()
parallel_request_handler = MaxParallelRequestsHandler(
internal_usage_cache=local_cache
)
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
@ -244,7 +256,9 @@ async def test_pre_call_hook_user_tpm_limits():
res = dict(user_api_key_dict)
print("dict user", res)
parallel_request_handler = MaxParallelRequestsHandler()
parallel_request_handler = MaxParallelRequestsHandler(
internal_usage_cache=local_cache
)
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
@ -287,7 +301,9 @@ async def test_success_call_hook():
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler()
parallel_request_handler = MaxParallelRequestsHandler(
internal_usage_cache=local_cache
)
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
@ -300,7 +316,7 @@ async def test_success_call_hook():
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
assert (
parallel_request_handler.user_api_key_cache.get_cache(
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
)["current_requests"]
== 1
@ -313,7 +329,7 @@ async def test_success_call_hook():
)
assert (
parallel_request_handler.user_api_key_cache.get_cache(
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
)["current_requests"]
== 0
@ -329,7 +345,9 @@ async def test_failure_call_hook():
_api_key = hash_token(_api_key)
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler()
parallel_request_handler = MaxParallelRequestsHandler(
internal_usage_cache=local_cache
)
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
@ -342,7 +360,7 @@ async def test_failure_call_hook():
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
assert (
parallel_request_handler.user_api_key_cache.get_cache(
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
)["current_requests"]
== 1
@ -358,7 +376,7 @@ async def test_failure_call_hook():
)
assert (
parallel_request_handler.user_api_key_cache.get_cache(
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
)["current_requests"]
== 0
@ -423,7 +441,7 @@ async def test_normal_router_call():
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
assert (
parallel_request_handler.user_api_key_cache.get_cache(
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
)["current_requests"]
== 1
@ -440,7 +458,7 @@ async def test_normal_router_call():
print(f"response: {response}")
assert (
parallel_request_handler.user_api_key_cache.get_cache(
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
)["current_requests"]
== 0
@ -504,7 +522,7 @@ async def test_normal_router_tpm_limit():
print("Test: Checking current_requests for precise_minute=", precise_minute)
assert (
parallel_request_handler.user_api_key_cache.get_cache(
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
)["current_requests"]
== 1
@ -522,7 +540,7 @@ async def test_normal_router_tpm_limit():
try:
assert (
parallel_request_handler.user_api_key_cache.get_cache(
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
)["current_tpm"]
> 0
@ -583,7 +601,7 @@ async def test_streaming_router_call():
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
assert (
parallel_request_handler.user_api_key_cache.get_cache(
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
)["current_requests"]
== 1
@ -601,7 +619,7 @@ async def test_streaming_router_call():
continue
await asyncio.sleep(1) # success is done in a separate thread
assert (
parallel_request_handler.user_api_key_cache.get_cache(
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
)["current_requests"]
== 0
@ -661,7 +679,7 @@ async def test_streaming_router_tpm_limit():
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
assert (
parallel_request_handler.user_api_key_cache.get_cache(
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
)["current_requests"]
== 1
@ -680,7 +698,7 @@ async def test_streaming_router_tpm_limit():
await asyncio.sleep(5) # success is done in a separate thread
assert (
parallel_request_handler.user_api_key_cache.get_cache(
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
)["current_tpm"]
> 0
@ -738,7 +756,7 @@ async def test_bad_router_call():
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
assert (
parallel_request_handler.user_api_key_cache.get_cache( # type: ignore
parallel_request_handler.internal_usage_cache.get_cache( # type: ignore
key=request_count_api_key
)["current_requests"]
== 1
@ -755,7 +773,7 @@ async def test_bad_router_call():
except:
pass
assert (
parallel_request_handler.user_api_key_cache.get_cache( # type: ignore
parallel_request_handler.internal_usage_cache.get_cache( # type: ignore
key=request_count_api_key
)["current_requests"]
== 0
@ -814,7 +832,7 @@ async def test_bad_router_tpm_limit():
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
assert (
parallel_request_handler.user_api_key_cache.get_cache(
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
)["current_requests"]
== 1
@ -833,7 +851,7 @@ async def test_bad_router_tpm_limit():
await asyncio.sleep(1) # success is done in a separate thread
assert (
parallel_request_handler.user_api_key_cache.get_cache(
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
)["current_tpm"]
== 0