forked from phoenix/litellm-mirror
Merge pull request #4615 from BerriAI/litellm_user_api_key_auth
Enable `allowed_ip's` for proxy
This commit is contained in:
commit
a986413df3
5 changed files with 93 additions and 2 deletions
|
@ -1108,7 +1108,10 @@ class AzureChatCompletion(BaseLLM):
|
||||||
"api-key": api_key,
|
"api-key": api_key,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
if "operation-location" in response.headers:
|
||||||
operation_location_url = response.headers["operation-location"]
|
operation_location_url = response.headers["operation-location"]
|
||||||
|
else:
|
||||||
|
raise AzureOpenAIError(status_code=500, message=response.text)
|
||||||
response = await async_handler.get(
|
response = await async_handler.get(
|
||||||
url=operation_location_url,
|
url=operation_location_url,
|
||||||
headers={
|
headers={
|
||||||
|
@ -1220,7 +1223,10 @@ class AzureChatCompletion(BaseLLM):
|
||||||
"api-key": api_key,
|
"api-key": api_key,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
if "operation-location" in response.headers:
|
||||||
operation_location_url = response.headers["operation-location"]
|
operation_location_url = response.headers["operation-location"]
|
||||||
|
else:
|
||||||
|
raise AzureOpenAIError(status_code=500, message=response.text)
|
||||||
response = sync_handler.get(
|
response = sync_handler.get(
|
||||||
url=operation_location_url,
|
url=operation_location_url,
|
||||||
headers={
|
headers={
|
||||||
|
|
|
@ -19,3 +19,4 @@ model_list:
|
||||||
general_settings:
|
general_settings:
|
||||||
alerting: ["slack"]
|
alerting: ["slack"]
|
||||||
alerting_threshold: 10
|
alerting_threshold: 10
|
||||||
|
allowed_ips: ["192.168.1.1"]
|
|
@ -136,6 +136,19 @@ async def user_api_key_auth(
|
||||||
enable_jwt_auth: true
|
enable_jwt_auth: true
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
### FILTER IP ADDRESS ###
|
||||||
|
|
||||||
|
is_valid_ip = _check_valid_ip(
|
||||||
|
allowed_ips=general_settings.get("allowed_ips", None), request=request
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_valid_ip:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Access forbidden: IP address not allowed.",
|
||||||
|
)
|
||||||
|
|
||||||
route: str = request.url.path
|
route: str = request.url.path
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -1208,3 +1221,22 @@ def _get_user_role(user_id_information: Optional[list]):
|
||||||
|
|
||||||
_user = user_id_information[0]
|
_user = user_id_information[0]
|
||||||
return _user.get("user_role")
|
return _user.get("user_role")
|
||||||
|
|
||||||
|
|
||||||
|
def _check_valid_ip(allowed_ips: Optional[List[str]], request: Request) -> bool:
|
||||||
|
"""
|
||||||
|
Returns if ip is allowed or not
|
||||||
|
"""
|
||||||
|
if allowed_ips is None: # if not set, assume true
|
||||||
|
return True
|
||||||
|
|
||||||
|
if request.client is not None:
|
||||||
|
client_ip = request.client.host
|
||||||
|
else:
|
||||||
|
client_ip = None
|
||||||
|
|
||||||
|
# Check if IP address is allowed
|
||||||
|
if client_ip not in allowed_ips:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
|
@ -1690,6 +1690,12 @@ class ProxyConfig:
|
||||||
ui_access_mode = general_settings.get(
|
ui_access_mode = general_settings.get(
|
||||||
"ui_access_mode", "all"
|
"ui_access_mode", "all"
|
||||||
) # can be either ["admin_only" or "all"]
|
) # can be either ["admin_only" or "all"]
|
||||||
|
### ALLOWED IP ###
|
||||||
|
allowed_ips = general_settings.get("allowed_ips", None)
|
||||||
|
if allowed_ips is not None and premium_user is False:
|
||||||
|
raise ValueError(
|
||||||
|
"allowed_ips is an Enterprise Feature. Please add a valid LITELLM_LICENSE to your envionment."
|
||||||
|
)
|
||||||
## BUDGET RESCHEDULER ##
|
## BUDGET RESCHEDULER ##
|
||||||
proxy_budget_rescheduler_min_time = general_settings.get(
|
proxy_budget_rescheduler_min_time = general_settings.get(
|
||||||
"proxy_budget_rescheduler_min_time", proxy_budget_rescheduler_min_time
|
"proxy_budget_rescheduler_min_time", proxy_budget_rescheduler_min_time
|
||||||
|
|
46
litellm/tests/test_user_api_key_auth.py
Normal file
46
litellm/tests/test_user_api_key_auth.py
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
# What is this?
|
||||||
|
## Unit tests for user_api_key_auth helper functions
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
from typing import List, Optional
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
|
class Request:
|
||||||
|
def __init__(self, client_ip: Optional[str] = None):
|
||||||
|
self.client = MagicMock()
|
||||||
|
self.client.host = client_ip
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"allowed_ips, client_ip, expected_result",
|
||||||
|
[
|
||||||
|
(None, "127.0.0.1", True), # No IP restrictions, should be allowed
|
||||||
|
(["127.0.0.1"], "127.0.0.1", True), # IP in allowed list
|
||||||
|
(["192.168.1.1"], "127.0.0.1", False), # IP not in allowed list
|
||||||
|
([], "127.0.0.1", False), # Empty allowed list, no IP should be allowed
|
||||||
|
(["192.168.1.1", "10.0.0.1"], "10.0.0.1", True), # IP in allowed list
|
||||||
|
(
|
||||||
|
["192.168.1.1"],
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
), # Request with no client IP should not be allowed
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_check_valid_ip(
|
||||||
|
allowed_ips: Optional[List[str]], client_ip: Optional[str], expected_result: bool
|
||||||
|
):
|
||||||
|
from litellm.proxy.auth.user_api_key_auth import _check_valid_ip
|
||||||
|
|
||||||
|
request = Request(client_ip)
|
||||||
|
|
||||||
|
assert _check_valid_ip(allowed_ips, request) == expected_result # type: ignore
|
Loading…
Add table
Add a link
Reference in a new issue