diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index d0c8eac4e..3c45fe028 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -800,6 +800,10 @@ class UserAPIKeyAuth( def check_api_key(cls, values): if values.get("api_key") is not None: values.update({"token": hash_token(values.get("api_key"))}) + if isinstance(values.get("api_key"), str) and values.get( + "api_key" + ).startswith("sk-"): + values.update({"api_key": hash_token(values.get("api_key"))}) return values diff --git a/litellm/tests/test_banned_keyword_list.py b/litellm/tests/test_banned_keyword_list.py index f8804df9a..54d8852e8 100644 --- a/litellm/tests/test_banned_keyword_list.py +++ b/litellm/tests/test_banned_keyword_list.py @@ -19,7 +19,7 @@ from litellm.proxy.enterprise.enterprise_hooks.banned_keywords import ( _ENTERPRISE_BannedKeywords, ) from litellm import Router, mock_completion -from litellm.proxy.utils import ProxyLogging +from litellm.proxy.utils import ProxyLogging, hash_token from litellm.proxy._types import UserAPIKeyAuth from litellm.caching import DualCache @@ -36,6 +36,7 @@ async def test_banned_keywords_check(): banned_keywords_obj = _ENTERPRISE_BannedKeywords() _api_key = "sk-12345" + _api_key = hash_token("sk-12345") user_api_key_dict = UserAPIKeyAuth(api_key=_api_key) local_cache = DualCache() diff --git a/litellm/tests/test_blocked_user_list.py b/litellm/tests/test_blocked_user_list.py index d3f9f6a1a..3c277a2d4 100644 --- a/litellm/tests/test_blocked_user_list.py +++ b/litellm/tests/test_blocked_user_list.py @@ -20,7 +20,7 @@ from litellm.proxy.enterprise.enterprise_hooks.blocked_user_list import ( _ENTERPRISE_BlockedUserList, ) from litellm import Router, mock_completion -from litellm.proxy.utils import ProxyLogging +from litellm.proxy.utils import ProxyLogging, hash_token from litellm.proxy._types import UserAPIKeyAuth from litellm.caching import DualCache from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token @@ -106,6 +106,7 @@ async def test_block_user_check(prisma_client): ) _api_key = "sk-12345" + _api_key = hash_token("sk-12345") user_api_key_dict = UserAPIKeyAuth(api_key=_api_key) local_cache = DualCache() diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index 2a7928743..367dd8072 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -1925,3 +1925,46 @@ async def test_proxy_load_test_db(prisma_client): raise Exception(f"it worked! key={key.key}") except Exception as e: pytest.fail(f"An exception occurred - {str(e)}") + + +@pytest.mark.asyncio() +async def test_master_key_hashing(prisma_client): + try: + + print("prisma client=", prisma_client) + + master_key = "sk-1234" + + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", master_key) + + await litellm.proxy.proxy_server.prisma_client.connect() + from litellm.proxy.proxy_server import user_api_key_cache + + _response = await new_user( + data=NewUserRequest( + models=["azure-gpt-3.5"], + team_id="ishaans-special-team", + tpm_limit=20, + ) + ) + print(_response) + assert _response.models == ["azure-gpt-3.5"] + assert _response.team_id == "ishaans-special-team" + assert _response.tpm_limit == 20 + + bearer_token = "Bearer " + master_key + + request = Request(scope={"type": "http"}) + request._url = URL(url="/chat/completions") + + # use generated key to auth in + result: UserAPIKeyAuth = await user_api_key_auth( + request=request, api_key=bearer_token + ) + + assert result.api_key == hash_token(master_key) + + except Exception as e: + print("Got Exception", e) + pytest.fail(f"Got exception {e}") diff --git a/litellm/tests/test_llm_guard.py b/litellm/tests/test_llm_guard.py index 97e8fd0ac..4775e065d 100644 --- a/litellm/tests/test_llm_guard.py +++ b/litellm/tests/test_llm_guard.py @@ -18,7 +18,7 @@ import pytest import litellm from litellm.proxy.enterprise.enterprise_hooks.llm_guard import _ENTERPRISE_LLMGuard from litellm import Router, mock_completion -from litellm.proxy.utils import ProxyLogging +from litellm.proxy.utils import ProxyLogging, hash_token from litellm.proxy._types import UserAPIKeyAuth from litellm.caching import DualCache @@ -40,6 +40,7 @@ async def test_llm_guard_valid_response(): ) _api_key = "sk-12345" + _api_key = hash_token("sk-12345") user_api_key_dict = UserAPIKeyAuth(api_key=_api_key) local_cache = DualCache() @@ -76,6 +77,7 @@ async def test_llm_guard_error_raising(): ) _api_key = "sk-12345" + _api_key = hash_token("sk-12345") user_api_key_dict = UserAPIKeyAuth(api_key=_api_key) local_cache = DualCache() diff --git a/litellm/tests/test_max_tpm_rpm_limiter.py b/litellm/tests/test_max_tpm_rpm_limiter.py index a906e2f8a..fbaf30c59 100644 --- a/litellm/tests/test_max_tpm_rpm_limiter.py +++ b/litellm/tests/test_max_tpm_rpm_limiter.py @@ -16,7 +16,7 @@ sys.path.insert( import pytest import litellm from litellm import Router -from litellm.proxy.utils import ProxyLogging +from litellm.proxy.utils import ProxyLogging, hash_token from litellm.proxy._types import UserAPIKeyAuth from litellm.caching import DualCache, RedisCache from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter @@ -29,7 +29,7 @@ async def test_pre_call_hook_rpm_limits(): Test if error raised on hitting rpm limits """ litellm.set_verbose = True - _api_key = "sk-12345" + _api_key = hash_token("sk-12345") user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, tpm_limit=9, rpm_limit=1) local_cache = DualCache() # redis_usage_cache = RedisCache() @@ -87,6 +87,7 @@ async def test_pre_call_hook_team_rpm_limits( "team_id": _team_id, } user_api_key_dict = UserAPIKeyAuth(**_user_api_key_dict) # type: ignore + _api_key = hash_token(_api_key) local_cache = DualCache() local_cache.set_cache(key=_api_key, value=_user_api_key_dict) internal_cache = DualCache(redis_cache=_redis_usage_cache) diff --git a/litellm/tests/test_parallel_request_limiter.py b/litellm/tests/test_parallel_request_limiter.py index 627e395cf..d0a28926e 100644 --- a/litellm/tests/test_parallel_request_limiter.py +++ b/litellm/tests/test_parallel_request_limiter.py @@ -15,7 +15,7 @@ sys.path.insert( import pytest import litellm from litellm import Router -from litellm.proxy.utils import ProxyLogging +from litellm.proxy.utils import ProxyLogging, hash_token from litellm.proxy._types import UserAPIKeyAuth from litellm.caching import DualCache from litellm.proxy.hooks.parallel_request_limiter import ( @@ -34,6 +34,7 @@ async def test_pre_call_hook(): Test if cache updated on call being received """ _api_key = "sk-12345" + _api_key = hash_token("sk-12345") user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1) local_cache = DualCache() parallel_request_handler = MaxParallelRequestsHandler() @@ -248,6 +249,7 @@ async def test_success_call_hook(): Test if on success, cache correctly decremented """ _api_key = "sk-12345" + _api_key = hash_token("sk-12345") user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1) local_cache = DualCache() parallel_request_handler = MaxParallelRequestsHandler() @@ -289,6 +291,7 @@ async def test_failure_call_hook(): Test if on failure, cache correctly decremented """ _api_key = "sk-12345" + _api_key = hash_token(_api_key) user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1) local_cache = DualCache() parallel_request_handler = MaxParallelRequestsHandler() @@ -366,6 +369,7 @@ async def test_normal_router_call(): ) # type: ignore _api_key = "sk-12345" + _api_key = hash_token(_api_key) user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1) local_cache = DualCache() pl = ProxyLogging(user_api_key_cache=local_cache) @@ -443,6 +447,7 @@ async def test_normal_router_tpm_limit(): ) # type: ignore _api_key = "sk-12345" + _api_key = hash_token("sk-12345") user_api_key_dict = UserAPIKeyAuth( api_key=_api_key, max_parallel_requests=10, tpm_limit=10 ) @@ -524,6 +529,7 @@ async def test_streaming_router_call(): ) # type: ignore _api_key = "sk-12345" + _api_key = hash_token("sk-12345") user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1) local_cache = DualCache() pl = ProxyLogging(user_api_key_cache=local_cache) @@ -599,6 +605,7 @@ async def test_streaming_router_tpm_limit(): ) # type: ignore _api_key = "sk-12345" + _api_key = hash_token("sk-12345") user_api_key_dict = UserAPIKeyAuth( api_key=_api_key, max_parallel_requests=10, tpm_limit=10 ) @@ -677,6 +684,7 @@ async def test_bad_router_call(): ) # type: ignore _api_key = "sk-12345" + _api_key = hash_token(_api_key) user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1) local_cache = DualCache() pl = ProxyLogging(user_api_key_cache=local_cache) @@ -750,6 +758,7 @@ async def test_bad_router_tpm_limit(): ) # type: ignore _api_key = "sk-12345" + _api_key = hash_token(_api_key) user_api_key_dict = UserAPIKeyAuth( api_key=_api_key, max_parallel_requests=10, tpm_limit=10 ) diff --git a/litellm/utils.py b/litellm/utils.py index 8e7c31867..d6548dc40 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -7820,6 +7820,15 @@ def exception_type( llm_provider="vertex_ai", response=original_exception.response, ) + elif "None Unknown Error." in error_str: + exception_mapping_worked = True + raise APIError( + message=f"VertexAIException - {error_str}", + status_code=500, + model=model, + llm_provider="vertex_ai", + request=original_exception.request, + ) elif "403" in error_str: exception_mapping_worked = True raise BadRequestError(