test: refactor testing to handle the hash token fix

This commit is contained in:
Krrish Dholakia 2024-04-17 17:31:39 -07:00
parent bafb008b44
commit 473e667bdf
5 changed files with 19 additions and 6 deletions

View file

@ -19,7 +19,7 @@ from litellm.proxy.enterprise.enterprise_hooks.banned_keywords import (
_ENTERPRISE_BannedKeywords,
)
from litellm import Router, mock_completion
from litellm.proxy.utils import ProxyLogging
from litellm.proxy.utils import ProxyLogging, hash_token
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
@ -36,6 +36,7 @@ async def test_banned_keywords_check():
banned_keywords_obj = _ENTERPRISE_BannedKeywords()
_api_key = "sk-12345"
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
local_cache = DualCache()

View file

@ -20,7 +20,7 @@ from litellm.proxy.enterprise.enterprise_hooks.blocked_user_list import (
_ENTERPRISE_BlockedUserList,
)
from litellm import Router, mock_completion
from litellm.proxy.utils import ProxyLogging
from litellm.proxy.utils import ProxyLogging, hash_token
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token
@ -106,6 +106,7 @@ async def test_block_user_check(prisma_client):
)
_api_key = "sk-12345"
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
local_cache = DualCache()

View file

@ -18,7 +18,7 @@ import pytest
import litellm
from litellm.proxy.enterprise.enterprise_hooks.llm_guard import _ENTERPRISE_LLMGuard
from litellm import Router, mock_completion
from litellm.proxy.utils import ProxyLogging
from litellm.proxy.utils import ProxyLogging, hash_token
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
@ -40,6 +40,7 @@ async def test_llm_guard_valid_response():
)
_api_key = "sk-12345"
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
local_cache = DualCache()
@ -76,6 +77,7 @@ async def test_llm_guard_error_raising():
)
_api_key = "sk-12345"
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
local_cache = DualCache()

View file

@ -16,7 +16,7 @@ sys.path.insert(
import pytest
import litellm
from litellm import Router
from litellm.proxy.utils import ProxyLogging
from litellm.proxy.utils import ProxyLogging, hash_token
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache, RedisCache
from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter
@ -29,7 +29,7 @@ async def test_pre_call_hook_rpm_limits():
Test if error raised on hitting rpm limits
"""
litellm.set_verbose = True
_api_key = "sk-12345"
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, tpm_limit=9, rpm_limit=1)
local_cache = DualCache()
# redis_usage_cache = RedisCache()

View file

@ -15,7 +15,7 @@ sys.path.insert(
import pytest
import litellm
from litellm import Router
from litellm.proxy.utils import ProxyLogging
from litellm.proxy.utils import ProxyLogging, hash_token
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import (
@ -34,6 +34,7 @@ async def test_pre_call_hook():
Test if cache updated on call being received
"""
_api_key = "sk-12345"
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler()
@ -248,6 +249,7 @@ async def test_success_call_hook():
Test if on success, cache correctly decremented
"""
_api_key = "sk-12345"
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler()
@ -289,6 +291,7 @@ async def test_failure_call_hook():
Test if on failure, cache correctly decremented
"""
_api_key = "sk-12345"
_api_key = hash_token(_api_key)
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler()
@ -366,6 +369,7 @@ async def test_normal_router_call():
) # type: ignore
_api_key = "sk-12345"
_api_key = hash_token(_api_key)
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache()
pl = ProxyLogging(user_api_key_cache=local_cache)
@ -443,6 +447,7 @@ async def test_normal_router_tpm_limit():
) # type: ignore
_api_key = "sk-12345"
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key, max_parallel_requests=10, tpm_limit=10
)
@ -524,6 +529,7 @@ async def test_streaming_router_call():
) # type: ignore
_api_key = "sk-12345"
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache()
pl = ProxyLogging(user_api_key_cache=local_cache)
@ -599,6 +605,7 @@ async def test_streaming_router_tpm_limit():
) # type: ignore
_api_key = "sk-12345"
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key, max_parallel_requests=10, tpm_limit=10
)
@ -677,6 +684,7 @@ async def test_bad_router_call():
) # type: ignore
_api_key = "sk-12345"
_api_key = hash_token(_api_key)
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache()
pl = ProxyLogging(user_api_key_cache=local_cache)
@ -750,6 +758,7 @@ async def test_bad_router_tpm_limit():
) # type: ignore
_api_key = "sk-12345"
_api_key = hash_token(_api_key)
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key, max_parallel_requests=10, tpm_limit=10
)