forked from phoenix/litellm-mirror
feat(user_api_key_auth.py): allow restricting calls by IP address
Allows admin to restrict which IP addresses can make calls to the proxy
This commit is contained in:
parent
95739c3778
commit
f3cc57bc6f
2 changed files with 78 additions and 0 deletions
|
@ -136,6 +136,19 @@ async def user_api_key_auth(
|
|||
enable_jwt_auth: true
|
||||
```
|
||||
"""
|
||||
|
||||
### FILTER IP ADDRESS ###
|
||||
|
||||
is_valid_ip = _check_valid_ip(
|
||||
allowed_ips=general_settings.get("allowed_ips", None), request=request
|
||||
)
|
||||
|
||||
if not is_valid_ip:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access forbidden: IP address not allowed.",
|
||||
)
|
||||
|
||||
route: str = request.url.path
|
||||
|
||||
if (
|
||||
|
@ -1208,3 +1221,22 @@ def _get_user_role(user_id_information: Optional[list]):
|
|||
|
||||
_user = user_id_information[0]
|
||||
return _user.get("user_role")
|
||||
|
||||
|
||||
def _check_valid_ip(allowed_ips: Optional[List[str]], request: Request) -> bool:
|
||||
"""
|
||||
Returns if ip is allowed or not
|
||||
"""
|
||||
if allowed_ips is None: # if not set, assume true
|
||||
return True
|
||||
|
||||
if request.client is not None:
|
||||
client_ip = request.client.host
|
||||
else:
|
||||
client_ip = None
|
||||
|
||||
# Check if IP address is allowed
|
||||
if client_ip not in allowed_ips:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
|
46
litellm/tests/test_user_api_key_auth.py
Normal file
46
litellm/tests/test_user_api_key_auth.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
# What is this?
|
||||
## Unit tests for user_api_key_auth helper functions
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from typing import List, Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
class Request:
|
||||
def __init__(self, client_ip: Optional[str] = None):
|
||||
self.client = MagicMock()
|
||||
self.client.host = client_ip
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"allowed_ips, client_ip, expected_result",
|
||||
[
|
||||
(None, "127.0.0.1", True), # No IP restrictions, should be allowed
|
||||
(["127.0.0.1"], "127.0.0.1", True), # IP in allowed list
|
||||
(["192.168.1.1"], "127.0.0.1", False), # IP not in allowed list
|
||||
([], "127.0.0.1", False), # Empty allowed list, no IP should be allowed
|
||||
(["192.168.1.1", "10.0.0.1"], "10.0.0.1", True), # IP in allowed list
|
||||
(
|
||||
["192.168.1.1"],
|
||||
None,
|
||||
False,
|
||||
), # Request with no client IP should not be allowed
|
||||
],
|
||||
)
|
||||
def test_check_valid_ip(
|
||||
allowed_ips: Optional[List[str]], client_ip: Optional[str], expected_result: bool
|
||||
):
|
||||
from litellm.proxy.auth.user_api_key_auth import _check_valid_ip
|
||||
|
||||
request = Request(client_ip)
|
||||
|
||||
assert _check_valid_ip(allowed_ips, request) == expected_result # type: ignore
|
Loading…
Add table
Add a link
Reference in a new issue