test(test_parallel_request_limiter.py): fix test

This commit is contained in:
Krrish Dholakia 2024-06-13 17:13:44 -07:00
parent 76c9b715f2
commit 8d56f72d5a

View file

@ -40,7 +40,9 @@ async def test_global_max_parallel_requests():
_api_key = hash_token("sk-12345") _api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=100) user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=100)
local_cache = DualCache() local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler() parallel_request_handler = MaxParallelRequestsHandler(
internal_usage_cache=local_cache
)
for _ in range(3): for _ in range(3):
try: try:
@ -68,7 +70,9 @@ async def test_pre_call_hook():
_api_key = hash_token("sk-12345") _api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1) user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache() local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler() parallel_request_handler = MaxParallelRequestsHandler(
internal_usage_cache=local_cache
)
await parallel_request_handler.async_pre_call_hook( await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
@ -81,10 +85,12 @@ async def test_pre_call_hook():
request_count_api_key = f"{_api_key}::{precise_minute}::request_count" request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
print( print(
parallel_request_handler.user_api_key_cache.get_cache(key=request_count_api_key) parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
)
) )
assert ( assert (
parallel_request_handler.user_api_key_cache.get_cache( parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key key=request_count_api_key
)["current_requests"] )["current_requests"]
== 1 == 1
@ -101,7 +107,9 @@ async def test_pre_call_hook_rpm_limits():
api_key=_api_key, max_parallel_requests=1, tpm_limit=9, rpm_limit=1 api_key=_api_key, max_parallel_requests=1, tpm_limit=9, rpm_limit=1
) )
local_cache = DualCache() local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler() parallel_request_handler = MaxParallelRequestsHandler(
internal_usage_cache=local_cache
)
await parallel_request_handler.async_pre_call_hook( await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
@ -148,7 +156,9 @@ async def test_pre_call_hook_team_rpm_limits():
team_id=_team_id, team_id=_team_id,
) )
local_cache = DualCache() local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler() parallel_request_handler = MaxParallelRequestsHandler(
internal_usage_cache=local_cache
)
await parallel_request_handler.async_pre_call_hook( await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
@ -194,7 +204,9 @@ async def test_pre_call_hook_tpm_limits():
api_key=_api_key, max_parallel_requests=1, tpm_limit=9, rpm_limit=10 api_key=_api_key, max_parallel_requests=1, tpm_limit=9, rpm_limit=10
) )
local_cache = DualCache() local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler() parallel_request_handler = MaxParallelRequestsHandler(
internal_usage_cache=local_cache
)
await parallel_request_handler.async_pre_call_hook( await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
@ -244,7 +256,9 @@ async def test_pre_call_hook_user_tpm_limits():
res = dict(user_api_key_dict) res = dict(user_api_key_dict)
print("dict user", res) print("dict user", res)
parallel_request_handler = MaxParallelRequestsHandler() parallel_request_handler = MaxParallelRequestsHandler(
internal_usage_cache=local_cache
)
await parallel_request_handler.async_pre_call_hook( await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
@ -287,7 +301,9 @@ async def test_success_call_hook():
_api_key = hash_token("sk-12345") _api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1) user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache() local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler() parallel_request_handler = MaxParallelRequestsHandler(
internal_usage_cache=local_cache
)
await parallel_request_handler.async_pre_call_hook( await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
@ -300,7 +316,7 @@ async def test_success_call_hook():
request_count_api_key = f"{_api_key}::{precise_minute}::request_count" request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
assert ( assert (
parallel_request_handler.user_api_key_cache.get_cache( parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key key=request_count_api_key
)["current_requests"] )["current_requests"]
== 1 == 1
@ -313,7 +329,7 @@ async def test_success_call_hook():
) )
assert ( assert (
parallel_request_handler.user_api_key_cache.get_cache( parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key key=request_count_api_key
)["current_requests"] )["current_requests"]
== 0 == 0
@ -329,7 +345,9 @@ async def test_failure_call_hook():
_api_key = hash_token(_api_key) _api_key = hash_token(_api_key)
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1) user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache() local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler() parallel_request_handler = MaxParallelRequestsHandler(
internal_usage_cache=local_cache
)
await parallel_request_handler.async_pre_call_hook( await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
@ -342,7 +360,7 @@ async def test_failure_call_hook():
request_count_api_key = f"{_api_key}::{precise_minute}::request_count" request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
assert ( assert (
parallel_request_handler.user_api_key_cache.get_cache( parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key key=request_count_api_key
)["current_requests"] )["current_requests"]
== 1 == 1
@ -358,7 +376,7 @@ async def test_failure_call_hook():
) )
assert ( assert (
parallel_request_handler.user_api_key_cache.get_cache( parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key key=request_count_api_key
)["current_requests"] )["current_requests"]
== 0 == 0
@ -423,7 +441,7 @@ async def test_normal_router_call():
request_count_api_key = f"{_api_key}::{precise_minute}::request_count" request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
assert ( assert (
parallel_request_handler.user_api_key_cache.get_cache( parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key key=request_count_api_key
)["current_requests"] )["current_requests"]
== 1 == 1
@ -440,7 +458,7 @@ async def test_normal_router_call():
print(f"response: {response}") print(f"response: {response}")
assert ( assert (
parallel_request_handler.user_api_key_cache.get_cache( parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key key=request_count_api_key
)["current_requests"] )["current_requests"]
== 0 == 0
@ -504,7 +522,7 @@ async def test_normal_router_tpm_limit():
print("Test: Checking current_requests for precise_minute=", precise_minute) print("Test: Checking current_requests for precise_minute=", precise_minute)
assert ( assert (
parallel_request_handler.user_api_key_cache.get_cache( parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key key=request_count_api_key
)["current_requests"] )["current_requests"]
== 1 == 1
@ -522,7 +540,7 @@ async def test_normal_router_tpm_limit():
try: try:
assert ( assert (
parallel_request_handler.user_api_key_cache.get_cache( parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key key=request_count_api_key
)["current_tpm"] )["current_tpm"]
> 0 > 0
@ -583,7 +601,7 @@ async def test_streaming_router_call():
request_count_api_key = f"{_api_key}::{precise_minute}::request_count" request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
assert ( assert (
parallel_request_handler.user_api_key_cache.get_cache( parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key key=request_count_api_key
)["current_requests"] )["current_requests"]
== 1 == 1
@ -601,7 +619,7 @@ async def test_streaming_router_call():
continue continue
await asyncio.sleep(1) # success is done in a separate thread await asyncio.sleep(1) # success is done in a separate thread
assert ( assert (
parallel_request_handler.user_api_key_cache.get_cache( parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key key=request_count_api_key
)["current_requests"] )["current_requests"]
== 0 == 0
@ -661,7 +679,7 @@ async def test_streaming_router_tpm_limit():
request_count_api_key = f"{_api_key}::{precise_minute}::request_count" request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
assert ( assert (
parallel_request_handler.user_api_key_cache.get_cache( parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key key=request_count_api_key
)["current_requests"] )["current_requests"]
== 1 == 1
@ -680,7 +698,7 @@ async def test_streaming_router_tpm_limit():
await asyncio.sleep(5) # success is done in a separate thread await asyncio.sleep(5) # success is done in a separate thread
assert ( assert (
parallel_request_handler.user_api_key_cache.get_cache( parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key key=request_count_api_key
)["current_tpm"] )["current_tpm"]
> 0 > 0
@ -738,7 +756,7 @@ async def test_bad_router_call():
request_count_api_key = f"{_api_key}::{precise_minute}::request_count" request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
assert ( assert (
parallel_request_handler.user_api_key_cache.get_cache( # type: ignore parallel_request_handler.internal_usage_cache.get_cache( # type: ignore
key=request_count_api_key key=request_count_api_key
)["current_requests"] )["current_requests"]
== 1 == 1
@ -755,7 +773,7 @@ async def test_bad_router_call():
except: except:
pass pass
assert ( assert (
parallel_request_handler.user_api_key_cache.get_cache( # type: ignore parallel_request_handler.internal_usage_cache.get_cache( # type: ignore
key=request_count_api_key key=request_count_api_key
)["current_requests"] )["current_requests"]
== 0 == 0
@ -814,7 +832,7 @@ async def test_bad_router_tpm_limit():
request_count_api_key = f"{_api_key}::{precise_minute}::request_count" request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
assert ( assert (
parallel_request_handler.user_api_key_cache.get_cache( parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key key=request_count_api_key
)["current_requests"] )["current_requests"]
== 1 == 1
@ -833,7 +851,7 @@ async def test_bad_router_tpm_limit():
await asyncio.sleep(1) # success is done in a separate thread await asyncio.sleep(1) # success is done in a separate thread
assert ( assert (
parallel_request_handler.user_api_key_cache.get_cache( parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key key=request_count_api_key
)["current_tpm"] )["current_tpm"]
== 0 == 0