From d0e6ca659fb92aae5427ba7da771b5164fb86e07 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 13 Aug 2024 15:54:53 -0700 Subject: [PATCH] add test for test_check_valid_ip_sent_with_x_forwarded_for --- litellm/proxy/auth/user_api_key_auth.py | 2 +- litellm/tests/test_user_api_key_auth.py | 33 +++++++++++++++++++++++-- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 5df9674172..9bbbc1a430 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -1221,7 +1221,7 @@ def _check_valid_ip( # if general_settings.get("use_x_forwarded_for") is True then use x-forwarded-for client_ip = None - if use_x_forwarded_for is True: + if use_x_forwarded_for is True and "x-forwarded-for" in request.headers: client_ip = request.headers["x-forwarded-for"] elif request.client is not None: client_ip = request.client.host diff --git a/litellm/tests/test_user_api_key_auth.py b/litellm/tests/test_user_api_key_auth.py index 33f055b37d..ad057ee572 100644 --- a/litellm/tests/test_user_api_key_auth.py +++ b/litellm/tests/test_user_api_key_auth.py @@ -7,7 +7,7 @@ import sys sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path -from typing import List, Optional +from typing import Dict, List, Optional from unittest.mock import MagicMock import pytest @@ -16,9 +16,10 @@ import litellm class Request: - def __init__(self, client_ip: Optional[str] = None): + def __init__(self, client_ip: Optional[str] = None, headers: Optional[dict] = None): self.client = MagicMock() self.client.host = client_ip + self.headers: Dict[str, str] = {} @pytest.mark.parametrize( @@ -46,6 +47,34 @@ def test_check_valid_ip( assert _check_valid_ip(allowed_ips, request) == expected_result # type: ignore +# test x-forwarder for is used when user has opted in + + +@pytest.mark.parametrize( + "allowed_ips, client_ip, expected_result", + [ + (None, "127.0.0.1", True), # No IP restrictions, should be allowed + (["127.0.0.1"], "127.0.0.1", True), # IP in allowed list + (["192.168.1.1"], "127.0.0.1", False), # IP not in allowed list + ([], "127.0.0.1", False), # Empty allowed list, no IP should be allowed + (["192.168.1.1", "10.0.0.1"], "10.0.0.1", True), # IP in allowed list + ( + ["192.168.1.1"], + None, + False, + ), # Request with no client IP should not be allowed + ], +) +def test_check_valid_ip_sent_with_x_forwarded_for( + allowed_ips: Optional[List[str]], client_ip: Optional[str], expected_result: bool +): + from litellm.proxy.auth.user_api_key_auth import _check_valid_ip + + request = Request(client_ip, headers={"X-Forwarded-For": client_ip}) + + assert _check_valid_ip(allowed_ips, request, use_x_forwarded_for=True) == expected_result # type: ignore + + @pytest.mark.asyncio async def test_check_blocked_team(): """