fix(proxy_server.py): prevent non-admins from creating new keys

This commit is contained in:
Krrish Dholakia 2024-04-16 11:21:38 -07:00
parent 459c0e38d7
commit ffd3a96fcf
4 changed files with 50 additions and 61 deletions

View file

@ -10,9 +10,9 @@
"supports_function_calling": true
},
"gpt-4-turbo-preview": {
"max_tokens": 4096,
"max_input_tokens": 8192,
"max_output_tokens": 4096,
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00001,
"output_cost_per_token": 0.00003,
"litellm_provider": "openai",

View file

@ -103,6 +103,26 @@ class LiteLLMRoutes(enum.Enum):
"/model/info",
]
spend_tracking_routes: List = [
# spend
"/spend/keys",
"/spend/users",
"/spend/tags",
"/spend/calculate",
"/spend/logs",
]
global_spend_tracking_routes: List = [
# global spend
"/global/spend/logs",
"/global/spend",
"/global/spend/keys",
"/global/spend/teams",
"/global/spend/end_users",
"/global/spend/models",
"/global/predict/spend/logs",
]
public_routes: List = [
"/routes",
"/",
@ -114,6 +134,18 @@ class LiteLLMRoutes(enum.Enum):
]
# class LiteLLMAllowedRoutes(LiteLLMBase):
# """
# Defines allowed routes based on key type.
# Types = ["admin", "team", "user", "unmapped"]
# """
# admin_allowed_routes: List[
# Literal["openai_routes", "info_routes", "management_routes", "spend_tracking_routes", "global_spend_tracking_routes"]
# ] = ["management_routes"]
class LiteLLM_JWTAuth(LiteLLMBase):
"""
A class to define the roles and permissions for a LiteLLM Proxy w/ JWT Auth.

View file

@ -104,6 +104,13 @@ def common_checks(
def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool:
"""
Return if a user is allowed to access route. Helper function for `allowed_routes_check`.
Parameters:
- user_route: str - the route the user is trying to call
- allowed_routes: List[str|LiteLLMRoutes] - the list of allowed routes for the user.
"""
for allowed_route in allowed_routes:
if (
allowed_route == LiteLLMRoutes.openai_routes.name
@ -126,7 +133,7 @@ def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool:
def allowed_routes_check(
user_role: Literal["proxy_admin", "team"],
user_role: Literal["proxy_admin", "team", "user"],
user_route: str,
litellm_proxy_roles: LiteLLM_JWTAuth,
) -> bool:

View file

@ -625,6 +625,7 @@ async def user_api_key_auth(
# 6. If token spend per model is under budget per model
# 7. If token spend is under team budget
# 8. If team spend is under team budget
request_data = await _read_request_body(
request=request
) # request data, used across all checks. Making this easily available
@ -1009,23 +1010,9 @@ async def user_api_key_auth(
db=custom_db_client,
)
)
if (
(
route.startswith("/key/")
or route.startswith("/user/")
or route.startswith("/model/")
or route.startswith("/spend/")
)
and (not is_master_key_valid)
and (not _is_user_proxy_admin(user_id_information))
):
allow_user_auth = False
if (
general_settings.get("allow_user_auth", False) == True
or _has_user_setup_sso() == True
):
allow_user_auth = True # user can create and delete their own keys
# enters this block when allow_user_auth is set to False
if route in LiteLLMRoutes.info_routes.value and (
not _is_user_proxy_admin(user_id_information)
): # check if user allowed to call an info route
if route == "/key/info":
# check if user can access this route
query_params = request.query_params
@ -1050,47 +1037,12 @@ async def user_api_key_auth(
status_code=status.HTTP_403_FORBIDDEN,
detail="key not allowed to access this user's info",
)
elif route == "/user/update":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="only proxy admin can update user settings. Tried calling `/user/update`",
)
elif route == "/model/info":
# /model/info just shows models user has access to
pass
elif route == "/user/request_model":
pass # this allows any user to request a model through the UI
elif allow_user_auth == True and route == "/key/generate":
pass
elif allow_user_auth == True and route == "/key/delete":
pass
elif route == "/spend/logs":
# check if user can access this route
# user can only access this route if
# - api_key they need logs for has the same user_id as the one used for auth
query_params = request.query_params
if query_params.get("api_key") is not None:
api_key = query_params.get("api_key")
token_info = await prisma_client.get_data(
token=api_key, table_name="key", query_type="find_unique"
)
if secrets.compare_digest(
token_info.user_id, valid_token.user_id
):
pass
elif query_params.get("user_id") is not None:
user_id = query_params.get("user_id")
# check if user id == token.user_id
if secrets.compare_digest(user_id, valid_token.user_id):
pass
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="user not allowed to access this key's info",
)
else:
raise Exception(
f"Only master key can be used to generate, delete, update or get info for new keys/users. Value of allow_user_auth={allow_user_auth}"
f"Only master key can be used to generate, delete, update info for new keys/users."
)
# check if token is from litellm-ui, litellm ui makes keys to allow users to login with sso. These keys can only be used for LiteLLM UI functions
@ -2463,7 +2415,7 @@ class ProxyConfig:
if m.model_info is not None and isinstance(m.model_info, dict):
if "id" not in m.model_info:
m.model_info["id"] = m.model_id
combined_id_list.append(m.model_info)
combined_id_list.append(m.model_id)
else:
combined_id_list.append(m.model_id)
## CONFIG MODELS ##
@ -8147,8 +8099,6 @@ async def auth_callback(request: Request):
algorithm="HS256",
)
litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token
# if a user has logged in they should be allowed to create keys - this ensures that it's set to True
general_settings["allow_user_auth"] = True
return RedirectResponse(url=litellm_dashboard_ui)
@ -8860,7 +8810,7 @@ async def get_routes():
@router.get("/token/generate", dependencies=[Depends(user_api_key_auth)])
async def token_generate():
"""
Test endpoint. Meant for generating admin tokens with specific claims and testing if they work for creating keys, etc.
Test endpoint. Admin-only access. Meant for generating admin tokens with specific claims and testing if they work for creating keys, etc.
"""
# Initialize AuthJWTSSO with your OpenID Provider configuration
from fastapi_sso import AuthJWTSSO