From aae2ba208d5a4744c38eaa0ec57631157bf02ac6 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 3 Sep 2024 14:17:53 -0700 Subject: [PATCH] add test for allowed routes --- litellm/tests/test_user_api_key_auth.py | 43 +++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/litellm/tests/test_user_api_key_auth.py b/litellm/tests/test_user_api_key_auth.py index 5a292bb4a9..5f94c5d238 100644 --- a/litellm/tests/test_user_api_key_auth.py +++ b/litellm/tests/test_user_api_key_auth.py @@ -42,7 +42,7 @@ class Request: def test_check_valid_ip( allowed_ips: Optional[List[str]], client_ip: Optional[str], expected_result: bool ): - from litellm.proxy.auth.user_api_key_auth import _check_valid_ip + from litellm.proxy.auth.auth_utils import _check_valid_ip request = Request(client_ip) @@ -70,7 +70,7 @@ def test_check_valid_ip( def test_check_valid_ip_sent_with_x_forwarded_for( allowed_ips: Optional[List[str]], client_ip: Optional[str], expected_result: bool ): - from litellm.proxy.auth.user_api_key_auth import _check_valid_ip + from litellm.proxy.auth.auth_utils import _check_valid_ip request = Request(client_ip, headers={"X-Forwarded-For": client_ip}) @@ -246,3 +246,42 @@ async def test_user_api_key_auth_fails_with_prohibited_params(prohibited_param): error_message = str(e.message) print("error message=", error_message) assert "is not allowed in request body" in error_message + + +@pytest.mark.asyncio() +@pytest.mark.parametrize( + "route, should_raise_error", + [ + ("/embeddings", False), + ("/chat/completions", True), + ("/completions", True), + ("/models", True), + ("/v1/embeddings", True), + ], +) +async def test_auth_with_allowed_routes(route, should_raise_error): + # Setup + user_key = "sk-1234" + + general_settings = {"allowed_routes": ["/embeddings"]} + from fastapi import Request + + from litellm.proxy import proxy_server + + setattr(proxy_server, "master_key", "sk-1234") + setattr(proxy_server, "general_settings", general_settings) + + request = Request(scope={"type": "http"}) + request._url = URL(url=route) + + if should_raise_error: + try: + await user_api_key_auth(request=request, api_key="Bearer " + user_key) + pytest.fail("Expected this call to fail. User is over limit.") + except Exception as e: + print("error str=", str(e.message)) + error_str = str(e.message) + assert "Route" in error_str and "not allowed" in error_str + pass + else: + await user_api_key_auth(request=request, api_key="Bearer " + user_key)