return detailed error message on check_valid_ip

This commit is contained in:
Ishaan Jaff 2024-08-13 21:29:21 -07:00
parent 9924ceac1c
commit acb31c0acd
2 changed files with 9 additions and 9 deletions

View file

@ -12,7 +12,7 @@ import json
import secrets import secrets
import traceback import traceback
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Optional from typing import Optional, Tuple
from uuid import uuid4 from uuid import uuid4
import fastapi import fastapi
@ -123,7 +123,7 @@ async def user_api_key_auth(
# Check 2. FILTER IP ADDRESS # Check 2. FILTER IP ADDRESS
await check_if_request_size_is_safe(request=request) await check_if_request_size_is_safe(request=request)
is_valid_ip = _check_valid_ip( is_valid_ip, passed_in_ip = _check_valid_ip(
allowed_ips=general_settings.get("allowed_ips", None), allowed_ips=general_settings.get("allowed_ips", None),
use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False), use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False),
request=request, request=request,
@ -132,7 +132,7 @@ async def user_api_key_auth(
if not is_valid_ip: if not is_valid_ip:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="Access forbidden: IP address not allowed.", detail=f"Access forbidden: IP address {passed_in_ip} not allowed.",
) )
pass_through_endpoints: Optional[List[dict]] = general_settings.get( pass_through_endpoints: Optional[List[dict]] = general_settings.get(
@ -1212,12 +1212,12 @@ def _check_valid_ip(
allowed_ips: Optional[List[str]], allowed_ips: Optional[List[str]],
request: Request, request: Request,
use_x_forwarded_for: Optional[bool] = False, use_x_forwarded_for: Optional[bool] = False,
) -> bool: ) -> Tuple[bool, Optional[str]]:
""" """
Returns if ip is allowed or not Returns if ip is allowed or not
""" """
if allowed_ips is None: # if not set, assume true if allowed_ips is None: # if not set, assume true
return True return True, None
# if general_settings.get("use_x_forwarded_for") is True then use x-forwarded-for # if general_settings.get("use_x_forwarded_for") is True then use x-forwarded-for
client_ip = None client_ip = None
@ -1228,9 +1228,9 @@ def _check_valid_ip(
# Check if IP address is allowed # Check if IP address is allowed
if client_ip not in allowed_ips: if client_ip not in allowed_ips:
return False return False, client_ip
return True return True, client_ip
def get_api_key_from_custom_header( def get_api_key_from_custom_header(

View file

@ -44,7 +44,7 @@ def test_check_valid_ip(
request = Request(client_ip) request = Request(client_ip)
assert _check_valid_ip(allowed_ips, request) == expected_result # type: ignore assert _check_valid_ip(allowed_ips, request)[0] == expected_result # type: ignore
# test x-forwarder for is used when user has opted in # test x-forwarder for is used when user has opted in
@ -72,7 +72,7 @@ def test_check_valid_ip_sent_with_x_forwarded_for(
request = Request(client_ip, headers={"X-Forwarded-For": client_ip}) request = Request(client_ip, headers={"X-Forwarded-For": client_ip})
assert _check_valid_ip(allowed_ips, request, use_x_forwarded_for=True) == expected_result # type: ignore assert _check_valid_ip(allowed_ips, request, use_x_forwarded_for=True)[0] == expected_result # type: ignore
@pytest.mark.asyncio @pytest.mark.asyncio