From f3cc57bc6f49b5deb80f4c80b803ba3fcbe0242c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 8 Jul 2024 15:58:15 -0700 Subject: [PATCH 1/3] feat(user_api_key_auth.py): allow restricting calls by IP address Allows admin to restrict which IP addresses can make calls to the proxy --- litellm/proxy/auth/user_api_key_auth.py | 32 +++++++++++++++++ litellm/tests/test_user_api_key_auth.py | 46 +++++++++++++++++++++++++ 2 files changed, 78 insertions(+) create mode 100644 litellm/tests/test_user_api_key_auth.py diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index bc774816f..c3c6f9182 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -136,6 +136,19 @@ async def user_api_key_auth( enable_jwt_auth: true ``` """ + + ### FILTER IP ADDRESS ### + + is_valid_ip = _check_valid_ip( + allowed_ips=general_settings.get("allowed_ips", None), request=request + ) + + if not is_valid_ip: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Access forbidden: IP address not allowed.", + ) + route: str = request.url.path if ( @@ -1208,3 +1221,22 @@ def _get_user_role(user_id_information: Optional[list]): _user = user_id_information[0] return _user.get("user_role") + + +def _check_valid_ip(allowed_ips: Optional[List[str]], request: Request) -> bool: + """ + Returns if ip is allowed or not + """ + if allowed_ips is None: # if not set, assume true + return True + + if request.client is not None: + client_ip = request.client.host + else: + client_ip = None + + # Check if IP address is allowed + if client_ip not in allowed_ips: + return False + + return True diff --git a/litellm/tests/test_user_api_key_auth.py b/litellm/tests/test_user_api_key_auth.py new file mode 100644 index 000000000..8d3a0af53 --- /dev/null +++ b/litellm/tests/test_user_api_key_auth.py @@ -0,0 +1,46 @@ +# What is this? +## Unit tests for user_api_key_auth helper functions + +import os +import sys + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +from typing import List, Optional +from unittest.mock import MagicMock + +import pytest + +import litellm + + +class Request: + def __init__(self, client_ip: Optional[str] = None): + self.client = MagicMock() + self.client.host = client_ip + + +@pytest.mark.parametrize( + "allowed_ips, client_ip, expected_result", + [ + (None, "127.0.0.1", True), # No IP restrictions, should be allowed + (["127.0.0.1"], "127.0.0.1", True), # IP in allowed list + (["192.168.1.1"], "127.0.0.1", False), # IP not in allowed list + ([], "127.0.0.1", False), # Empty allowed list, no IP should be allowed + (["192.168.1.1", "10.0.0.1"], "10.0.0.1", True), # IP in allowed list + ( + ["192.168.1.1"], + None, + False, + ), # Request with no client IP should not be allowed + ], +) +def test_check_valid_ip( + allowed_ips: Optional[List[str]], client_ip: Optional[str], expected_result: bool +): + from litellm.proxy.auth.user_api_key_auth import _check_valid_ip + + request = Request(client_ip) + + assert _check_valid_ip(allowed_ips, request) == expected_result # type: ignore From 3045a2d9b3e0418f2cd42b8d13de68b9b48275f2 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 8 Jul 2024 16:04:44 -0700 Subject: [PATCH 2/3] fix(proxy_server.py): add license protection for 'allowed_ip' address feature --- litellm/proxy/_new_secret_config.yaml | 1 + litellm/proxy/proxy_server.py | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 1c0b70d79..956234f85 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -19,3 +19,4 @@ model_list: general_settings: alerting: ["slack"] alerting_threshold: 10 + allowed_ips: ["192.168.1.1"] \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 08994a957..58da69858 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1690,6 +1690,12 @@ class ProxyConfig: ui_access_mode = general_settings.get( "ui_access_mode", "all" ) # can be either ["admin_only" or "all"] + ### ALLOWED IP ### + allowed_ips = general_settings.get("allowed_ips", None) + if allowed_ips is not None and premium_user is False: + raise ValueError( + "allowed_ips is an Enterprise Feature. Please add a valid LITELLM_LICENSE to your envionment." + ) ## BUDGET RESCHEDULER ## proxy_budget_rescheduler_min_time = general_settings.get( "proxy_budget_rescheduler_min_time", proxy_budget_rescheduler_min_time From 381347082dbde57d7b60a7fdb1ff8b74b6746c42 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 8 Jul 2024 16:39:05 -0700 Subject: [PATCH 3/3] fix(azure.py): improve error handling for azure image gen responses --- litellm/llms/azure.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index f905d17a3..a2928cf20 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -1108,7 +1108,10 @@ class AzureChatCompletion(BaseLLM): "api-key": api_key, }, ) - operation_location_url = response.headers["operation-location"] + if "operation-location" in response.headers: + operation_location_url = response.headers["operation-location"] + else: + raise AzureOpenAIError(status_code=500, message=response.text) response = await async_handler.get( url=operation_location_url, headers={ @@ -1220,7 +1223,10 @@ class AzureChatCompletion(BaseLLM): "api-key": api_key, }, ) - operation_location_url = response.headers["operation-location"] + if "operation-location" in response.headers: + operation_location_url = response.headers["operation-location"] + else: + raise AzureOpenAIError(status_code=500, message=response.text) response = sync_handler.get( url=operation_location_url, headers={