mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
add test for test_check_valid_ip_sent_with_x_forwarded_for
This commit is contained in:
parent
b94c982ec9
commit
d0e6ca659f
2 changed files with 32 additions and 3 deletions
|
@ -1221,7 +1221,7 @@ def _check_valid_ip(
|
||||||
|
|
||||||
# if general_settings.get("use_x_forwarded_for") is True then use x-forwarded-for
|
# if general_settings.get("use_x_forwarded_for") is True then use x-forwarded-for
|
||||||
client_ip = None
|
client_ip = None
|
||||||
if use_x_forwarded_for is True:
|
if use_x_forwarded_for is True and "x-forwarded-for" in request.headers:
|
||||||
client_ip = request.headers["x-forwarded-for"]
|
client_ip = request.headers["x-forwarded-for"]
|
||||||
elif request.client is not None:
|
elif request.client is not None:
|
||||||
client_ip = request.client.host
|
client_ip = request.client.host
|
||||||
|
|
|
@ -7,7 +7,7 @@ import sys
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
from typing import List, Optional
|
from typing import Dict, List, Optional
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -16,9 +16,10 @@ import litellm
|
||||||
|
|
||||||
|
|
||||||
class Request:
|
class Request:
|
||||||
def __init__(self, client_ip: Optional[str] = None):
|
def __init__(self, client_ip: Optional[str] = None, headers: Optional[dict] = None):
|
||||||
self.client = MagicMock()
|
self.client = MagicMock()
|
||||||
self.client.host = client_ip
|
self.client.host = client_ip
|
||||||
|
self.headers: Dict[str, str] = {}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -46,6 +47,34 @@ def test_check_valid_ip(
|
||||||
assert _check_valid_ip(allowed_ips, request) == expected_result # type: ignore
|
assert _check_valid_ip(allowed_ips, request) == expected_result # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
# test x-forwarder for is used when user has opted in
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"allowed_ips, client_ip, expected_result",
|
||||||
|
[
|
||||||
|
(None, "127.0.0.1", True), # No IP restrictions, should be allowed
|
||||||
|
(["127.0.0.1"], "127.0.0.1", True), # IP in allowed list
|
||||||
|
(["192.168.1.1"], "127.0.0.1", False), # IP not in allowed list
|
||||||
|
([], "127.0.0.1", False), # Empty allowed list, no IP should be allowed
|
||||||
|
(["192.168.1.1", "10.0.0.1"], "10.0.0.1", True), # IP in allowed list
|
||||||
|
(
|
||||||
|
["192.168.1.1"],
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
), # Request with no client IP should not be allowed
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_check_valid_ip_sent_with_x_forwarded_for(
|
||||||
|
allowed_ips: Optional[List[str]], client_ip: Optional[str], expected_result: bool
|
||||||
|
):
|
||||||
|
from litellm.proxy.auth.user_api_key_auth import _check_valid_ip
|
||||||
|
|
||||||
|
request = Request(client_ip, headers={"X-Forwarded-For": client_ip})
|
||||||
|
|
||||||
|
assert _check_valid_ip(allowed_ips, request, use_x_forwarded_for=True) == expected_result # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_check_blocked_team():
|
async def test_check_blocked_team():
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue