mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
return detailed error message on check_valid_ip
This commit is contained in:
parent
9924ceac1c
commit
acb31c0acd
2 changed files with 9 additions and 9 deletions
|
@ -12,7 +12,7 @@ import json
|
||||||
import secrets
|
import secrets
|
||||||
import traceback
|
import traceback
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Optional
|
from typing import Optional, Tuple
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
|
@ -123,7 +123,7 @@ async def user_api_key_auth(
|
||||||
# Check 2. FILTER IP ADDRESS
|
# Check 2. FILTER IP ADDRESS
|
||||||
await check_if_request_size_is_safe(request=request)
|
await check_if_request_size_is_safe(request=request)
|
||||||
|
|
||||||
is_valid_ip = _check_valid_ip(
|
is_valid_ip, passed_in_ip = _check_valid_ip(
|
||||||
allowed_ips=general_settings.get("allowed_ips", None),
|
allowed_ips=general_settings.get("allowed_ips", None),
|
||||||
use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False),
|
use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False),
|
||||||
request=request,
|
request=request,
|
||||||
|
@ -132,7 +132,7 @@ async def user_api_key_auth(
|
||||||
if not is_valid_ip:
|
if not is_valid_ip:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail="Access forbidden: IP address not allowed.",
|
detail=f"Access forbidden: IP address {passed_in_ip} not allowed.",
|
||||||
)
|
)
|
||||||
|
|
||||||
pass_through_endpoints: Optional[List[dict]] = general_settings.get(
|
pass_through_endpoints: Optional[List[dict]] = general_settings.get(
|
||||||
|
@ -1212,12 +1212,12 @@ def _check_valid_ip(
|
||||||
allowed_ips: Optional[List[str]],
|
allowed_ips: Optional[List[str]],
|
||||||
request: Request,
|
request: Request,
|
||||||
use_x_forwarded_for: Optional[bool] = False,
|
use_x_forwarded_for: Optional[bool] = False,
|
||||||
) -> bool:
|
) -> Tuple[bool, Optional[str]]:
|
||||||
"""
|
"""
|
||||||
Returns if ip is allowed or not
|
Returns if ip is allowed or not
|
||||||
"""
|
"""
|
||||||
if allowed_ips is None: # if not set, assume true
|
if allowed_ips is None: # if not set, assume true
|
||||||
return True
|
return True, None
|
||||||
|
|
||||||
# if general_settings.get("use_x_forwarded_for") is True then use x-forwarded-for
|
# if general_settings.get("use_x_forwarded_for") is True then use x-forwarded-for
|
||||||
client_ip = None
|
client_ip = None
|
||||||
|
@ -1228,9 +1228,9 @@ def _check_valid_ip(
|
||||||
|
|
||||||
# Check if IP address is allowed
|
# Check if IP address is allowed
|
||||||
if client_ip not in allowed_ips:
|
if client_ip not in allowed_ips:
|
||||||
return False
|
return False, client_ip
|
||||||
|
|
||||||
return True
|
return True, client_ip
|
||||||
|
|
||||||
|
|
||||||
def get_api_key_from_custom_header(
|
def get_api_key_from_custom_header(
|
||||||
|
|
|
@ -44,7 +44,7 @@ def test_check_valid_ip(
|
||||||
|
|
||||||
request = Request(client_ip)
|
request = Request(client_ip)
|
||||||
|
|
||||||
assert _check_valid_ip(allowed_ips, request) == expected_result # type: ignore
|
assert _check_valid_ip(allowed_ips, request)[0] == expected_result # type: ignore
|
||||||
|
|
||||||
|
|
||||||
# test x-forwarder for is used when user has opted in
|
# test x-forwarder for is used when user has opted in
|
||||||
|
@ -72,7 +72,7 @@ def test_check_valid_ip_sent_with_x_forwarded_for(
|
||||||
|
|
||||||
request = Request(client_ip, headers={"X-Forwarded-For": client_ip})
|
request = Request(client_ip, headers={"X-Forwarded-For": client_ip})
|
||||||
|
|
||||||
assert _check_valid_ip(allowed_ips, request, use_x_forwarded_for=True) == expected_result # type: ignore
|
assert _check_valid_ip(allowed_ips, request, use_x_forwarded_for=True)[0] == expected_result # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue