test: refactor testing to handle the hash token fix

This commit is contained in:
Krrish Dholakia 2024-04-17 17:31:39 -07:00
parent bafb008b44
commit 473e667bdf
5 changed files with 19 additions and 6 deletions

View file

@ -15,7 +15,7 @@ sys.path.insert(
import pytest
import litellm
from litellm import Router
from litellm.proxy.utils import ProxyLogging
from litellm.proxy.utils import ProxyLogging, hash_token
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import (
@ -34,6 +34,7 @@ async def test_pre_call_hook():
Test if cache updated on call being received
"""
_api_key = "sk-12345"
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler()
@ -248,6 +249,7 @@ async def test_success_call_hook():
Test if on success, cache correctly decremented
"""
_api_key = "sk-12345"
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler()
@ -289,6 +291,7 @@ async def test_failure_call_hook():
Test if on failure, cache correctly decremented
"""
_api_key = "sk-12345"
_api_key = hash_token(_api_key)
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler()
@ -366,6 +369,7 @@ async def test_normal_router_call():
) # type: ignore
_api_key = "sk-12345"
_api_key = hash_token(_api_key)
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache()
pl = ProxyLogging(user_api_key_cache=local_cache)
@ -443,6 +447,7 @@ async def test_normal_router_tpm_limit():
) # type: ignore
_api_key = "sk-12345"
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key, max_parallel_requests=10, tpm_limit=10
)
@ -524,6 +529,7 @@ async def test_streaming_router_call():
) # type: ignore
_api_key = "sk-12345"
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache()
pl = ProxyLogging(user_api_key_cache=local_cache)
@ -599,6 +605,7 @@ async def test_streaming_router_tpm_limit():
) # type: ignore
_api_key = "sk-12345"
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key, max_parallel_requests=10, tpm_limit=10
)
@ -677,6 +684,7 @@ async def test_bad_router_call():
) # type: ignore
_api_key = "sk-12345"
_api_key = hash_token(_api_key)
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache()
pl = ProxyLogging(user_api_key_cache=local_cache)
@ -750,6 +758,7 @@ async def test_bad_router_tpm_limit():
) # type: ignore
_api_key = "sk-12345"
_api_key = hash_token(_api_key)
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key, max_parallel_requests=10, tpm_limit=10
)