mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
Merge pull request #9454 from BerriAI/litellm_dev_03_21_2025_p3
Fix route check for non-proxy admins on jwt auth
This commit is contained in:
commit
950edd76b3
10 changed files with 276 additions and 162 deletions
|
@ -15,4 +15,12 @@ router_settings:
|
||||||
redis_password: os.environ/REDIS_PASSWORD
|
redis_password: os.environ/REDIS_PASSWORD
|
||||||
redis_port: os.environ/REDIS_PORT
|
redis_port: os.environ/REDIS_PORT
|
||||||
|
|
||||||
|
general_settings:
|
||||||
|
enable_jwt_auth: True
|
||||||
|
litellm_jwtauth:
|
||||||
|
admin_jwt_scope: "ai.admin"
|
||||||
|
# team_id_jwt_field: "client_id" # 👈 CAN BE ANY FIELD
|
||||||
|
user_id_jwt_field: "sub" # 👈 CAN BE ANY FIELD
|
||||||
|
org_id_jwt_field: "org_id" # 👈 CAN BE ANY FIELD
|
||||||
|
end_user_id_jwt_field: "customer_id" # 👈 CAN BE ANY FIELD
|
||||||
|
user_id_upsert: True
|
|
@ -1631,7 +1631,7 @@ class LiteLLM_UserTable(LiteLLMPydanticObjectBase):
|
||||||
|
|
||||||
class LiteLLM_UserTableFiltered(BaseModel): # done to avoid exposing sensitive data
|
class LiteLLM_UserTableFiltered(BaseModel): # done to avoid exposing sensitive data
|
||||||
user_id: str
|
user_id: str
|
||||||
user_email: str
|
user_email: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class LiteLLM_UserTableWithKeyCount(LiteLLM_UserTable):
|
class LiteLLM_UserTableWithKeyCount(LiteLLM_UserTable):
|
||||||
|
|
|
@ -14,7 +14,7 @@ import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, cast
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, cast
|
||||||
|
|
||||||
from fastapi import status
|
from fastapi import Request, status
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
@ -74,6 +74,7 @@ async def common_checks(
|
||||||
llm_router: Optional[Router],
|
llm_router: Optional[Router],
|
||||||
proxy_logging_obj: ProxyLogging,
|
proxy_logging_obj: ProxyLogging,
|
||||||
valid_token: Optional[UserAPIKeyAuth],
|
valid_token: Optional[UserAPIKeyAuth],
|
||||||
|
request: Request,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Common checks across jwt + key-based auth.
|
Common checks across jwt + key-based auth.
|
||||||
|
@ -198,9 +199,134 @@ async def common_checks(
|
||||||
user_object=user_object, route=route, request_body=request_body
|
user_object=user_object, route=route, request_body=request_body
|
||||||
)
|
)
|
||||||
|
|
||||||
|
token_team = getattr(valid_token, "team_id", None)
|
||||||
|
token_type: Literal["ui", "api"] = (
|
||||||
|
"ui" if token_team is not None and token_team == "litellm-dashboard" else "api"
|
||||||
|
)
|
||||||
|
_is_route_allowed = _is_allowed_route(
|
||||||
|
route=route,
|
||||||
|
token_type=token_type,
|
||||||
|
user_obj=user_object,
|
||||||
|
request=request,
|
||||||
|
request_data=request_body,
|
||||||
|
valid_token=valid_token,
|
||||||
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _is_ui_route(
|
||||||
|
route: str,
|
||||||
|
user_obj: Optional[LiteLLM_UserTable] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
- Check if the route is a UI used route
|
||||||
|
"""
|
||||||
|
# this token is only used for managing the ui
|
||||||
|
allowed_routes = LiteLLMRoutes.ui_routes.value
|
||||||
|
# check if the current route startswith any of the allowed routes
|
||||||
|
if (
|
||||||
|
route is not None
|
||||||
|
and isinstance(route, str)
|
||||||
|
and any(route.startswith(allowed_route) for allowed_route in allowed_routes)
|
||||||
|
):
|
||||||
|
# Do something if the current route starts with any of the allowed routes
|
||||||
|
return True
|
||||||
|
elif any(
|
||||||
|
RouteChecks._route_matches_pattern(route=route, pattern=allowed_route)
|
||||||
|
for allowed_route in allowed_routes
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _get_user_role(
|
||||||
|
user_obj: Optional[LiteLLM_UserTable],
|
||||||
|
) -> Optional[LitellmUserRoles]:
|
||||||
|
if user_obj is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
_user = user_obj
|
||||||
|
|
||||||
|
_user_role = _user.user_role
|
||||||
|
try:
|
||||||
|
role = LitellmUserRoles(_user_role)
|
||||||
|
except ValueError:
|
||||||
|
return LitellmUserRoles.INTERNAL_USER
|
||||||
|
|
||||||
|
return role
|
||||||
|
|
||||||
|
|
||||||
|
def _is_api_route_allowed(
|
||||||
|
route: str,
|
||||||
|
request: Request,
|
||||||
|
request_data: dict,
|
||||||
|
valid_token: Optional[UserAPIKeyAuth],
|
||||||
|
user_obj: Optional[LiteLLM_UserTable] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
- Route b/w api token check and normal token check
|
||||||
|
"""
|
||||||
|
_user_role = _get_user_role(user_obj=user_obj)
|
||||||
|
|
||||||
|
if valid_token is None:
|
||||||
|
raise Exception("Invalid proxy server token passed. valid_token=None.")
|
||||||
|
|
||||||
|
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
|
||||||
|
RouteChecks.non_proxy_admin_allowed_routes_check(
|
||||||
|
user_obj=user_obj,
|
||||||
|
_user_role=_user_role,
|
||||||
|
route=route,
|
||||||
|
request=request,
|
||||||
|
request_data=request_data,
|
||||||
|
valid_token=valid_token,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _is_user_proxy_admin(user_obj: Optional[LiteLLM_UserTable]):
|
||||||
|
if user_obj is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if (
|
||||||
|
user_obj.user_role is not None
|
||||||
|
and user_obj.user_role == LitellmUserRoles.PROXY_ADMIN.value
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
if (
|
||||||
|
user_obj.user_role is not None
|
||||||
|
and user_obj.user_role == LitellmUserRoles.PROXY_ADMIN.value
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_allowed_route(
|
||||||
|
route: str,
|
||||||
|
token_type: Literal["ui", "api"],
|
||||||
|
request: Request,
|
||||||
|
request_data: dict,
|
||||||
|
valid_token: Optional[UserAPIKeyAuth],
|
||||||
|
user_obj: Optional[LiteLLM_UserTable] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
- Route b/w ui token check and normal token check
|
||||||
|
"""
|
||||||
|
|
||||||
|
if token_type == "ui" and _is_ui_route(route=route, user_obj=user_obj):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return _is_api_route_allowed(
|
||||||
|
route=route,
|
||||||
|
request=request,
|
||||||
|
request_data=request_data,
|
||||||
|
valid_token=valid_token,
|
||||||
|
user_obj=user_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool:
|
def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool:
|
||||||
"""
|
"""
|
||||||
Return if a user is allowed to access route. Helper function for `allowed_routes_check`.
|
Return if a user is allowed to access route. Helper function for `allowed_routes_check`.
|
||||||
|
|
|
@ -321,6 +321,7 @@ async def check_if_request_size_is_safe(request: Request) -> bool:
|
||||||
from litellm.proxy.proxy_server import general_settings, premium_user
|
from litellm.proxy.proxy_server import general_settings, premium_user
|
||||||
|
|
||||||
max_request_size_mb = general_settings.get("max_request_size_mb", None)
|
max_request_size_mb = general_settings.get("max_request_size_mb", None)
|
||||||
|
|
||||||
if max_request_size_mb is not None:
|
if max_request_size_mb is not None:
|
||||||
# Check if premium user
|
# Check if premium user
|
||||||
if premium_user is not True:
|
if premium_user is not True:
|
||||||
|
|
|
@ -24,7 +24,6 @@ class RouteChecks:
|
||||||
route: str,
|
route: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
valid_token: UserAPIKeyAuth,
|
valid_token: UserAPIKeyAuth,
|
||||||
api_key: str,
|
|
||||||
request_data: dict,
|
request_data: dict,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -25,7 +25,9 @@ from litellm.litellm_core_utils.dd_tracing import tracer
|
||||||
from litellm.proxy._types import *
|
from litellm.proxy._types import *
|
||||||
from litellm.proxy.auth.auth_checks import (
|
from litellm.proxy.auth.auth_checks import (
|
||||||
_cache_key_object,
|
_cache_key_object,
|
||||||
|
_get_user_role,
|
||||||
_handle_failed_db_connection_for_get_key_object,
|
_handle_failed_db_connection_for_get_key_object,
|
||||||
|
_is_user_proxy_admin,
|
||||||
_virtual_key_max_budget_check,
|
_virtual_key_max_budget_check,
|
||||||
_virtual_key_soft_budget_check,
|
_virtual_key_soft_budget_check,
|
||||||
can_key_call_model,
|
can_key_call_model,
|
||||||
|
@ -48,7 +50,6 @@ from litellm.proxy.auth.auth_utils import (
|
||||||
from litellm.proxy.auth.handle_jwt import JWTAuthManager, JWTHandler
|
from litellm.proxy.auth.handle_jwt import JWTAuthManager, JWTHandler
|
||||||
from litellm.proxy.auth.oauth2_check import check_oauth2_token
|
from litellm.proxy.auth.oauth2_check import check_oauth2_token
|
||||||
from litellm.proxy.auth.oauth2_proxy_hook import handle_oauth2_proxy_request
|
from litellm.proxy.auth.oauth2_proxy_hook import handle_oauth2_proxy_request
|
||||||
from litellm.proxy.auth.route_checks import RouteChecks
|
|
||||||
from litellm.proxy.auth.service_account_checks import service_account_checks
|
from litellm.proxy.auth.service_account_checks import service_account_checks
|
||||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||||
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
||||||
|
@ -98,86 +99,6 @@ def _get_bearer_token(
|
||||||
return api_key
|
return api_key
|
||||||
|
|
||||||
|
|
||||||
def _is_ui_route(
|
|
||||||
route: str,
|
|
||||||
user_obj: Optional[LiteLLM_UserTable] = None,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
- Check if the route is a UI used route
|
|
||||||
"""
|
|
||||||
# this token is only used for managing the ui
|
|
||||||
allowed_routes = LiteLLMRoutes.ui_routes.value
|
|
||||||
# check if the current route startswith any of the allowed routes
|
|
||||||
if (
|
|
||||||
route is not None
|
|
||||||
and isinstance(route, str)
|
|
||||||
and any(route.startswith(allowed_route) for allowed_route in allowed_routes)
|
|
||||||
):
|
|
||||||
# Do something if the current route starts with any of the allowed routes
|
|
||||||
return True
|
|
||||||
elif any(
|
|
||||||
RouteChecks._route_matches_pattern(route=route, pattern=allowed_route)
|
|
||||||
for allowed_route in allowed_routes
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _is_api_route_allowed(
|
|
||||||
route: str,
|
|
||||||
request: Request,
|
|
||||||
request_data: dict,
|
|
||||||
api_key: str,
|
|
||||||
valid_token: Optional[UserAPIKeyAuth],
|
|
||||||
user_obj: Optional[LiteLLM_UserTable] = None,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
- Route b/w api token check and normal token check
|
|
||||||
"""
|
|
||||||
_user_role = _get_user_role(user_obj=user_obj)
|
|
||||||
|
|
||||||
if valid_token is None:
|
|
||||||
raise Exception("Invalid proxy server token passed. valid_token=None.")
|
|
||||||
|
|
||||||
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
|
|
||||||
RouteChecks.non_proxy_admin_allowed_routes_check(
|
|
||||||
user_obj=user_obj,
|
|
||||||
_user_role=_user_role,
|
|
||||||
route=route,
|
|
||||||
request=request,
|
|
||||||
request_data=request_data,
|
|
||||||
api_key=api_key,
|
|
||||||
valid_token=valid_token,
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def _is_allowed_route(
|
|
||||||
route: str,
|
|
||||||
token_type: Literal["ui", "api"],
|
|
||||||
request: Request,
|
|
||||||
request_data: dict,
|
|
||||||
api_key: str,
|
|
||||||
valid_token: Optional[UserAPIKeyAuth],
|
|
||||||
user_obj: Optional[LiteLLM_UserTable] = None,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
- Route b/w ui token check and normal token check
|
|
||||||
"""
|
|
||||||
|
|
||||||
if token_type == "ui" and _is_ui_route(route=route, user_obj=user_obj):
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return _is_api_route_allowed(
|
|
||||||
route=route,
|
|
||||||
request=request,
|
|
||||||
request_data=request_data,
|
|
||||||
api_key=api_key,
|
|
||||||
valid_token=valid_token,
|
|
||||||
user_obj=user_obj,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def user_api_key_auth_websocket(websocket: WebSocket):
|
async def user_api_key_auth_websocket(websocket: WebSocket):
|
||||||
# Accept the WebSocket connection
|
# Accept the WebSocket connection
|
||||||
|
|
||||||
|
@ -328,6 +249,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
parent_otel_span: Optional[Span] = None
|
parent_otel_span: Optional[Span] = None
|
||||||
start_time = datetime.now()
|
start_time = datetime.now()
|
||||||
route: str = get_request_route(request=request)
|
route: str = get_request_route(request=request)
|
||||||
|
valid_token: Optional[UserAPIKeyAuth] = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# get the request body
|
# get the request body
|
||||||
|
@ -470,22 +393,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
parent_otel_span=parent_otel_span,
|
parent_otel_span=parent_otel_span,
|
||||||
)
|
)
|
||||||
# run through common checks
|
|
||||||
_ = await common_checks(
|
|
||||||
request_body=request_data,
|
|
||||||
team_object=team_object,
|
|
||||||
user_object=user_object,
|
|
||||||
end_user_object=end_user_object,
|
|
||||||
general_settings=general_settings,
|
|
||||||
global_proxy_spend=global_proxy_spend,
|
|
||||||
route=route,
|
|
||||||
llm_router=llm_router,
|
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
|
||||||
valid_token=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# return UserAPIKeyAuth object
|
valid_token = UserAPIKeyAuth(
|
||||||
return UserAPIKeyAuth(
|
|
||||||
api_key=None,
|
api_key=None,
|
||||||
team_id=team_id,
|
team_id=team_id,
|
||||||
team_tpm_limit=(
|
team_tpm_limit=(
|
||||||
|
@ -501,6 +410,23 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
parent_otel_span=parent_otel_span,
|
parent_otel_span=parent_otel_span,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
)
|
)
|
||||||
|
# run through common checks
|
||||||
|
_ = await common_checks(
|
||||||
|
request=request,
|
||||||
|
request_body=request_data,
|
||||||
|
team_object=team_object,
|
||||||
|
user_object=user_object,
|
||||||
|
end_user_object=end_user_object,
|
||||||
|
general_settings=general_settings,
|
||||||
|
global_proxy_spend=global_proxy_spend,
|
||||||
|
route=route,
|
||||||
|
llm_router=llm_router,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
valid_token=valid_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
# return UserAPIKeyAuth object
|
||||||
|
return cast(UserAPIKeyAuth, valid_token)
|
||||||
|
|
||||||
#### ELSE ####
|
#### ELSE ####
|
||||||
## CHECK PASS-THROUGH ENDPOINTS ##
|
## CHECK PASS-THROUGH ENDPOINTS ##
|
||||||
|
@ -1038,6 +964,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
_ = await common_checks(
|
_ = await common_checks(
|
||||||
|
request=request,
|
||||||
request_body=request_data,
|
request_body=request_data,
|
||||||
team_object=_team_obj,
|
team_object=_team_obj,
|
||||||
user_object=user_obj,
|
user_object=user_obj,
|
||||||
|
@ -1075,23 +1002,6 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
# check if token is from litellm-ui, litellm ui makes keys to allow users to login with sso. These keys can only be used for LiteLLM UI functions
|
# check if token is from litellm-ui, litellm ui makes keys to allow users to login with sso. These keys can only be used for LiteLLM UI functions
|
||||||
# sso/login, ui/login, /key functions and /user functions
|
# sso/login, ui/login, /key functions and /user functions
|
||||||
# this will never be allowed to call /chat/completions
|
# this will never be allowed to call /chat/completions
|
||||||
token_team = getattr(valid_token, "team_id", None)
|
|
||||||
token_type: Literal["ui", "api"] = (
|
|
||||||
"ui"
|
|
||||||
if token_team is not None and token_team == "litellm-dashboard"
|
|
||||||
else "api"
|
|
||||||
)
|
|
||||||
_is_route_allowed = _is_allowed_route(
|
|
||||||
route=route,
|
|
||||||
token_type=token_type,
|
|
||||||
user_obj=user_obj,
|
|
||||||
request=request,
|
|
||||||
request_data=request_data,
|
|
||||||
api_key=api_key,
|
|
||||||
valid_token=valid_token,
|
|
||||||
)
|
|
||||||
if not _is_route_allowed:
|
|
||||||
raise HTTPException(401, detail="Invalid route for UI token")
|
|
||||||
|
|
||||||
if valid_token is None:
|
if valid_token is None:
|
||||||
# No token was found when looking up in the DB
|
# No token was found when looking up in the DB
|
||||||
|
@ -1242,42 +1152,6 @@ async def _return_user_api_key_auth_obj(
|
||||||
return UserAPIKeyAuth(**user_api_key_kwargs)
|
return UserAPIKeyAuth(**user_api_key_kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _is_user_proxy_admin(user_obj: Optional[LiteLLM_UserTable]):
|
|
||||||
if user_obj is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if (
|
|
||||||
user_obj.user_role is not None
|
|
||||||
and user_obj.user_role == LitellmUserRoles.PROXY_ADMIN.value
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
|
|
||||||
if (
|
|
||||||
user_obj.user_role is not None
|
|
||||||
and user_obj.user_role == LitellmUserRoles.PROXY_ADMIN.value
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _get_user_role(
|
|
||||||
user_obj: Optional[LiteLLM_UserTable],
|
|
||||||
) -> Optional[LitellmUserRoles]:
|
|
||||||
if user_obj is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
_user = user_obj
|
|
||||||
|
|
||||||
_user_role = _user.user_role
|
|
||||||
try:
|
|
||||||
role = LitellmUserRoles(_user_role)
|
|
||||||
except ValueError:
|
|
||||||
return LitellmUserRoles.INTERNAL_USER
|
|
||||||
|
|
||||||
return role
|
|
||||||
|
|
||||||
|
|
||||||
def get_api_key_from_custom_header(
|
def get_api_key_from_custom_header(
|
||||||
request: Request, custom_litellm_key_header_name: str
|
request: Request, custom_litellm_key_header_name: str
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
|
@ -1240,4 +1240,5 @@ async def ui_view_users(
|
||||||
return [LiteLLM_UserTableFiltered(**user.model_dump()) for user in users]
|
return [LiteLLM_UserTableFiltered(**user.model_dump()) for user in users]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.exception(f"Error searching users: {str(e)}")
|
||||||
raise HTTPException(status_code=500, detail=f"Error searching users: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Error searching users: {str(e)}")
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
from litellm.proxy._types import LiteLLM_UserTableFiltered, UserAPIKeyAuth
|
||||||
|
from litellm.proxy.management_endpoints.internal_user_endpoints import ui_view_users
|
||||||
|
from litellm.proxy.proxy_server import app
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ui_view_users_with_null_email(mocker, caplog):
|
||||||
|
"""
|
||||||
|
Test that /user/filter/ui endpoint returns users even when they have null email fields
|
||||||
|
"""
|
||||||
|
# Mock the prisma client
|
||||||
|
mock_prisma_client = mocker.MagicMock()
|
||||||
|
|
||||||
|
# Create mock user data with null email
|
||||||
|
mock_user = mocker.MagicMock()
|
||||||
|
mock_user.model_dump.return_value = {
|
||||||
|
"user_id": "test-user-null-email",
|
||||||
|
"user_email": None,
|
||||||
|
"user_role": "proxy_admin",
|
||||||
|
"created_at": "2024-01-01T00:00:00Z",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Setup the mock find_many response
|
||||||
|
# Setup the mock find_many response as an async function
|
||||||
|
async def mock_find_many(*args, **kwargs):
|
||||||
|
return [mock_user]
|
||||||
|
|
||||||
|
mock_prisma_client.db.litellm_usertable.find_many = mock_find_many
|
||||||
|
|
||||||
|
# Patch the prisma client import in the endpoint
|
||||||
|
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||||||
|
|
||||||
|
# Call ui_view_users function directly
|
||||||
|
response = await ui_view_users(
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(user_id="test_user"),
|
||||||
|
user_id="test_user",
|
||||||
|
user_email=None,
|
||||||
|
page=1,
|
||||||
|
page_size=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response == [
|
||||||
|
LiteLLM_UserTableFiltered(user_id="test-user-null-email", user_email=None)
|
||||||
|
]
|
|
@ -165,7 +165,6 @@ def test_llm_api_route(route_checks):
|
||||||
route="/v1/chat/completions",
|
route="/v1/chat/completions",
|
||||||
request=MockRequest(),
|
request=MockRequest(),
|
||||||
valid_token=UserAPIKeyAuth(api_key="test_key"),
|
valid_token=UserAPIKeyAuth(api_key="test_key"),
|
||||||
api_key="test_key",
|
|
||||||
request_data={},
|
request_data={},
|
||||||
)
|
)
|
||||||
is None
|
is None
|
||||||
|
@ -183,7 +182,6 @@ def test_key_info_route_allowed(route_checks):
|
||||||
route="/key/info",
|
route="/key/info",
|
||||||
request=MockRequest(query_params={"key": "test_key"}),
|
request=MockRequest(query_params={"key": "test_key"}),
|
||||||
valid_token=UserAPIKeyAuth(api_key="test_key"),
|
valid_token=UserAPIKeyAuth(api_key="test_key"),
|
||||||
api_key="test_key",
|
|
||||||
request_data={},
|
request_data={},
|
||||||
)
|
)
|
||||||
is None
|
is None
|
||||||
|
@ -201,7 +199,6 @@ def test_user_info_route_allowed(route_checks):
|
||||||
route="/user/info",
|
route="/user/info",
|
||||||
request=MockRequest(query_params={"user_id": "test_user"}),
|
request=MockRequest(query_params={"user_id": "test_user"}),
|
||||||
valid_token=UserAPIKeyAuth(api_key="test_key", user_id="test_user"),
|
valid_token=UserAPIKeyAuth(api_key="test_key", user_id="test_user"),
|
||||||
api_key="test_key",
|
|
||||||
request_data={},
|
request_data={},
|
||||||
)
|
)
|
||||||
is None
|
is None
|
||||||
|
@ -219,7 +216,6 @@ def test_user_info_route_forbidden(route_checks):
|
||||||
route="/user/info",
|
route="/user/info",
|
||||||
request=MockRequest(query_params={"user_id": "wrong_user"}),
|
request=MockRequest(query_params={"user_id": "wrong_user"}),
|
||||||
valid_token=UserAPIKeyAuth(api_key="test_key", user_id="test_user"),
|
valid_token=UserAPIKeyAuth(api_key="test_key", user_id="test_user"),
|
||||||
api_key="test_key",
|
|
||||||
request_data={},
|
request_data={},
|
||||||
)
|
)
|
||||||
assert exc_info.value.status_code == 403
|
assert exc_info.value.status_code == 403
|
||||||
|
|
|
@ -4,6 +4,9 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
import litellm.proxy
|
||||||
|
import litellm.proxy.proxy_server
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
|
@ -329,7 +332,7 @@ async def test_auth_with_allowed_routes(route, should_raise_error):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_is_ui_route_allowed(route, user_role, expected_result):
|
def test_is_ui_route_allowed(route, user_role, expected_result):
|
||||||
from litellm.proxy.auth.user_api_key_auth import _is_ui_route
|
from litellm.proxy.auth.auth_checks import _is_ui_route
|
||||||
from litellm.proxy._types import LiteLLM_UserTable
|
from litellm.proxy._types import LiteLLM_UserTable
|
||||||
|
|
||||||
user_obj = LiteLLM_UserTable(
|
user_obj = LiteLLM_UserTable(
|
||||||
|
@ -367,7 +370,7 @@ def test_is_ui_route_allowed(route, user_role, expected_result):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_is_api_route_allowed(route, user_role, expected_result):
|
def test_is_api_route_allowed(route, user_role, expected_result):
|
||||||
from litellm.proxy.auth.user_api_key_auth import _is_api_route_allowed
|
from litellm.proxy.auth.auth_checks import _is_api_route_allowed
|
||||||
from litellm.proxy._types import LiteLLM_UserTable
|
from litellm.proxy._types import LiteLLM_UserTable
|
||||||
|
|
||||||
user_obj = LiteLLM_UserTable(
|
user_obj = LiteLLM_UserTable(
|
||||||
|
@ -635,7 +638,7 @@ async def test_soft_budget_alert():
|
||||||
|
|
||||||
|
|
||||||
def test_is_allowed_route():
|
def test_is_allowed_route():
|
||||||
from litellm.proxy.auth.user_api_key_auth import _is_allowed_route
|
from litellm.proxy.auth.auth_checks import _is_allowed_route
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
@ -646,7 +649,6 @@ def test_is_allowed_route():
|
||||||
"token_type": "api",
|
"token_type": "api",
|
||||||
"request": request,
|
"request": request,
|
||||||
"request_data": {"input": ["hello world"], "model": "embedding-small"},
|
"request_data": {"input": ["hello world"], "model": "embedding-small"},
|
||||||
"api_key": "9644159bc181998825c44c788b1526341ed2e825d1b6f562e23173759e14bb86",
|
|
||||||
"valid_token": UserAPIKeyAuth(
|
"valid_token": UserAPIKeyAuth(
|
||||||
token="9644159bc181998825c44c788b1526341ed2e825d1b6f562e23173759e14bb86",
|
token="9644159bc181998825c44c788b1526341ed2e825d1b6f562e23173759e14bb86",
|
||||||
key_name="sk-...CJjQ",
|
key_name="sk-...CJjQ",
|
||||||
|
@ -734,7 +736,7 @@ def test_is_allowed_route():
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_is_user_proxy_admin(user_obj, expected_result):
|
def test_is_user_proxy_admin(user_obj, expected_result):
|
||||||
from litellm.proxy.auth.user_api_key_auth import _is_user_proxy_admin
|
from litellm.proxy.auth.auth_checks import _is_user_proxy_admin
|
||||||
|
|
||||||
assert _is_user_proxy_admin(user_obj) == expected_result
|
assert _is_user_proxy_admin(user_obj) == expected_result
|
||||||
|
|
||||||
|
@ -947,3 +949,53 @@ def test_get_model_from_request(route, request_data, expected_model):
|
||||||
from litellm.proxy.auth.user_api_key_auth import get_model_from_request
|
from litellm.proxy.auth.user_api_key_auth import get_model_from_request
|
||||||
|
|
||||||
assert get_model_from_request(request_data, route) == expected_model
|
assert get_model_from_request(request_data, route) == expected_model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_jwt_non_admin_team_route_access(monkeypatch):
|
||||||
|
"""
|
||||||
|
Test that a non-admin JWT user cannot access team management routes
|
||||||
|
"""
|
||||||
|
from fastapi import Request, HTTPException
|
||||||
|
from starlette.datastructures import URL
|
||||||
|
from unittest.mock import patch
|
||||||
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
|
import json
|
||||||
|
from litellm.proxy._types import ProxyException
|
||||||
|
|
||||||
|
mock_jwt_response = {
|
||||||
|
"is_proxy_admin": False,
|
||||||
|
"team_id": None,
|
||||||
|
"team_object": None,
|
||||||
|
"user_id": None,
|
||||||
|
"user_object": None,
|
||||||
|
"org_id": None,
|
||||||
|
"org_object": None,
|
||||||
|
"end_user_id": None,
|
||||||
|
"end_user_object": None,
|
||||||
|
"token": "eyJhbGciOiJSUzI1NiIsInR5cCIgOiAiSldUIiwia2lkIiA6ICJmR09YQTNhbHFObjByRzJ6OHJQT1FLZVVMSWxCNDFnVWl4VDJ5WE1QVG1ZIn0.eyJleHAiOjE3NDI2MDAzODIsImlhdCI6MTc0MjYwMDA4MiwianRpIjoiODRhNjZmZjAtMTE5OC00YmRkLTk1NzAtNWZhMjNhZjYxMmQyIiwiaXNzIjoiaHR0cDovL2xvY2FsaG9zdDo4MDgwL3JlYWxtcy9saXRlbGxtLXJlYWxtIiwiYXVkIjoiYWNjb3VudCIsInN1YiI6ImZmMGZjOGNiLWUyMjktNDkyYy05NzYwLWNlYzVhMDYxNmI2MyIsInR5cCI6IkJlYXJlciIsImF6cCI6ImxpdGVsbG0tdGVzdC1jbGllbnQtaWQiLCJzaWQiOiI4MTYwNjIxOC0yNmZmLTQwMjAtOWQxNy05Zjc0YmFlNTBkODUiLCJhY3IiOiIxIiwiYWxsb3dlZC1vcmlnaW5zIjpbImh0dHA6Ly9sb2NhbGhvc3Q6NDAwMC8qIl0sInJlYWxtX2FjY2VzcyI6eyJyb2xlcyI6WyJvZmZsaW5lX2FjY2VzcyIsImRlZmF1bHQtcm9sZXMtbGl0ZWxsbS1yZWFsbSIsInVtYV9hdXRob3JpemF0aW9uIl19LCJyZXNvdXJjZV9hY2Nlc3MiOnsiYWNjb3VudCI6eyJyb2xlcyI6WyJtYW5hZ2UtYWNjb3VudCIsIm1hbmFnZS1hY2NvdW50LWxpbmtzIiwidmlldy1wcm9maWxlIl19fSwic2NvcGUiOiJwcm9maWxlIGdyb3Vwcy1zY29wZSBlbWFpbCBsaXRlbGxtLmFwaS5jb25zdW1lciIsImVtYWlsX3ZlcmlmaWVkIjp0cnVlLCJuYW1lIjoiS3Jpc2ggRGhvbGFraWEiLCJncm91cHMiOlsiL28zX21pbmlfYWNjZXNzIl0sInByZWZlcnJlZF91c2VybmFtZSI6ImtycmlzaGRoMiIsImdpdmVuX25hbWUiOiJLcmlzaCIsImZhbWlseV9uYW1lIjoiRGhvbGFraWEiLCJlbWFpbCI6ImtycmlzaGRob2xha2lhMkBnbWFpbC5jb20ifQ.Fu2ErZhnfez-bhn_XmjkywcFdZHcFUSvzIzfdNiEowdA0soLmCyqf9731amP6m68shd9qk11e0mQhxFIAIxZPojViC1Csc9TBXLRRQ8ESMd6gPIj-DBkKVkQSZLJ1uibsh4Oo2RViGtqWVcEt32T8U_xhGdtdzNkJ8qy_e0fdNDsUnhmSaTQvmZJYarW0roIrkC-zYZrX3fftzbQfavSu9eqdfPf6wUttIrkaWThWUuORy-xaeZfSmvsGbEg027hh6QwlChiZTSF8R6bRxoqfPN3ZaGFFgbBXNRYZA_eYi2IevhIwJHi_r4o1UvtKAJyfPefm-M6hCfkN_6da4zsog",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create request
|
||||||
|
request = Request(
|
||||||
|
scope={"type": "http", "headers": [("Authorization", "Bearer fake.jwt.token")]}
|
||||||
|
)
|
||||||
|
request._url = URL(url="/team/new")
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
litellm.proxy.proxy_server, "general_settings", {"enable_jwt_auth": True}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock JWTAuthManager.auth_builder
|
||||||
|
with patch(
|
||||||
|
"litellm.proxy.auth.handle_jwt.JWTAuthManager.auth_builder",
|
||||||
|
return_value=mock_jwt_response,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
await user_api_key_auth(request=request, api_key="Bearer fake.jwt.token")
|
||||||
|
pytest.fail(
|
||||||
|
"Expected this call to fail. Non-admin user should not access team routes."
|
||||||
|
)
|
||||||
|
except ProxyException as e:
|
||||||
|
print("e", e)
|
||||||
|
assert "Only proxy admin can be used to generate" in str(e.message)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue