diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index d4e5834f2..86d842ed9 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1,4 +1,5 @@ from pydantic import BaseModel, Extra, Field, root_validator, Json, validator +from dataclasses import fields import enum from typing import Optional, List, Union, Dict, Literal, Any from datetime import datetime @@ -37,9 +38,96 @@ class LiteLLMBase(BaseModel): protected_namespaces = () +class LiteLLMRoutes(enum.Enum): + openai_routes: List = [ # chat completions + "/openai/deployments/{model}/chat/completions", + "/chat/completions", + "/v1/chat/completions", + # completions + "/openai/deployments/{model}/completions", + "/completions", + "/v1/completions", + # embeddings + "/openai/deployments/{model}/embeddings", + "/embeddings", + "/v1/embeddings", + # image generation + "/images/generations", + "/v1/images/generations", + # audio transcription + "/audio/transcriptions", + "/v1/audio/transcriptions", + # moderations + "/moderations", + "/v1/moderations", + # models + "/models", + "/v1/models", + ] + + info_routes: List = ["/key/info", "/team/info", "/user/info", "/model/info"] + + management_routes: List = [ # key + "/key/generate", + "/key/update", + "/key/delete", + "/key/info", + # user + "/user/new", + "/user/update", + "/user/delete", + "/user/info", + # team + "/team/new", + "/team/update", + "/team/delete", + "/team/info", + # model + "/model/new", + "/model/update", + "/model/delete", + "/model/info", + ] + + class LiteLLMProxyRoles(LiteLLMBase): - proxy_admin: str = "litellm_proxy_admin" - proxy_user: str = "litellm_user" + """ + A class to define the roles and permissions for a LiteLLM Proxy w/ JWT Auth. + + Attributes: + - admin_jwt_scope: The JWT scope required for proxy admin roles. + - admin_allowed_routes: list of allowed routes for proxy admin roles. + - team_jwt_scope: The JWT scope required for proxy team roles. + - team_id_jwt_field: The field in the JWT token that stores the team ID. Default - `client_id`. + - team_allowed_routes: list of allowed routes for proxy team roles. + - end_user_id_jwt_field: Default - `sub`. The field in the JWT token that stores the end-user ID. Turn this off by setting to `None`. Enables end-user cost tracking. + + See `auth_checks.py` for the specific routes + """ + + admin_jwt_scope: str = "litellm_proxy_admin" + admin_allowed_routes: List[ + Literal["openai_routes", "info_routes", "management_routes"] + ] = ["management_routes"] + team_jwt_scope: str = "litellm_team" + team_id_jwt_field: str = "client_id" + team_allowed_routes: List[ + Literal["openai_routes", "info_routes", "management_routes"] + ] = ["openai_routes", "info_routes"] + end_user_id_jwt_field: Optional[str] = "sub" + + def __init__(self, **kwargs: Any) -> None: + # get the attribute names for this Pydantic model + allowed_keys = self.__annotations__.keys() + + invalid_keys = set(kwargs.keys()) - allowed_keys + + if invalid_keys: + raise ValueError( + f"Invalid arguments provided: {', '.join(invalid_keys)}. Allowed arguments are: {', '.join(allowed_keys)}." + ) + + super().__init__(**kwargs) class LiteLLMPromptInjectionParams(LiteLLMBase): diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 1c16381ad..f1ef5ca00 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -8,15 +8,23 @@ Run checks for: 2. If user is in budget 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget """ -from litellm.proxy._types import LiteLLM_UserTable, LiteLLM_EndUserTable +from litellm.proxy._types import ( + LiteLLM_UserTable, + LiteLLM_EndUserTable, + LiteLLMProxyRoles, + LiteLLM_TeamTable, + LiteLLMRoutes, +) from typing import Optional, Literal from litellm.proxy.utils import PrismaClient from litellm.caching import DualCache +all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value + def common_checks( request_body: dict, - user_object: LiteLLM_UserTable, + team_object: LiteLLM_TeamTable, end_user_object: Optional[LiteLLM_EndUserTable], ) -> bool: """ @@ -30,19 +38,20 @@ def common_checks( # 1. If user can call model if ( _model is not None - and len(user_object.models) > 0 - and _model not in user_object.models + and len(team_object.models) > 0 + and _model not in team_object.models ): raise Exception( - f"User={user_object.user_id} not allowed to call model={_model}. Allowed user models = {user_object.models}" + f"Team={team_object.team_id} not allowed to call model={_model}. Allowed team models = {team_object.models}" ) - # 2. If user is in budget + # 2. If team is in budget if ( - user_object.max_budget is not None - and user_object.spend > user_object.max_budget + team_object.max_budget is not None + and team_object.spend is not None + and team_object.spend > team_object.max_budget ): raise Exception( - f"User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_object.max_budget}" + f"Team={team_object.team_id} over budget. Spend={team_object.spend}, Budget={team_object.max_budget}" ) # 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget if end_user_object is not None and end_user_object.litellm_budget_table is not None: @@ -54,52 +63,79 @@ def common_checks( return True +def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool: + for allowed_route in allowed_routes: + if ( + allowed_route == LiteLLMRoutes.openai_routes.name + and user_route in LiteLLMRoutes.openai_routes.value + ): + return True + elif ( + allowed_route == LiteLLMRoutes.info_routes.name + and user_route in LiteLLMRoutes.info_routes.value + ): + return True + elif ( + allowed_route == LiteLLMRoutes.management_routes.name + and user_route in LiteLLMRoutes.management_routes.value + ): + return True + elif allowed_route == user_route: + return True + return False + + def allowed_routes_check( - user_role: Literal["proxy_admin", "app_owner"], - route: str, - allowed_routes: Optional[list] = None, + user_role: Literal["proxy_admin", "team"], + user_route: str, + litellm_proxy_roles: LiteLLMProxyRoles, ) -> bool: """ Check if user -> not admin - allowed to access these routes """ - openai_routes = [ - # chat completions - "/openai/deployments/{model}/chat/completions", - "/chat/completions", - "/v1/chat/completions", - # completions - # embeddings - "/openai/deployments/{model}/embeddings", - "/embeddings", - "/v1/embeddings", - # image generation - "/images/generations", - "/v1/images/generations", - # audio transcription - "/audio/transcriptions", - "/v1/audio/transcriptions", - # moderations - "/moderations", - "/v1/moderations", - # models - "/models", - "/v1/models", - ] - info_routes = ["/key/info", "/team/info", "/user/info", "/model/info"] - default_routes = openai_routes + info_routes + if user_role == "proxy_admin": - return True - elif user_role == "app_owner": - if allowed_routes is None: - if route in default_routes: # check default routes - return True - elif route in allowed_routes: - return True - else: - return False + if litellm_proxy_roles.admin_allowed_routes is None: + is_allowed = _allowed_routes_check( + user_route=user_route, allowed_routes=["management_routes"] + ) + return is_allowed + elif litellm_proxy_roles.admin_allowed_routes is not None: + is_allowed = _allowed_routes_check( + user_route=user_route, + allowed_routes=litellm_proxy_roles.admin_allowed_routes, + ) + return is_allowed + + elif user_role == "team": + if litellm_proxy_roles.team_allowed_routes is None: + """ + By default allow a team to call openai + info routes + """ + is_allowed = _allowed_routes_check( + user_route=user_route, allowed_routes=["openai_routes", "info_routes"] + ) + return is_allowed + elif litellm_proxy_roles.team_allowed_routes is not None: + is_allowed = _allowed_routes_check( + user_route=user_route, + allowed_routes=litellm_proxy_roles.team_allowed_routes, + ) + return is_allowed return False +def get_actual_routes(allowed_routes: list) -> list: + actual_routes: list = [] + for route_name in allowed_routes: + try: + route_value = LiteLLMRoutes[route_name].value + actual_routes = actual_routes + route_value + except KeyError: + actual_routes.append(route_name) + return actual_routes + + async def get_end_user_object( end_user_id: Optional[str], prisma_client: Optional[PrismaClient], @@ -135,3 +171,75 @@ async def get_end_user_object( return LiteLLM_EndUserTable(**response.dict()) except Exception as e: # if end-user not in db return None + + +async def get_user_object(self, user_id: str) -> LiteLLM_UserTable: + """ + - Check if user id in proxy User Table + - if valid, return LiteLLM_UserTable object with defined limits + - if not, then raise an error + """ + if self.prisma_client is None: + raise Exception( + "No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys" + ) + + # check if in cache + cached_user_obj = self.user_api_key_cache.async_get_cache(key=user_id) + if cached_user_obj is not None: + if isinstance(cached_user_obj, dict): + return LiteLLM_UserTable(**cached_user_obj) + elif isinstance(cached_user_obj, LiteLLM_UserTable): + return cached_user_obj + # else, check db + try: + response = await self.prisma_client.db.litellm_usertable.find_unique( + where={"user_id": user_id} + ) + + if response is None: + raise Exception + + return LiteLLM_UserTable(**response.dict()) + except Exception as e: + raise Exception( + f"User doesn't exist in db. User={user_id}. Create user via `/user/new` call." + ) + + +async def get_team_object( + team_id: str, + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, +) -> LiteLLM_TeamTable: + """ + - Check if team id in proxy Team Table + - if valid, return LiteLLM_TeamTable object with defined limits + - if not, then raise an error + """ + if prisma_client is None: + raise Exception( + "No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys" + ) + + # check if in cache + cached_team_obj = user_api_key_cache.async_get_cache(key=team_id) + if cached_team_obj is not None: + if isinstance(cached_team_obj, dict): + return LiteLLM_TeamTable(**cached_team_obj) + elif isinstance(cached_team_obj, LiteLLM_TeamTable): + return cached_team_obj + # else, check db + try: + response = await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": team_id} + ) + + if response is None: + raise Exception + + return LiteLLM_TeamTable(**response.dict()) + except Exception as e: + raise Exception( + f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call." + ) diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index b636c8813..ec7f75562 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -81,57 +81,27 @@ class JWTHandler: return len(parts) == 3 def is_admin(self, scopes: list) -> bool: - if self.litellm_proxy_roles.proxy_admin in scopes: + if self.litellm_proxy_roles.admin_jwt_scope in scopes: return True return False - def get_user_id(self, token: dict, default_value: str) -> str: + def get_end_user_id(self, token: dict, default_value: Optional[str]) -> str: try: - user_id = token["sub"] + if self.litellm_proxy_roles.team_id_jwt_field is not None: + user_id = token[self.litellm_proxy_roles.team_id_jwt_field] + else: + user_id = None except KeyError: user_id = default_value return user_id def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: try: - team_id = token["client_id"] + team_id = token[self.litellm_proxy_roles.team_id_jwt_field] except KeyError: team_id = default_value return team_id - async def get_user_object(self, user_id: str) -> LiteLLM_UserTable: - """ - - Check if user id in proxy User Table - - if valid, return LiteLLM_UserTable object with defined limits - - if not, then raise an error - """ - if self.prisma_client is None: - raise Exception( - "No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys" - ) - - # check if in cache - cached_user_obj = self.user_api_key_cache.async_get_cache(key=user_id) - if cached_user_obj is not None: - if isinstance(cached_user_obj, dict): - return LiteLLM_UserTable(**cached_user_obj) - elif isinstance(cached_user_obj, LiteLLM_UserTable): - return cached_user_obj - # else, check db - try: - response = await self.prisma_client.db.litellm_usertable.find_unique( - where={"user_id": user_id} - ) - - if response is None: - raise Exception - - return LiteLLM_UserTable(**response.dict()) - except Exception as e: - raise Exception( - f"User doesn't exist in db. User={user_id}. Create user via `/user/new` call." - ) - def get_scopes(self, token: dict) -> list: try: if isinstance(token["scope"], str): diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 48712a864..f405630f8 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -113,7 +113,10 @@ from litellm.proxy.hooks.prompt_injection_detection import ( from litellm.proxy.auth.auth_checks import ( common_checks, get_end_user_object, + get_team_object, + get_user_object, allowed_routes_check, + get_actual_routes, ) try: @@ -369,71 +372,93 @@ async def user_api_key_auth( scopes = jwt_handler.get_scopes(token=valid_token) # check if admin is_admin = jwt_handler.is_admin(scopes=scopes) - # get user id - user_id = jwt_handler.get_user_id( - token=valid_token, default_value=litellm_proxy_admin_name + # if admin return + if is_admin: + # check allowed admin routes + is_allowed = allowed_routes_check( + user_role="proxy_admin", + user_route=route, + litellm_proxy_roles=jwt_handler.litellm_proxy_roles, + ) + if is_allowed: + return UserAPIKeyAuth() + else: + allowed_routes = ( + jwt_handler.litellm_proxy_roles.admin_allowed_routes + ) + actual_routes = get_actual_routes(allowed_routes=allowed_routes) + raise Exception( + f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}" + ) + # get team id + team_id = jwt_handler.get_team_id(token=valid_token, default_value=None) + + if team_id is None: + raise Exception( + f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_proxy_roles.team_id_jwt_field}'" + ) + # check allowed team routes + is_allowed = allowed_routes_check( + user_role="team", + user_route=route, + litellm_proxy_roles=jwt_handler.litellm_proxy_roles, + ) + if is_allowed == False: + allowed_routes = jwt_handler.litellm_proxy_roles.team_allowed_routes + actual_routes = get_actual_routes(allowed_routes=allowed_routes) + raise Exception( + f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}" + ) + + # check if team in db + team_object = await get_team_object( + team_id=team_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, ) - end_user_object = None + # common checks + # allow request + # get the request body request_data = await _read_request_body(request=request) - # get user obj from cache/db -> run for admin too. Ensures, admin client id in db. - user_object = await jwt_handler.get_user_object(user_id=user_id) - if ( - request_data.get("user", None) - and request_data["user"] != user_object.user_id - ): + + end_user_object = None + end_user_id = jwt_handler.get_end_user_id( + token=valid_token, default_value=None + ) + if end_user_id is not None: # get the end-user object end_user_object = await get_end_user_object( - end_user_id=request_data["user"], + end_user_id=end_user_id, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, ) # save the end-user object to cache await user_api_key_cache.async_set_cache( - key=request_data["user"], value=end_user_object + key=end_user_id, value=end_user_object ) # run through common checks _ = common_checks( request_body=request_data, - user_object=user_object, + team_object=team_object, end_user_object=end_user_object, ) # save user object in cache await user_api_key_cache.async_set_cache( - key=user_object.user_id, value=user_object + key=team_object.team_id, value=team_object + ) + + # return UserAPIKeyAuth object + return UserAPIKeyAuth( + api_key=None, + team_id=team_object.team_id, + tpm_limit=team_object.tpm_limit, + rpm_limit=team_object.rpm_limit, + models=team_object.models, + user_role="app_owner", ) - # if admin return - if is_admin: - return UserAPIKeyAuth( - api_key=api_key, - user_role="proxy_admin", - user_id=user_id, - ) - else: - is_allowed = allowed_routes_check( - user_role="app_owner", - route=route, - allowed_routes=general_settings.get("allowed_routes", None), - ) - if is_allowed: - # return UserAPIKeyAuth object - return UserAPIKeyAuth( - api_key=None, - user_id=user_object.user_id, - tpm_limit=user_object.tpm_limit, - rpm_limit=user_object.rpm_limit, - models=user_object.models, - user_role="app_owner", - ) - else: - raise HTTPException( - status_code=401, - detail={ - "error": f"User={user_object.user_id} not allowed to access this route={route}." - }, - ) #### ELSE #### if master_key is None: if isinstance(api_key, str): @@ -2709,12 +2734,16 @@ async def startup_event(): proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made ## JWT AUTH ## + if general_settings.get("litellm_proxy_roles", None) is not None: + litellm_proxy_roles = LiteLLMProxyRoles( + **general_settings["litellm_proxy_roles"] + ) + else: + litellm_proxy_roles = LiteLLMProxyRoles() jwt_handler.update_environment( prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, - litellm_proxy_roles=LiteLLMProxyRoles( - **general_settings.get("litellm_proxy_roles", {}) - ), + litellm_proxy_roles=litellm_proxy_roles, ) if use_background_health_checks: