mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
refactor(user_api_key_auth.py): move is_route_allowed to inside common_checks
ensures consistent behaviour inside api key + jwt routes
This commit is contained in:
parent
a23a7e1486
commit
c7f42747bf
5 changed files with 167 additions and 155 deletions
|
@ -14,7 +14,7 @@ import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, cast
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, cast
|
||||||
|
|
||||||
from fastapi import status
|
from fastapi import Request, status
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
@ -74,6 +74,7 @@ async def common_checks(
|
||||||
llm_router: Optional[Router],
|
llm_router: Optional[Router],
|
||||||
proxy_logging_obj: ProxyLogging,
|
proxy_logging_obj: ProxyLogging,
|
||||||
valid_token: Optional[UserAPIKeyAuth],
|
valid_token: Optional[UserAPIKeyAuth],
|
||||||
|
request: Request,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Common checks across jwt + key-based auth.
|
Common checks across jwt + key-based auth.
|
||||||
|
@ -198,9 +199,134 @@ async def common_checks(
|
||||||
user_object=user_object, route=route, request_body=request_body
|
user_object=user_object, route=route, request_body=request_body
|
||||||
)
|
)
|
||||||
|
|
||||||
|
token_team = getattr(valid_token, "team_id", None)
|
||||||
|
token_type: Literal["ui", "api"] = (
|
||||||
|
"ui" if token_team is not None and token_team == "litellm-dashboard" else "api"
|
||||||
|
)
|
||||||
|
_is_route_allowed = _is_allowed_route(
|
||||||
|
route=route,
|
||||||
|
token_type=token_type,
|
||||||
|
user_obj=user_object,
|
||||||
|
request=request,
|
||||||
|
request_data=request_body,
|
||||||
|
valid_token=valid_token,
|
||||||
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _is_ui_route(
|
||||||
|
route: str,
|
||||||
|
user_obj: Optional[LiteLLM_UserTable] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
- Check if the route is a UI used route
|
||||||
|
"""
|
||||||
|
# this token is only used for managing the ui
|
||||||
|
allowed_routes = LiteLLMRoutes.ui_routes.value
|
||||||
|
# check if the current route startswith any of the allowed routes
|
||||||
|
if (
|
||||||
|
route is not None
|
||||||
|
and isinstance(route, str)
|
||||||
|
and any(route.startswith(allowed_route) for allowed_route in allowed_routes)
|
||||||
|
):
|
||||||
|
# Do something if the current route starts with any of the allowed routes
|
||||||
|
return True
|
||||||
|
elif any(
|
||||||
|
RouteChecks._route_matches_pattern(route=route, pattern=allowed_route)
|
||||||
|
for allowed_route in allowed_routes
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _get_user_role(
|
||||||
|
user_obj: Optional[LiteLLM_UserTable],
|
||||||
|
) -> Optional[LitellmUserRoles]:
|
||||||
|
if user_obj is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
_user = user_obj
|
||||||
|
|
||||||
|
_user_role = _user.user_role
|
||||||
|
try:
|
||||||
|
role = LitellmUserRoles(_user_role)
|
||||||
|
except ValueError:
|
||||||
|
return LitellmUserRoles.INTERNAL_USER
|
||||||
|
|
||||||
|
return role
|
||||||
|
|
||||||
|
|
||||||
|
def _is_api_route_allowed(
|
||||||
|
route: str,
|
||||||
|
request: Request,
|
||||||
|
request_data: dict,
|
||||||
|
valid_token: Optional[UserAPIKeyAuth],
|
||||||
|
user_obj: Optional[LiteLLM_UserTable] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
- Route b/w api token check and normal token check
|
||||||
|
"""
|
||||||
|
_user_role = _get_user_role(user_obj=user_obj)
|
||||||
|
|
||||||
|
if valid_token is None:
|
||||||
|
raise Exception("Invalid proxy server token passed. valid_token=None.")
|
||||||
|
|
||||||
|
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
|
||||||
|
RouteChecks.non_proxy_admin_allowed_routes_check(
|
||||||
|
user_obj=user_obj,
|
||||||
|
_user_role=_user_role,
|
||||||
|
route=route,
|
||||||
|
request=request,
|
||||||
|
request_data=request_data,
|
||||||
|
valid_token=valid_token,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _is_user_proxy_admin(user_obj: Optional[LiteLLM_UserTable]):
|
||||||
|
if user_obj is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if (
|
||||||
|
user_obj.user_role is not None
|
||||||
|
and user_obj.user_role == LitellmUserRoles.PROXY_ADMIN.value
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
if (
|
||||||
|
user_obj.user_role is not None
|
||||||
|
and user_obj.user_role == LitellmUserRoles.PROXY_ADMIN.value
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_allowed_route(
|
||||||
|
route: str,
|
||||||
|
token_type: Literal["ui", "api"],
|
||||||
|
request: Request,
|
||||||
|
request_data: dict,
|
||||||
|
valid_token: Optional[UserAPIKeyAuth],
|
||||||
|
user_obj: Optional[LiteLLM_UserTable] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
- Route b/w ui token check and normal token check
|
||||||
|
"""
|
||||||
|
|
||||||
|
if token_type == "ui" and _is_ui_route(route=route, user_obj=user_obj):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return _is_api_route_allowed(
|
||||||
|
route=route,
|
||||||
|
request=request,
|
||||||
|
request_data=request_data,
|
||||||
|
valid_token=valid_token,
|
||||||
|
user_obj=user_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool:
|
def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool:
|
||||||
"""
|
"""
|
||||||
Return if a user is allowed to access route. Helper function for `allowed_routes_check`.
|
Return if a user is allowed to access route. Helper function for `allowed_routes_check`.
|
||||||
|
|
|
@ -321,6 +321,7 @@ async def check_if_request_size_is_safe(request: Request) -> bool:
|
||||||
from litellm.proxy.proxy_server import general_settings, premium_user
|
from litellm.proxy.proxy_server import general_settings, premium_user
|
||||||
|
|
||||||
max_request_size_mb = general_settings.get("max_request_size_mb", None)
|
max_request_size_mb = general_settings.get("max_request_size_mb", None)
|
||||||
|
|
||||||
if max_request_size_mb is not None:
|
if max_request_size_mb is not None:
|
||||||
# Check if premium user
|
# Check if premium user
|
||||||
if premium_user is not True:
|
if premium_user is not True:
|
||||||
|
|
|
@ -24,7 +24,6 @@ class RouteChecks:
|
||||||
route: str,
|
route: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
valid_token: UserAPIKeyAuth,
|
valid_token: UserAPIKeyAuth,
|
||||||
api_key: str,
|
|
||||||
request_data: dict,
|
request_data: dict,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -25,7 +25,10 @@ from litellm.litellm_core_utils.dd_tracing import tracer
|
||||||
from litellm.proxy._types import *
|
from litellm.proxy._types import *
|
||||||
from litellm.proxy.auth.auth_checks import (
|
from litellm.proxy.auth.auth_checks import (
|
||||||
_cache_key_object,
|
_cache_key_object,
|
||||||
|
_get_user_role,
|
||||||
_handle_failed_db_connection_for_get_key_object,
|
_handle_failed_db_connection_for_get_key_object,
|
||||||
|
_is_allowed_route,
|
||||||
|
_is_user_proxy_admin,
|
||||||
_virtual_key_max_budget_check,
|
_virtual_key_max_budget_check,
|
||||||
_virtual_key_soft_budget_check,
|
_virtual_key_soft_budget_check,
|
||||||
can_key_call_model,
|
can_key_call_model,
|
||||||
|
@ -98,86 +101,6 @@ def _get_bearer_token(
|
||||||
return api_key
|
return api_key
|
||||||
|
|
||||||
|
|
||||||
def _is_ui_route(
|
|
||||||
route: str,
|
|
||||||
user_obj: Optional[LiteLLM_UserTable] = None,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
- Check if the route is a UI used route
|
|
||||||
"""
|
|
||||||
# this token is only used for managing the ui
|
|
||||||
allowed_routes = LiteLLMRoutes.ui_routes.value
|
|
||||||
# check if the current route startswith any of the allowed routes
|
|
||||||
if (
|
|
||||||
route is not None
|
|
||||||
and isinstance(route, str)
|
|
||||||
and any(route.startswith(allowed_route) for allowed_route in allowed_routes)
|
|
||||||
):
|
|
||||||
# Do something if the current route starts with any of the allowed routes
|
|
||||||
return True
|
|
||||||
elif any(
|
|
||||||
RouteChecks._route_matches_pattern(route=route, pattern=allowed_route)
|
|
||||||
for allowed_route in allowed_routes
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _is_api_route_allowed(
|
|
||||||
route: str,
|
|
||||||
request: Request,
|
|
||||||
request_data: dict,
|
|
||||||
api_key: str,
|
|
||||||
valid_token: Optional[UserAPIKeyAuth],
|
|
||||||
user_obj: Optional[LiteLLM_UserTable] = None,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
- Route b/w api token check and normal token check
|
|
||||||
"""
|
|
||||||
_user_role = _get_user_role(user_obj=user_obj)
|
|
||||||
|
|
||||||
if valid_token is None:
|
|
||||||
raise Exception("Invalid proxy server token passed. valid_token=None.")
|
|
||||||
|
|
||||||
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
|
|
||||||
RouteChecks.non_proxy_admin_allowed_routes_check(
|
|
||||||
user_obj=user_obj,
|
|
||||||
_user_role=_user_role,
|
|
||||||
route=route,
|
|
||||||
request=request,
|
|
||||||
request_data=request_data,
|
|
||||||
api_key=api_key,
|
|
||||||
valid_token=valid_token,
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def _is_allowed_route(
|
|
||||||
route: str,
|
|
||||||
token_type: Literal["ui", "api"],
|
|
||||||
request: Request,
|
|
||||||
request_data: dict,
|
|
||||||
api_key: str,
|
|
||||||
valid_token: Optional[UserAPIKeyAuth],
|
|
||||||
user_obj: Optional[LiteLLM_UserTable] = None,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
- Route b/w ui token check and normal token check
|
|
||||||
"""
|
|
||||||
|
|
||||||
if token_type == "ui" and _is_ui_route(route=route, user_obj=user_obj):
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return _is_api_route_allowed(
|
|
||||||
route=route,
|
|
||||||
request=request,
|
|
||||||
request_data=request_data,
|
|
||||||
api_key=api_key,
|
|
||||||
valid_token=valid_token,
|
|
||||||
user_obj=user_obj,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def user_api_key_auth_websocket(websocket: WebSocket):
|
async def user_api_key_auth_websocket(websocket: WebSocket):
|
||||||
# Accept the WebSocket connection
|
# Accept the WebSocket connection
|
||||||
|
|
||||||
|
@ -328,6 +251,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
parent_otel_span: Optional[Span] = None
|
parent_otel_span: Optional[Span] = None
|
||||||
start_time = datetime.now()
|
start_time = datetime.now()
|
||||||
route: str = get_request_route(request=request)
|
route: str = get_request_route(request=request)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# get the request body
|
# get the request body
|
||||||
|
@ -470,22 +394,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
parent_otel_span=parent_otel_span,
|
parent_otel_span=parent_otel_span,
|
||||||
)
|
)
|
||||||
# run through common checks
|
|
||||||
_ = await common_checks(
|
|
||||||
request_body=request_data,
|
|
||||||
team_object=team_object,
|
|
||||||
user_object=user_object,
|
|
||||||
end_user_object=end_user_object,
|
|
||||||
general_settings=general_settings,
|
|
||||||
global_proxy_spend=global_proxy_spend,
|
|
||||||
route=route,
|
|
||||||
llm_router=llm_router,
|
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
|
||||||
valid_token=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# return UserAPIKeyAuth object
|
valid_token = UserAPIKeyAuth(
|
||||||
return UserAPIKeyAuth(
|
|
||||||
api_key=None,
|
api_key=None,
|
||||||
team_id=team_id,
|
team_id=team_id,
|
||||||
team_tpm_limit=(
|
team_tpm_limit=(
|
||||||
|
@ -501,6 +411,23 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
parent_otel_span=parent_otel_span,
|
parent_otel_span=parent_otel_span,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
)
|
)
|
||||||
|
# run through common checks
|
||||||
|
_ = await common_checks(
|
||||||
|
request=request,
|
||||||
|
request_body=request_data,
|
||||||
|
team_object=team_object,
|
||||||
|
user_object=user_object,
|
||||||
|
end_user_object=end_user_object,
|
||||||
|
general_settings=general_settings,
|
||||||
|
global_proxy_spend=global_proxy_spend,
|
||||||
|
route=route,
|
||||||
|
llm_router=llm_router,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
valid_token=valid_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
# return UserAPIKeyAuth object
|
||||||
|
return valid_token
|
||||||
|
|
||||||
#### ELSE ####
|
#### ELSE ####
|
||||||
## CHECK PASS-THROUGH ENDPOINTS ##
|
## CHECK PASS-THROUGH ENDPOINTS ##
|
||||||
|
@ -1038,6 +965,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
_ = await common_checks(
|
_ = await common_checks(
|
||||||
|
request=request,
|
||||||
request_body=request_data,
|
request_body=request_data,
|
||||||
team_object=_team_obj,
|
team_object=_team_obj,
|
||||||
user_object=user_obj,
|
user_object=user_obj,
|
||||||
|
@ -1075,23 +1003,6 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
# check if token is from litellm-ui, litellm ui makes keys to allow users to login with sso. These keys can only be used for LiteLLM UI functions
|
# check if token is from litellm-ui, litellm ui makes keys to allow users to login with sso. These keys can only be used for LiteLLM UI functions
|
||||||
# sso/login, ui/login, /key functions and /user functions
|
# sso/login, ui/login, /key functions and /user functions
|
||||||
# this will never be allowed to call /chat/completions
|
# this will never be allowed to call /chat/completions
|
||||||
token_team = getattr(valid_token, "team_id", None)
|
|
||||||
token_type: Literal["ui", "api"] = (
|
|
||||||
"ui"
|
|
||||||
if token_team is not None and token_team == "litellm-dashboard"
|
|
||||||
else "api"
|
|
||||||
)
|
|
||||||
_is_route_allowed = _is_allowed_route(
|
|
||||||
route=route,
|
|
||||||
token_type=token_type,
|
|
||||||
user_obj=user_obj,
|
|
||||||
request=request,
|
|
||||||
request_data=request_data,
|
|
||||||
api_key=api_key,
|
|
||||||
valid_token=valid_token,
|
|
||||||
)
|
|
||||||
if not _is_route_allowed:
|
|
||||||
raise HTTPException(401, detail="Invalid route for UI token")
|
|
||||||
|
|
||||||
if valid_token is None:
|
if valid_token is None:
|
||||||
# No token was found when looking up in the DB
|
# No token was found when looking up in the DB
|
||||||
|
@ -1242,42 +1153,6 @@ async def _return_user_api_key_auth_obj(
|
||||||
return UserAPIKeyAuth(**user_api_key_kwargs)
|
return UserAPIKeyAuth(**user_api_key_kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _is_user_proxy_admin(user_obj: Optional[LiteLLM_UserTable]):
|
|
||||||
if user_obj is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if (
|
|
||||||
user_obj.user_role is not None
|
|
||||||
and user_obj.user_role == LitellmUserRoles.PROXY_ADMIN.value
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
|
|
||||||
if (
|
|
||||||
user_obj.user_role is not None
|
|
||||||
and user_obj.user_role == LitellmUserRoles.PROXY_ADMIN.value
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _get_user_role(
|
|
||||||
user_obj: Optional[LiteLLM_UserTable],
|
|
||||||
) -> Optional[LitellmUserRoles]:
|
|
||||||
if user_obj is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
_user = user_obj
|
|
||||||
|
|
||||||
_user_role = _user.user_role
|
|
||||||
try:
|
|
||||||
role = LitellmUserRoles(_user_role)
|
|
||||||
except ValueError:
|
|
||||||
return LitellmUserRoles.INTERNAL_USER
|
|
||||||
|
|
||||||
return role
|
|
||||||
|
|
||||||
|
|
||||||
def get_api_key_from_custom_header(
|
def get_api_key_from_custom_header(
|
||||||
request: Request, custom_litellm_key_header_name: str
|
request: Request, custom_litellm_key_header_name: str
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
|
@ -4,6 +4,9 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
import litellm.proxy
|
||||||
|
import litellm.proxy.proxy_server
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
|
@ -950,7 +953,7 @@ def test_get_model_from_request(route, request_data, expected_model):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_jwt_non_admin_team_route_access():
|
async def test_jwt_non_admin_team_route_access(monkeypatch):
|
||||||
"""
|
"""
|
||||||
Test that a non-admin JWT user cannot access team management routes
|
Test that a non-admin JWT user cannot access team management routes
|
||||||
"""
|
"""
|
||||||
|
@ -958,6 +961,8 @@ async def test_jwt_non_admin_team_route_access():
|
||||||
from starlette.datastructures import URL
|
from starlette.datastructures import URL
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
|
import json
|
||||||
|
from litellm.proxy._types import ProxyException
|
||||||
|
|
||||||
mock_jwt_response = {
|
mock_jwt_response = {
|
||||||
"is_proxy_admin": False,
|
"is_proxy_admin": False,
|
||||||
|
@ -973,9 +978,15 @@ async def test_jwt_non_admin_team_route_access():
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create request
|
# Create request
|
||||||
request = Request(scope={"type": "http"})
|
request = Request(
|
||||||
|
scope={"type": "http", "headers": [("Authorization", "Bearer fake.jwt.token")]}
|
||||||
|
)
|
||||||
request._url = URL(url="/team/new")
|
request._url = URL(url="/team/new")
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
litellm.proxy.proxy_server, "general_settings", {"enable_jwt_auth": True}
|
||||||
|
)
|
||||||
|
|
||||||
# Mock JWTAuthManager.auth_builder
|
# Mock JWTAuthManager.auth_builder
|
||||||
with patch(
|
with patch(
|
||||||
"litellm.proxy.auth.handle_jwt.JWTAuthManager.auth_builder",
|
"litellm.proxy.auth.handle_jwt.JWTAuthManager.auth_builder",
|
||||||
|
@ -986,6 +997,6 @@ async def test_jwt_non_admin_team_route_access():
|
||||||
pytest.fail(
|
pytest.fail(
|
||||||
"Expected this call to fail. Non-admin user should not access team routes."
|
"Expected this call to fail. Non-admin user should not access team routes."
|
||||||
)
|
)
|
||||||
except HTTPException as e:
|
except ProxyException as e:
|
||||||
assert e.status_code == 403
|
print("e", e)
|
||||||
assert "Unauthorized" in str(e.detail)
|
assert "Only proxy admin can be used to generate" in str(e.message)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue