diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 86fc8ea8b..ff35ef2e3 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -233,9 +233,13 @@ def usage_telemetry( ).start() -def _get_bearer_token(api_key: str): - assert api_key.startswith("Bearer ") # ensure Bearer token passed in - api_key = api_key.replace("Bearer ", "") # extract the token +def _get_bearer_token( + api_key: str, +): + if api_key.startswith("Bearer "): # ensure Bearer token passed in + api_key = api_key.replace("Bearer ", "") # extract the token + else: + api_key = "" return api_key @@ -253,11 +257,14 @@ async def user_api_key_auth( global master_key, prisma_client, llm_model_list, user_custom_auth, custom_db_client try: if isinstance(api_key, str): + passed_in_key = api_key api_key = _get_bearer_token(api_key=api_key) + ### USER-DEFINED AUTH FUNCTION ### if user_custom_auth is not None: response = await user_custom_auth(request=request, api_key=api_key) return UserAPIKeyAuth.model_validate(response) + ### LITELLM-DEFINED AUTH FUNCTION ### if master_key is None: if isinstance(api_key, str): @@ -265,6 +272,14 @@ async def user_api_key_auth( else: return UserAPIKeyAuth() + if api_key is None: + raise Exception("No API Key passed in. api_key is None") + if secrets.compare_digest(api_key, ""): + # missing 'Bearer ' prefix + raise Exception( + f"Malformed API Key passed in. Ensure Key has `Bearer ` prefix. Passed in: {passed_in_key}" + ) + route: str = request.url.path if route == "/user/auth": if general_settings.get("allow_user_auth", False) == True: diff --git a/litellm/tests/test_configs/custom_auth.py b/litellm/tests/test_configs/custom_auth.py index e4747ee53..1b6bec43b 100644 --- a/litellm/tests/test_configs/custom_auth.py +++ b/litellm/tests/test_configs/custom_auth.py @@ -9,8 +9,14 @@ load_dotenv() async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth: try: print(f"api_key: {api_key}") + if api_key == "": + raise Exception( + f"CustomAuth - Malformed API Key passed in. Ensure Key has `Bearer` prefix" + ) if api_key == f"{os.getenv('PROXY_MASTER_KEY')}-1234": return UserAPIKeyAuth(api_key=api_key) raise Exception - except: + except Exception as e: + if len(str(e)) > 0: + raise e raise Exception("Failed custom auth") diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index 728514342..9d4318fe7 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -1154,6 +1154,7 @@ async def test_key_name_null(prisma_client): """ setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + setattr(litellm.proxy.proxy_server, "general_settings", {"allow_user_auth": False}) await litellm.proxy.proxy_server.prisma_client.connect() try: request = GenerateKeyRequest() @@ -1212,3 +1213,107 @@ async def test_default_key_params(prisma_client): except Exception as e: print("Got Exception", e) pytest.fail(f"Got exception {e}") + + +def test_get_bearer_token(): + from litellm.proxy.proxy_server import _get_bearer_token + + # Test valid Bearer token + api_key = "Bearer valid_token" + result = _get_bearer_token(api_key) + assert result == "valid_token", f"Expected 'valid_token', got '{result}'" + + # Test empty API key + api_key = "" + result = _get_bearer_token(api_key) + assert result == "", f"Expected '', got '{result}'" + + # Test API key without Bearer prefix + api_key = "invalid_token" + result = _get_bearer_token(api_key) + assert result == "", f"Expected '', got '{result}'" + + # Test API key with Bearer prefix in lowercase + api_key = "bearer valid_token" + result = _get_bearer_token(api_key) + assert result == "", f"Expected '', got '{result}'" + + # Test API key with Bearer prefix and extra spaces + api_key = " Bearer valid_token " + result = _get_bearer_token(api_key) + assert result == "", f"Expected '', got '{result}'" + + # Test API key with Bearer prefix and no token + api_key = "Bearer sk-1234" + result = _get_bearer_token(api_key) + assert result == "sk-1234", f"Expected 'valid_token', got '{result}'" + + +@pytest.mark.asyncio +async def test_user_api_key_auth(prisma_client): + from litellm.proxy.proxy_server import ProxyException + + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + setattr(litellm.proxy.proxy_server, "general_settings", {"allow_user_auth": True}) + await litellm.proxy.proxy_server.prisma_client.connect() + + request = Request(scope={"type": "http"}) + request._url = URL(url="/chat/completions") + # Test case: No API Key passed in + try: + await user_api_key_auth(request, api_key=None) + pytest.fail(f"This should have failed!. IT's an invalid key") + except ProxyException as exc: + print(exc.message) + assert ( + exc.message == "Authentication Error, No API Key passed in. api_key is None" + ) + + # Test case: Malformed API Key (missing 'Bearer ' prefix) + try: + await user_api_key_auth(request, api_key="my_token") + pytest.fail(f"This should have failed!. IT's an invalid key") + except ProxyException as exc: + print(exc.message) + assert ( + exc.message + == "Authentication Error, Malformed API Key passed in. Ensure Key has `Bearer ` prefix. Passed in: my_token" + ) + + # Test case: User passes empty string API Key + try: + await user_api_key_auth(request, api_key="") + pytest.fail(f"This should have failed!. IT's an invalid key") + except ProxyException as exc: + print(exc.message) + assert ( + exc.message + == "Authentication Error, Malformed API Key passed in. Ensure Key has `Bearer ` prefix. Passed in: " + ) + + +@pytest.mark.asyncio +async def test_user_api_key_auth_without_master_key(prisma_client): + # if master key is not set, expect all calls to go through + try: + from litellm.proxy.proxy_server import ProxyException + + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", None) + setattr( + litellm.proxy.proxy_server, "general_settings", {"allow_user_auth": True} + ) + await litellm.proxy.proxy_server.prisma_client.connect() + + request = Request(scope={"type": "http"}) + request._url = URL(url="/chat/completions") + # Test case: No API Key passed in + + await user_api_key_auth(request, api_key=None) + await user_api_key_auth(request, api_key="my_token") + await user_api_key_auth(request, api_key="") + await user_api_key_auth(request, api_key="Bearer " + "1234") + except Exception as e: + print("Got Exception", e) + pytest.fail(f"Got exception {e}") diff --git a/litellm/tests/test_proxy_custom_auth.py b/litellm/tests/test_proxy_custom_auth.py index b6b833e17..55ab45624 100644 --- a/litellm/tests/test_proxy_custom_auth.py +++ b/litellm/tests/test_proxy_custom_auth.py @@ -65,3 +65,30 @@ def test_custom_auth(client): assert e.code == 401 assert e.message == "Authentication Error, Failed custom auth" pass + + +def test_custom_auth_bearer(client): + try: + # Your test data + test_data = { + "model": "openai-model", + "messages": [ + {"role": "user", "content": "hi"}, + ], + "max_tokens": 10, + } + # Your bearer token + token = os.getenv("PROXY_MASTER_KEY") + + headers = {"Authorization": f"WITHOUT BEAR Er {token}"} + response = client.post("/chat/completions", json=test_data, headers=headers) + pytest.fail("LiteLLM Proxy test failed. This request should have been rejected") + except Exception as e: + print(vars(e)) + print("got an exception") + assert e.code == 401 + assert ( + e.message + == "Authentication Error, CustomAuth - Malformed API Key passed in. Ensure Key has `Bearer` prefix" + ) + pass