diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index a6e8b9bb7..aff312e35 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -139,7 +139,8 @@ class LiteLLM_JWTAuth(LiteLLMBase): team_allowed_routes: List[ Literal["openai_routes", "info_routes", "management_routes"] ] = ["openai_routes", "info_routes"] - end_user_id_jwt_field: Optional[str] = "sub" + user_id_jwt_field: Optional[str] = None + end_user_id_jwt_field: Optional[str] = None public_key_ttl: float = 600 def __init__(self, **kwargs: Any) -> None: diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 8cfa5587b..b25ae86a1 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -26,6 +26,7 @@ all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes def common_checks( request_body: dict, team_object: LiteLLM_TeamTable, + user_object: Optional[LiteLLM_UserTable], end_user_object: Optional[LiteLLM_EndUserTable], global_proxy_spend: Optional[float], general_settings: dict, @@ -37,7 +38,8 @@ def common_checks( 1. If team is blocked 2. If team can call model 3. If team is in budget - 4. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget + 5. If user passed in (JWT or key.user_id) - is in budget + 4. If end_user (either via JWT or 'user' passed to /chat/completions, /embeddings endpoint) is in budget 5. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints 6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget """ @@ -69,14 +71,20 @@ def common_checks( raise Exception( f"Team={team_object.team_id} over budget. Spend={team_object.spend}, Budget={team_object.max_budget}" ) - # 4. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget + if user_object is not None and user_object.max_budget is not None: + user_budget = user_object.max_budget + if user_budget > user_object.spend: + raise Exception( + f"ExceededBudget: User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_budget}" + ) + # 5. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget if end_user_object is not None and end_user_object.litellm_budget_table is not None: end_user_budget = end_user_object.litellm_budget_table.max_budget if end_user_budget is not None and end_user_object.spend > end_user_budget: raise Exception( f"ExceededBudget: End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}" ) - # 5. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints + # 6. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints if ( general_settings.get("enforce_user_param", None) is not None and general_settings["enforce_user_param"] == True @@ -85,7 +93,7 @@ def common_checks( raise Exception( f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}" ) - # 6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget + # 7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget if litellm.max_budget > 0 and global_proxy_spend is not None: if global_proxy_spend > litellm.max_budget: raise Exception( @@ -204,19 +212,24 @@ async def get_end_user_object( return None -async def get_user_object(self, user_id: str) -> LiteLLM_UserTable: +async def get_user_object( + user_id: str, + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, +) -> Optional[LiteLLM_UserTable]: """ - Check if user id in proxy User Table - if valid, return LiteLLM_UserTable object with defined limits - if not, then raise an error """ - if self.prisma_client is None: - raise Exception( - "No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys" - ) + if prisma_client is None: + raise Exception("No db connected") + + if user_id is None: + return None # check if in cache - cached_user_obj = self.user_api_key_cache.async_get_cache(key=user_id) + cached_user_obj = user_api_key_cache.async_get_cache(key=user_id) if cached_user_obj is not None: if isinstance(cached_user_obj, dict): return LiteLLM_UserTable(**cached_user_obj) @@ -224,7 +237,7 @@ async def get_user_object(self, user_id: str) -> LiteLLM_UserTable: return cached_user_obj # else, check db try: - response = await self.prisma_client.db.litellm_usertable.find_unique( + response = await prisma_client.db.litellm_usertable.find_unique( where={"user_id": user_id} ) @@ -232,10 +245,8 @@ async def get_user_object(self, user_id: str) -> LiteLLM_UserTable: raise Exception return LiteLLM_UserTable(**response.dict()) - except Exception as e: - raise Exception( - f"User doesn't exist in db. User={user_id}. Create user via `/user/new` call." - ) + except Exception as e: # if end-user not in db + return None async def get_team_object( diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index f9f7a6904..76042ec68 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -74,6 +74,16 @@ class JWTHandler: team_id = default_value return team_id + def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: + try: + if self.litellm_jwtauth.user_id_jwt_field is not None: + user_id = token[self.litellm_jwtauth.user_id_jwt_field] + else: + user_id = None + except KeyError: + user_id = default_value + return user_id + def get_scopes(self, token: dict) -> list: try: if isinstance(token["scope"], str): diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index a95070411..a322c32c3 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -422,12 +422,21 @@ async def user_api_key_auth( user_api_key_cache=user_api_key_cache, ) - # common checks - # allow request - - # get the request body - request_data = await _read_request_body(request=request) - + # [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable` + user_object = None + user_id = jwt_handler.get_user_id(token=valid_token, default_value=None) + if user_id is not None: + # get the user object + user_object = await get_user_object( + user_id=user_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + ) + # save the user object to cache + await user_api_key_cache.async_set_cache( + key=user_id, value=user_object + ) + # [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable` end_user_object = None end_user_id = jwt_handler.get_end_user_id( token=valid_token, default_value=None @@ -445,7 +454,6 @@ async def user_api_key_auth( ) global_proxy_spend = None - if litellm.max_budget > 0: # user set proxy max budget # check cache global_proxy_spend = await user_api_key_cache.async_get_cache( @@ -480,16 +488,20 @@ async def user_api_key_auth( ) ) + # get the request body + request_data = await _read_request_body(request=request) + # run through common checks _ = common_checks( request_body=request_data, team_object=team_object, + user_object=user_object, end_user_object=end_user_object, general_settings=general_settings, global_proxy_spend=global_proxy_spend, route=route, ) - # save user object in cache + # save team object in cache await user_api_key_cache.async_set_cache( key=team_object.team_id, value=team_object ) @@ -954,6 +966,7 @@ async def user_api_key_auth( _ = common_checks( request_body=request_data, team_object=_team_obj, + user_object=None, end_user_object=_end_user_object, general_settings=general_settings, global_proxy_spend=global_proxy_spend,