mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
allow setting allowed routes on proxy
This commit is contained in:
parent
36ce43ed95
commit
253ef5f995
4 changed files with 122 additions and 74 deletions
|
@ -40,22 +40,6 @@ else:
|
||||||
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
|
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
|
||||||
|
|
||||||
|
|
||||||
def is_request_body_safe(request_body: dict) -> bool:
|
|
||||||
"""
|
|
||||||
Check if the request body is safe.
|
|
||||||
|
|
||||||
A malicious user can set the api_base to their own domain and invoke POST /chat/completions to intercept and steal the OpenAI API key.
|
|
||||||
Relevant issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997
|
|
||||||
"""
|
|
||||||
banned_params = ["api_base", "base_url"]
|
|
||||||
|
|
||||||
for param in banned_params:
|
|
||||||
if param in request_body:
|
|
||||||
raise ValueError(f"BadRequest: {param} is not allowed in request body")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def common_checks(
|
def common_checks(
|
||||||
request_body: dict,
|
request_body: dict,
|
||||||
team_object: Optional[LiteLLM_TeamTable],
|
team_object: Optional[LiteLLM_TeamTable],
|
||||||
|
|
|
@ -1,13 +1,123 @@
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import HTTPException, Request, status
|
||||||
|
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.proxy._types import *
|
from litellm.proxy._types import *
|
||||||
|
|
||||||
|
|
||||||
|
def _get_request_ip_address(
|
||||||
|
request: Request, use_x_forwarded_for: Optional[bool] = False
|
||||||
|
) -> Optional[str]:
|
||||||
|
|
||||||
|
client_ip = None
|
||||||
|
if use_x_forwarded_for is True and "x-forwarded-for" in request.headers:
|
||||||
|
client_ip = request.headers["x-forwarded-for"]
|
||||||
|
elif request.client is not None:
|
||||||
|
client_ip = request.client.host
|
||||||
|
else:
|
||||||
|
client_ip = ""
|
||||||
|
|
||||||
|
return client_ip
|
||||||
|
|
||||||
|
|
||||||
|
def _check_valid_ip(
|
||||||
|
allowed_ips: Optional[List[str]],
|
||||||
|
request: Request,
|
||||||
|
use_x_forwarded_for: Optional[bool] = False,
|
||||||
|
) -> Tuple[bool, Optional[str]]:
|
||||||
|
"""
|
||||||
|
Returns if ip is allowed or not
|
||||||
|
"""
|
||||||
|
if allowed_ips is None: # if not set, assume true
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
# if general_settings.get("use_x_forwarded_for") is True then use x-forwarded-for
|
||||||
|
client_ip = _get_request_ip_address(
|
||||||
|
request=request, use_x_forwarded_for=use_x_forwarded_for
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if IP address is allowed
|
||||||
|
if client_ip not in allowed_ips:
|
||||||
|
return False, client_ip
|
||||||
|
|
||||||
|
return True, client_ip
|
||||||
|
|
||||||
|
|
||||||
|
def is_request_body_safe(request_body: dict) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the request body is safe.
|
||||||
|
|
||||||
|
A malicious user can set the api_base to their own domain and invoke POST /chat/completions to intercept and steal the OpenAI API key.
|
||||||
|
Relevant issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997
|
||||||
|
"""
|
||||||
|
banned_params = ["api_base", "base_url"]
|
||||||
|
|
||||||
|
for param in banned_params:
|
||||||
|
if param in request_body:
|
||||||
|
raise ValueError(f"BadRequest: {param} is not allowed in request body")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def pre_db_read_auth_checks(
|
||||||
|
request: Request,
|
||||||
|
request_data: dict,
|
||||||
|
route: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
1. Checks if request size is under max_request_size_mb (if set)
|
||||||
|
2. Check if request body is safe (example user has not set api_base in request body)
|
||||||
|
3. Check if IP address is allowed (if set)
|
||||||
|
4. Check if request route is an allowed route on the proxy (if set)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- True
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
- HTTPException if request fails initial auth checks
|
||||||
|
"""
|
||||||
|
from litellm.proxy.proxy_server import general_settings, premium_user
|
||||||
|
|
||||||
|
# Check 1. request size
|
||||||
|
await check_if_request_size_is_safe(request=request)
|
||||||
|
|
||||||
|
# Check 2. Request body is safe
|
||||||
|
is_request_body_safe(request_body=request_data)
|
||||||
|
|
||||||
|
# Check 3. Check if IP address is allowed
|
||||||
|
is_valid_ip, passed_in_ip = _check_valid_ip(
|
||||||
|
allowed_ips=general_settings.get("allowed_ips", None),
|
||||||
|
use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False),
|
||||||
|
request=request,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_valid_ip:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=f"Access forbidden: IP address {passed_in_ip} not allowed.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check 4. Check if request route is an allowed route on the proxy
|
||||||
|
if "allowed_routes" in general_settings:
|
||||||
|
_allowed_routes = general_settings["allowed_routes"]
|
||||||
|
if premium_user is not True:
|
||||||
|
verbose_proxy_logger.error(
|
||||||
|
f"Trying to set allowed_routes. This is an Enterprise feature. {CommonProxyErrors.not_premium_user.value}"
|
||||||
|
)
|
||||||
|
if route not in _allowed_routes:
|
||||||
|
verbose_proxy_logger.error(
|
||||||
|
f"Route {route} not in allowed_routes={_allowed_routes}"
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=f"Access forbidden: Route {route} not allowed",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def route_in_additonal_public_routes(current_route: str):
|
def route_in_additonal_public_routes(current_route: str):
|
||||||
"""
|
"""
|
||||||
Helper to check if the user defined public_routes on config.yaml
|
Helper to check if the user defined public_routes on config.yaml
|
||||||
|
|
|
@ -54,14 +54,15 @@ from litellm.proxy.auth.auth_checks import (
|
||||||
get_org_object,
|
get_org_object,
|
||||||
get_team_object,
|
get_team_object,
|
||||||
get_user_object,
|
get_user_object,
|
||||||
is_request_body_safe,
|
|
||||||
log_to_opentelemetry,
|
log_to_opentelemetry,
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.auth_utils import (
|
from litellm.proxy.auth.auth_utils import (
|
||||||
|
_get_request_ip_address,
|
||||||
check_if_request_size_is_safe,
|
check_if_request_size_is_safe,
|
||||||
get_request_route,
|
get_request_route,
|
||||||
is_llm_api_route,
|
is_llm_api_route,
|
||||||
is_pass_through_provider_route,
|
is_pass_through_provider_route,
|
||||||
|
pre_db_read_auth_checks,
|
||||||
route_in_additonal_public_routes,
|
route_in_additonal_public_routes,
|
||||||
should_run_auth_on_pass_through_provider_route,
|
should_run_auth_on_pass_through_provider_route,
|
||||||
)
|
)
|
||||||
|
@ -128,25 +129,11 @@ async def user_api_key_auth(
|
||||||
route: str = get_request_route(request=request)
|
route: str = get_request_route(request=request)
|
||||||
# get the request body
|
# get the request body
|
||||||
request_data = await _read_request_body(request=request)
|
request_data = await _read_request_body(request=request)
|
||||||
is_request_body_safe(request_body=request_data)
|
await pre_db_read_auth_checks(
|
||||||
|
request_data=request_data,
|
||||||
### LiteLLM Enterprise Security Checks
|
|
||||||
# Check 1. Check if request size is under max_request_size_mb
|
|
||||||
# Check 2. FILTER IP ADDRESS
|
|
||||||
await check_if_request_size_is_safe(request=request)
|
|
||||||
|
|
||||||
is_valid_ip, passed_in_ip = _check_valid_ip(
|
|
||||||
allowed_ips=general_settings.get("allowed_ips", None),
|
|
||||||
use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False),
|
|
||||||
request=request,
|
request=request,
|
||||||
|
route=route,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not is_valid_ip:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail=f"Access forbidden: IP address {passed_in_ip} not allowed.",
|
|
||||||
)
|
|
||||||
|
|
||||||
pass_through_endpoints: Optional[List[dict]] = general_settings.get(
|
pass_through_endpoints: Optional[List[dict]] = general_settings.get(
|
||||||
"pass_through_endpoints", None
|
"pass_through_endpoints", None
|
||||||
)
|
)
|
||||||
|
@ -200,6 +187,7 @@ async def user_api_key_auth(
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
######## Route Checks Before Reading DB / Cache for "token" ################
|
||||||
if (
|
if (
|
||||||
route in LiteLLMRoutes.public_routes.value
|
route in LiteLLMRoutes.public_routes.value
|
||||||
or route_in_additonal_public_routes(current_route=route)
|
or route_in_additonal_public_routes(current_route=route)
|
||||||
|
@ -211,6 +199,9 @@ async def user_api_key_auth(
|
||||||
return UserAPIKeyAuth(
|
return UserAPIKeyAuth(
|
||||||
user_role=LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
user_role=LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
||||||
)
|
)
|
||||||
|
|
||||||
|
########## End of Route Checks Before Reading DB / Cache for "token" ########
|
||||||
|
|
||||||
if general_settings.get("enable_oauth2_auth", False) is True:
|
if general_settings.get("enable_oauth2_auth", False) is True:
|
||||||
# return UserAPIKeyAuth object
|
# return UserAPIKeyAuth object
|
||||||
# helper to check if the api_key is a valid oauth2 token
|
# helper to check if the api_key is a valid oauth2 token
|
||||||
|
@ -1282,44 +1273,6 @@ def _get_user_role(
|
||||||
return role
|
return role
|
||||||
|
|
||||||
|
|
||||||
def _get_request_ip_address(
|
|
||||||
request: Request, use_x_forwarded_for: Optional[bool] = False
|
|
||||||
) -> Optional[str]:
|
|
||||||
|
|
||||||
client_ip = None
|
|
||||||
if use_x_forwarded_for is True and "x-forwarded-for" in request.headers:
|
|
||||||
client_ip = request.headers["x-forwarded-for"]
|
|
||||||
elif request.client is not None:
|
|
||||||
client_ip = request.client.host
|
|
||||||
else:
|
|
||||||
client_ip = ""
|
|
||||||
|
|
||||||
return client_ip
|
|
||||||
|
|
||||||
|
|
||||||
def _check_valid_ip(
|
|
||||||
allowed_ips: Optional[List[str]],
|
|
||||||
request: Request,
|
|
||||||
use_x_forwarded_for: Optional[bool] = False,
|
|
||||||
) -> Tuple[bool, Optional[str]]:
|
|
||||||
"""
|
|
||||||
Returns if ip is allowed or not
|
|
||||||
"""
|
|
||||||
if allowed_ips is None: # if not set, assume true
|
|
||||||
return True, None
|
|
||||||
|
|
||||||
# if general_settings.get("use_x_forwarded_for") is True then use x-forwarded-for
|
|
||||||
client_ip = _get_request_ip_address(
|
|
||||||
request=request, use_x_forwarded_for=use_x_forwarded_for
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if IP address is allowed
|
|
||||||
if client_ip not in allowed_ips:
|
|
||||||
return False, client_ip
|
|
||||||
|
|
||||||
return True, client_ip
|
|
||||||
|
|
||||||
|
|
||||||
def get_api_key_from_custom_header(
|
def get_api_key_from_custom_header(
|
||||||
request: Request, custom_litellm_key_header_name: str
|
request: Request, custom_litellm_key_header_name: str
|
||||||
):
|
):
|
||||||
|
|
|
@ -12,4 +12,5 @@ litellm_settings:
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
master_key: sk-1234
|
master_key: sk-1234
|
||||||
custom_auth: example_config_yaml.custom_auth_basic.user_api_key_auth
|
custom_auth: example_config_yaml.custom_auth_basic.user_api_key_auth
|
||||||
|
allowed_routes: []
|
Loading…
Add table
Add a link
Reference in a new issue