mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
add test for allowed routes
This commit is contained in:
parent
253ef5f995
commit
aae2ba208d
1 changed files with 41 additions and 2 deletions
|
@ -42,7 +42,7 @@ class Request:
|
|||
def test_check_valid_ip(
|
||||
allowed_ips: Optional[List[str]], client_ip: Optional[str], expected_result: bool
|
||||
):
|
||||
from litellm.proxy.auth.user_api_key_auth import _check_valid_ip
|
||||
from litellm.proxy.auth.auth_utils import _check_valid_ip
|
||||
|
||||
request = Request(client_ip)
|
||||
|
||||
|
@ -70,7 +70,7 @@ def test_check_valid_ip(
|
|||
def test_check_valid_ip_sent_with_x_forwarded_for(
|
||||
allowed_ips: Optional[List[str]], client_ip: Optional[str], expected_result: bool
|
||||
):
|
||||
from litellm.proxy.auth.user_api_key_auth import _check_valid_ip
|
||||
from litellm.proxy.auth.auth_utils import _check_valid_ip
|
||||
|
||||
request = Request(client_ip, headers={"X-Forwarded-For": client_ip})
|
||||
|
||||
|
@ -246,3 +246,42 @@ async def test_user_api_key_auth_fails_with_prohibited_params(prohibited_param):
|
|||
error_message = str(e.message)
|
||||
print("error message=", error_message)
|
||||
assert "is not allowed in request body" in error_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.parametrize(
|
||||
"route, should_raise_error",
|
||||
[
|
||||
("/embeddings", False),
|
||||
("/chat/completions", True),
|
||||
("/completions", True),
|
||||
("/models", True),
|
||||
("/v1/embeddings", True),
|
||||
],
|
||||
)
|
||||
async def test_auth_with_allowed_routes(route, should_raise_error):
|
||||
# Setup
|
||||
user_key = "sk-1234"
|
||||
|
||||
general_settings = {"allowed_routes": ["/embeddings"]}
|
||||
from fastapi import Request
|
||||
|
||||
from litellm.proxy import proxy_server
|
||||
|
||||
setattr(proxy_server, "master_key", "sk-1234")
|
||||
setattr(proxy_server, "general_settings", general_settings)
|
||||
|
||||
request = Request(scope={"type": "http"})
|
||||
request._url = URL(url=route)
|
||||
|
||||
if should_raise_error:
|
||||
try:
|
||||
await user_api_key_auth(request=request, api_key="Bearer " + user_key)
|
||||
pytest.fail("Expected this call to fail. User is over limit.")
|
||||
except Exception as e:
|
||||
print("error str=", str(e.message))
|
||||
error_str = str(e.message)
|
||||
assert "Route" in error_str and "not allowed" in error_str
|
||||
pass
|
||||
else:
|
||||
await user_api_key_auth(request=request, api_key="Bearer " + user_key)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue