add test for allowed routes

This commit is contained in:
Ishaan Jaff 2024-09-03 14:17:53 -07:00
parent 253ef5f995
commit aae2ba208d

View file

@ -42,7 +42,7 @@ class Request:
def test_check_valid_ip( def test_check_valid_ip(
allowed_ips: Optional[List[str]], client_ip: Optional[str], expected_result: bool allowed_ips: Optional[List[str]], client_ip: Optional[str], expected_result: bool
): ):
from litellm.proxy.auth.user_api_key_auth import _check_valid_ip from litellm.proxy.auth.auth_utils import _check_valid_ip
request = Request(client_ip) request = Request(client_ip)
@ -70,7 +70,7 @@ def test_check_valid_ip(
def test_check_valid_ip_sent_with_x_forwarded_for( def test_check_valid_ip_sent_with_x_forwarded_for(
allowed_ips: Optional[List[str]], client_ip: Optional[str], expected_result: bool allowed_ips: Optional[List[str]], client_ip: Optional[str], expected_result: bool
): ):
from litellm.proxy.auth.user_api_key_auth import _check_valid_ip from litellm.proxy.auth.auth_utils import _check_valid_ip
request = Request(client_ip, headers={"X-Forwarded-For": client_ip}) request = Request(client_ip, headers={"X-Forwarded-For": client_ip})
@ -246,3 +246,42 @@ async def test_user_api_key_auth_fails_with_prohibited_params(prohibited_param):
error_message = str(e.message) error_message = str(e.message)
print("error message=", error_message) print("error message=", error_message)
assert "is not allowed in request body" in error_message assert "is not allowed in request body" in error_message
@pytest.mark.asyncio()
@pytest.mark.parametrize(
"route, should_raise_error",
[
("/embeddings", False),
("/chat/completions", True),
("/completions", True),
("/models", True),
("/v1/embeddings", True),
],
)
async def test_auth_with_allowed_routes(route, should_raise_error):
# Setup
user_key = "sk-1234"
general_settings = {"allowed_routes": ["/embeddings"]}
from fastapi import Request
from litellm.proxy import proxy_server
setattr(proxy_server, "master_key", "sk-1234")
setattr(proxy_server, "general_settings", general_settings)
request = Request(scope={"type": "http"})
request._url = URL(url=route)
if should_raise_error:
try:
await user_api_key_auth(request=request, api_key="Bearer " + user_key)
pytest.fail("Expected this call to fail. User is over limit.")
except Exception as e:
print("error str=", str(e.message))
error_str = str(e.message)
assert "Route" in error_str and "not allowed" in error_str
pass
else:
await user_api_key_auth(request=request, api_key="Bearer " + user_key)