add test for test_check_valid_ip_sent_with_x_forwarded_for

This commit is contained in:
Ishaan Jaff 2024-08-13 15:54:53 -07:00
parent b94c982ec9
commit d0e6ca659f
2 changed files with 32 additions and 3 deletions

View file

@ -1221,7 +1221,7 @@ def _check_valid_ip(
# if general_settings.get("use_x_forwarded_for") is True then use x-forwarded-for # if general_settings.get("use_x_forwarded_for") is True then use x-forwarded-for
client_ip = None client_ip = None
if use_x_forwarded_for is True: if use_x_forwarded_for is True and "x-forwarded-for" in request.headers:
client_ip = request.headers["x-forwarded-for"] client_ip = request.headers["x-forwarded-for"]
elif request.client is not None: elif request.client is not None:
client_ip = request.client.host client_ip = request.client.host

View file

@ -7,7 +7,7 @@ import sys
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
from typing import List, Optional from typing import Dict, List, Optional
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
@ -16,9 +16,10 @@ import litellm
class Request: class Request:
def __init__(self, client_ip: Optional[str] = None): def __init__(self, client_ip: Optional[str] = None, headers: Optional[dict] = None):
self.client = MagicMock() self.client = MagicMock()
self.client.host = client_ip self.client.host = client_ip
self.headers: Dict[str, str] = {}
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -46,6 +47,34 @@ def test_check_valid_ip(
assert _check_valid_ip(allowed_ips, request) == expected_result # type: ignore assert _check_valid_ip(allowed_ips, request) == expected_result # type: ignore
# test x-forwarder for is used when user has opted in
@pytest.mark.parametrize(
"allowed_ips, client_ip, expected_result",
[
(None, "127.0.0.1", True), # No IP restrictions, should be allowed
(["127.0.0.1"], "127.0.0.1", True), # IP in allowed list
(["192.168.1.1"], "127.0.0.1", False), # IP not in allowed list
([], "127.0.0.1", False), # Empty allowed list, no IP should be allowed
(["192.168.1.1", "10.0.0.1"], "10.0.0.1", True), # IP in allowed list
(
["192.168.1.1"],
None,
False,
), # Request with no client IP should not be allowed
],
)
def test_check_valid_ip_sent_with_x_forwarded_for(
allowed_ips: Optional[List[str]], client_ip: Optional[str], expected_result: bool
):
from litellm.proxy.auth.user_api_key_auth import _check_valid_ip
request = Request(client_ip, headers={"X-Forwarded-For": client_ip})
assert _check_valid_ip(allowed_ips, request, use_x_forwarded_for=True) == expected_result # type: ignore
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_check_blocked_team(): async def test_check_blocked_team():
""" """