mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
refactor(user_api_key_auth.py): move is_route_allowed to inside common_checks
ensures consistent behaviour inside api key + jwt routes
This commit is contained in:
parent
a23a7e1486
commit
c7f42747bf
5 changed files with 167 additions and 155 deletions
|
@ -14,7 +14,7 @@ import time
|
|||
import traceback
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, cast
|
||||
|
||||
from fastapi import status
|
||||
from fastapi import Request, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
|
@ -74,6 +74,7 @@ async def common_checks(
|
|||
llm_router: Optional[Router],
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
valid_token: Optional[UserAPIKeyAuth],
|
||||
request: Request,
|
||||
) -> bool:
|
||||
"""
|
||||
Common checks across jwt + key-based auth.
|
||||
|
@ -198,9 +199,134 @@ async def common_checks(
|
|||
user_object=user_object, route=route, request_body=request_body
|
||||
)
|
||||
|
||||
token_team = getattr(valid_token, "team_id", None)
|
||||
token_type: Literal["ui", "api"] = (
|
||||
"ui" if token_team is not None and token_team == "litellm-dashboard" else "api"
|
||||
)
|
||||
_is_route_allowed = _is_allowed_route(
|
||||
route=route,
|
||||
token_type=token_type,
|
||||
user_obj=user_object,
|
||||
request=request,
|
||||
request_data=request_body,
|
||||
valid_token=valid_token,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _is_ui_route(
|
||||
route: str,
|
||||
user_obj: Optional[LiteLLM_UserTable] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
- Check if the route is a UI used route
|
||||
"""
|
||||
# this token is only used for managing the ui
|
||||
allowed_routes = LiteLLMRoutes.ui_routes.value
|
||||
# check if the current route startswith any of the allowed routes
|
||||
if (
|
||||
route is not None
|
||||
and isinstance(route, str)
|
||||
and any(route.startswith(allowed_route) for allowed_route in allowed_routes)
|
||||
):
|
||||
# Do something if the current route starts with any of the allowed routes
|
||||
return True
|
||||
elif any(
|
||||
RouteChecks._route_matches_pattern(route=route, pattern=allowed_route)
|
||||
for allowed_route in allowed_routes
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _get_user_role(
|
||||
user_obj: Optional[LiteLLM_UserTable],
|
||||
) -> Optional[LitellmUserRoles]:
|
||||
if user_obj is None:
|
||||
return None
|
||||
|
||||
_user = user_obj
|
||||
|
||||
_user_role = _user.user_role
|
||||
try:
|
||||
role = LitellmUserRoles(_user_role)
|
||||
except ValueError:
|
||||
return LitellmUserRoles.INTERNAL_USER
|
||||
|
||||
return role
|
||||
|
||||
|
||||
def _is_api_route_allowed(
|
||||
route: str,
|
||||
request: Request,
|
||||
request_data: dict,
|
||||
valid_token: Optional[UserAPIKeyAuth],
|
||||
user_obj: Optional[LiteLLM_UserTable] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
- Route b/w api token check and normal token check
|
||||
"""
|
||||
_user_role = _get_user_role(user_obj=user_obj)
|
||||
|
||||
if valid_token is None:
|
||||
raise Exception("Invalid proxy server token passed. valid_token=None.")
|
||||
|
||||
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
|
||||
RouteChecks.non_proxy_admin_allowed_routes_check(
|
||||
user_obj=user_obj,
|
||||
_user_role=_user_role,
|
||||
route=route,
|
||||
request=request,
|
||||
request_data=request_data,
|
||||
valid_token=valid_token,
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def _is_user_proxy_admin(user_obj: Optional[LiteLLM_UserTable]):
|
||||
if user_obj is None:
|
||||
return False
|
||||
|
||||
if (
|
||||
user_obj.user_role is not None
|
||||
and user_obj.user_role == LitellmUserRoles.PROXY_ADMIN.value
|
||||
):
|
||||
return True
|
||||
|
||||
if (
|
||||
user_obj.user_role is not None
|
||||
and user_obj.user_role == LitellmUserRoles.PROXY_ADMIN.value
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _is_allowed_route(
|
||||
route: str,
|
||||
token_type: Literal["ui", "api"],
|
||||
request: Request,
|
||||
request_data: dict,
|
||||
valid_token: Optional[UserAPIKeyAuth],
|
||||
user_obj: Optional[LiteLLM_UserTable] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
- Route b/w ui token check and normal token check
|
||||
"""
|
||||
|
||||
if token_type == "ui" and _is_ui_route(route=route, user_obj=user_obj):
|
||||
return True
|
||||
else:
|
||||
return _is_api_route_allowed(
|
||||
route=route,
|
||||
request=request,
|
||||
request_data=request_data,
|
||||
valid_token=valid_token,
|
||||
user_obj=user_obj,
|
||||
)
|
||||
|
||||
|
||||
def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool:
|
||||
"""
|
||||
Return if a user is allowed to access route. Helper function for `allowed_routes_check`.
|
||||
|
|
|
@ -321,6 +321,7 @@ async def check_if_request_size_is_safe(request: Request) -> bool:
|
|||
from litellm.proxy.proxy_server import general_settings, premium_user
|
||||
|
||||
max_request_size_mb = general_settings.get("max_request_size_mb", None)
|
||||
|
||||
if max_request_size_mb is not None:
|
||||
# Check if premium user
|
||||
if premium_user is not True:
|
||||
|
|
|
@ -24,7 +24,6 @@ class RouteChecks:
|
|||
route: str,
|
||||
request: Request,
|
||||
valid_token: UserAPIKeyAuth,
|
||||
api_key: str,
|
||||
request_data: dict,
|
||||
):
|
||||
"""
|
||||
|
|
|
@ -25,7 +25,10 @@ from litellm.litellm_core_utils.dd_tracing import tracer
|
|||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.auth_checks import (
|
||||
_cache_key_object,
|
||||
_get_user_role,
|
||||
_handle_failed_db_connection_for_get_key_object,
|
||||
_is_allowed_route,
|
||||
_is_user_proxy_admin,
|
||||
_virtual_key_max_budget_check,
|
||||
_virtual_key_soft_budget_check,
|
||||
can_key_call_model,
|
||||
|
@ -98,86 +101,6 @@ def _get_bearer_token(
|
|||
return api_key
|
||||
|
||||
|
||||
def _is_ui_route(
|
||||
route: str,
|
||||
user_obj: Optional[LiteLLM_UserTable] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
- Check if the route is a UI used route
|
||||
"""
|
||||
# this token is only used for managing the ui
|
||||
allowed_routes = LiteLLMRoutes.ui_routes.value
|
||||
# check if the current route startswith any of the allowed routes
|
||||
if (
|
||||
route is not None
|
||||
and isinstance(route, str)
|
||||
and any(route.startswith(allowed_route) for allowed_route in allowed_routes)
|
||||
):
|
||||
# Do something if the current route starts with any of the allowed routes
|
||||
return True
|
||||
elif any(
|
||||
RouteChecks._route_matches_pattern(route=route, pattern=allowed_route)
|
||||
for allowed_route in allowed_routes
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_api_route_allowed(
|
||||
route: str,
|
||||
request: Request,
|
||||
request_data: dict,
|
||||
api_key: str,
|
||||
valid_token: Optional[UserAPIKeyAuth],
|
||||
user_obj: Optional[LiteLLM_UserTable] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
- Route b/w api token check and normal token check
|
||||
"""
|
||||
_user_role = _get_user_role(user_obj=user_obj)
|
||||
|
||||
if valid_token is None:
|
||||
raise Exception("Invalid proxy server token passed. valid_token=None.")
|
||||
|
||||
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
|
||||
RouteChecks.non_proxy_admin_allowed_routes_check(
|
||||
user_obj=user_obj,
|
||||
_user_role=_user_role,
|
||||
route=route,
|
||||
request=request,
|
||||
request_data=request_data,
|
||||
api_key=api_key,
|
||||
valid_token=valid_token,
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def _is_allowed_route(
|
||||
route: str,
|
||||
token_type: Literal["ui", "api"],
|
||||
request: Request,
|
||||
request_data: dict,
|
||||
api_key: str,
|
||||
valid_token: Optional[UserAPIKeyAuth],
|
||||
user_obj: Optional[LiteLLM_UserTable] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
- Route b/w ui token check and normal token check
|
||||
"""
|
||||
|
||||
if token_type == "ui" and _is_ui_route(route=route, user_obj=user_obj):
|
||||
return True
|
||||
else:
|
||||
return _is_api_route_allowed(
|
||||
route=route,
|
||||
request=request,
|
||||
request_data=request_data,
|
||||
api_key=api_key,
|
||||
valid_token=valid_token,
|
||||
user_obj=user_obj,
|
||||
)
|
||||
|
||||
|
||||
async def user_api_key_auth_websocket(websocket: WebSocket):
|
||||
# Accept the WebSocket connection
|
||||
|
||||
|
@ -328,6 +251,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
|||
parent_otel_span: Optional[Span] = None
|
||||
start_time = datetime.now()
|
||||
route: str = get_request_route(request=request)
|
||||
|
||||
try:
|
||||
|
||||
# get the request body
|
||||
|
@ -470,22 +394,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
|||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
# run through common checks
|
||||
_ = await common_checks(
|
||||
request_body=request_data,
|
||||
team_object=team_object,
|
||||
user_object=user_object,
|
||||
end_user_object=end_user_object,
|
||||
general_settings=general_settings,
|
||||
global_proxy_spend=global_proxy_spend,
|
||||
route=route,
|
||||
llm_router=llm_router,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
valid_token=None,
|
||||
)
|
||||
|
||||
# return UserAPIKeyAuth object
|
||||
return UserAPIKeyAuth(
|
||||
valid_token = UserAPIKeyAuth(
|
||||
api_key=None,
|
||||
team_id=team_id,
|
||||
team_tpm_limit=(
|
||||
|
@ -501,6 +411,23 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
|||
parent_otel_span=parent_otel_span,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
# run through common checks
|
||||
_ = await common_checks(
|
||||
request=request,
|
||||
request_body=request_data,
|
||||
team_object=team_object,
|
||||
user_object=user_object,
|
||||
end_user_object=end_user_object,
|
||||
general_settings=general_settings,
|
||||
global_proxy_spend=global_proxy_spend,
|
||||
route=route,
|
||||
llm_router=llm_router,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
valid_token=valid_token,
|
||||
)
|
||||
|
||||
# return UserAPIKeyAuth object
|
||||
return valid_token
|
||||
|
||||
#### ELSE ####
|
||||
## CHECK PASS-THROUGH ENDPOINTS ##
|
||||
|
@ -1038,6 +965,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
|||
)
|
||||
)
|
||||
_ = await common_checks(
|
||||
request=request,
|
||||
request_body=request_data,
|
||||
team_object=_team_obj,
|
||||
user_object=user_obj,
|
||||
|
@ -1075,23 +1003,6 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
|||
# check if token is from litellm-ui, litellm ui makes keys to allow users to login with sso. These keys can only be used for LiteLLM UI functions
|
||||
# sso/login, ui/login, /key functions and /user functions
|
||||
# this will never be allowed to call /chat/completions
|
||||
token_team = getattr(valid_token, "team_id", None)
|
||||
token_type: Literal["ui", "api"] = (
|
||||
"ui"
|
||||
if token_team is not None and token_team == "litellm-dashboard"
|
||||
else "api"
|
||||
)
|
||||
_is_route_allowed = _is_allowed_route(
|
||||
route=route,
|
||||
token_type=token_type,
|
||||
user_obj=user_obj,
|
||||
request=request,
|
||||
request_data=request_data,
|
||||
api_key=api_key,
|
||||
valid_token=valid_token,
|
||||
)
|
||||
if not _is_route_allowed:
|
||||
raise HTTPException(401, detail="Invalid route for UI token")
|
||||
|
||||
if valid_token is None:
|
||||
# No token was found when looking up in the DB
|
||||
|
@ -1242,42 +1153,6 @@ async def _return_user_api_key_auth_obj(
|
|||
return UserAPIKeyAuth(**user_api_key_kwargs)
|
||||
|
||||
|
||||
def _is_user_proxy_admin(user_obj: Optional[LiteLLM_UserTable]):
|
||||
if user_obj is None:
|
||||
return False
|
||||
|
||||
if (
|
||||
user_obj.user_role is not None
|
||||
and user_obj.user_role == LitellmUserRoles.PROXY_ADMIN.value
|
||||
):
|
||||
return True
|
||||
|
||||
if (
|
||||
user_obj.user_role is not None
|
||||
and user_obj.user_role == LitellmUserRoles.PROXY_ADMIN.value
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _get_user_role(
|
||||
user_obj: Optional[LiteLLM_UserTable],
|
||||
) -> Optional[LitellmUserRoles]:
|
||||
if user_obj is None:
|
||||
return None
|
||||
|
||||
_user = user_obj
|
||||
|
||||
_user_role = _user.user_role
|
||||
try:
|
||||
role = LitellmUserRoles(_user_role)
|
||||
except ValueError:
|
||||
return LitellmUserRoles.INTERNAL_USER
|
||||
|
||||
return role
|
||||
|
||||
|
||||
def get_api_key_from_custom_header(
|
||||
request: Request, custom_litellm_key_header_name: str
|
||||
) -> str:
|
||||
|
|
|
@ -4,6 +4,9 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
import litellm.proxy
|
||||
import litellm.proxy.proxy_server
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
@ -950,7 +953,7 @@ def test_get_model_from_request(route, request_data, expected_model):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jwt_non_admin_team_route_access():
|
||||
async def test_jwt_non_admin_team_route_access(monkeypatch):
|
||||
"""
|
||||
Test that a non-admin JWT user cannot access team management routes
|
||||
"""
|
||||
|
@ -958,6 +961,8 @@ async def test_jwt_non_admin_team_route_access():
|
|||
from starlette.datastructures import URL
|
||||
from unittest.mock import patch
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
import json
|
||||
from litellm.proxy._types import ProxyException
|
||||
|
||||
mock_jwt_response = {
|
||||
"is_proxy_admin": False,
|
||||
|
@ -973,9 +978,15 @@ async def test_jwt_non_admin_team_route_access():
|
|||
}
|
||||
|
||||
# Create request
|
||||
request = Request(scope={"type": "http"})
|
||||
request = Request(
|
||||
scope={"type": "http", "headers": [("Authorization", "Bearer fake.jwt.token")]}
|
||||
)
|
||||
request._url = URL(url="/team/new")
|
||||
|
||||
monkeypatch.setattr(
|
||||
litellm.proxy.proxy_server, "general_settings", {"enable_jwt_auth": True}
|
||||
)
|
||||
|
||||
# Mock JWTAuthManager.auth_builder
|
||||
with patch(
|
||||
"litellm.proxy.auth.handle_jwt.JWTAuthManager.auth_builder",
|
||||
|
@ -986,6 +997,6 @@ async def test_jwt_non_admin_team_route_access():
|
|||
pytest.fail(
|
||||
"Expected this call to fail. Non-admin user should not access team routes."
|
||||
)
|
||||
except HTTPException as e:
|
||||
assert e.status_code == 403
|
||||
assert "Unauthorized" in str(e.detail)
|
||||
except ProxyException as e:
|
||||
print("e", e)
|
||||
assert "Only proxy admin can be used to generate" in str(e.message)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue