diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 286857eb0..3fa0ef51f 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1338,6 +1338,7 @@ class LiteLLM_UserTable(LiteLLMBase): models: list = [] tpm_limit: Optional[int] = None rpm_limit: Optional[int] = None + user_role: Optional[str] = None @model_validator(mode="before") @classmethod diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 115c90bb6..ca1a1a787 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -581,7 +581,6 @@ async def user_api_key_auth( "allowed_model_region" ) - user_id_information: Optional[List] = None if valid_token is not None: user_obj: Optional[LiteLLM_UserTable] = None # Got Valid Token from Cache, DB @@ -661,15 +660,6 @@ async def user_api_key_auth( parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj, ) - # value = user_api_key_cache.get_cache(key=id) - if user_obj is not None: - if user_id_information is None: - user_id_information = [] - user_id_information.append(user_obj.model_dump()) - - verbose_proxy_logger.debug( - f"user_id_information: {user_id_information}" - ) # Check 3. Check if user is in their team budget if valid_token.team_member_spend is not None: @@ -742,11 +732,8 @@ async def user_api_key_auth( user_email: Optional[str] = None # Check if the token has any user id information - if user_id_information is not None and len(user_id_information) > 0: - specific_user_id_information = user_id_information[0] - _user_email = specific_user_id_information.get("user_email", None) - if _user_email is not None: - user_email = str(_user_email) + if user_obj is not None: + user_email = user_obj.user_email call_info = CallInfo( token=valid_token.token, @@ -920,9 +907,9 @@ async def user_api_key_auth( if _end_user_object is not None: valid_token_dict.update(end_user_params) - _user_role = _get_user_role(user_id_information=user_id_information) + _user_role = _get_user_role(user_obj=user_obj) - if not _is_user_proxy_admin(user_id_information): # if non-admin + if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin if is_llm_api_route(route=route): pass elif is_llm_api_route(route=request["route"].name): @@ -1004,14 +991,9 @@ async def user_api_key_auth( else: user_role = "unknown" user_id = "unknown" - if ( - user_id_information is not None - and isinstance(user_id_information, list) - and len(user_id_information) > 0 - ): - _user = user_id_information[0] - user_role = _user.get("user_role", "unknown") - user_id = _user.get("user_id", "unknown") + if user_obj is not None: + user_role = user_obj.user_role or "unknown" + user_id = user_obj.user_id or "unknown" raise Exception( f"Only proxy admin can be used to generate, delete, update info for new keys/users/teams. Route={route}. Your role={user_role}. Your user_id={user_id}" ) @@ -1057,9 +1039,7 @@ async def user_api_key_auth( # Do something if the current route starts with any of the allowed routes pass else: - if user_id_information is not None and _is_user_proxy_admin( - user_id_information - ): + if user_obj is not None and _is_user_proxy_admin(user_obj=user_obj): return UserAPIKeyAuth( api_key=api_key, user_role=LitellmUserRoles.PROXY_ADMIN, @@ -1085,7 +1065,7 @@ async def user_api_key_auth( raise Exception("Invalid proxy server token passed") if valid_token_dict is not None: return _return_user_api_key_auth_obj( - user_id_information=user_id_information, + user_obj=user_obj, api_key=api_key, parent_otel_span=parent_otel_span, valid_token_dict=valid_token_dict, @@ -1132,17 +1112,16 @@ async def user_api_key_auth( def _return_user_api_key_auth_obj( - user_id_information: Optional[list], + user_obj: Optional[LiteLLM_UserTable], api_key: str, parent_otel_span: Optional[Span], valid_token_dict: dict, route: str, ) -> UserAPIKeyAuth: retrieved_user_role = ( - _get_user_role(user_id_information=user_id_information) - or LitellmUserRoles.INTERNAL_USER + _get_user_role(user_obj=user_obj) or LitellmUserRoles.INTERNAL_USER ) - if user_id_information is not None and _is_user_proxy_admin(user_id_information): + if user_obj is not None and _is_user_proxy_admin(user_obj=user_obj): return UserAPIKeyAuth( api_key=api_key, user_role=LitellmUserRoles.PROXY_ADMIN, @@ -1183,30 +1162,19 @@ def _has_user_setup_sso(): return sso_setup -def _is_user_proxy_admin(user_id_information: Optional[list]): - if user_id_information is None: +def _is_user_proxy_admin(user_obj: Optional[LiteLLM_UserTable]): + if user_obj is None: return False - if len(user_id_information) == 0 or user_id_information[0] is None: - return False - - _user = user_id_information[0] if ( - _user.get("user_role", None) is not None - and _user.get("user_role") == LitellmUserRoles.PROXY_ADMIN.value + user_obj.user_role is not None + and user_obj.user_role == LitellmUserRoles.PROXY_ADMIN.value ): return True - # if user_id_information contains litellm-proxy-budget - # get first user_id that is not litellm-proxy-budget - for user in user_id_information: - if user.get("user_id") != "litellm-proxy-budget": - _user = user - break - if ( - _user.get("user_role", None) is not None - and _user.get("user_role") == LitellmUserRoles.PROXY_ADMIN.value + user_obj.user_role is not None + and user_obj.user_role == LitellmUserRoles.PROXY_ADMIN.value ): return True @@ -1214,29 +1182,20 @@ def _is_user_proxy_admin(user_id_information: Optional[list]): def _get_user_role( - user_id_information: Optional[list], -) -> Optional[ - Literal[ - LitellmUserRoles.PROXY_ADMIN, - LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY, - LitellmUserRoles.INTERNAL_USER, - LitellmUserRoles.INTERNAL_USER_VIEW_ONLY, - LitellmUserRoles.TEAM, - LitellmUserRoles.CUSTOMER, - ] -]: - if user_id_information is None: + user_obj: Optional[LiteLLM_UserTable], +) -> Optional[LitellmUserRoles]: + if user_obj is None: return None - if len(user_id_information) == 0 or user_id_information[0] is None: - return None + _user = user_obj - _user = user_id_information[0] + _user_role = _user.user_role + try: + role = LitellmUserRoles(_user_role) + except ValueError: + return LitellmUserRoles.INTERNAL_USER - _user_role = _user.get("user_role") - if _user_role in list(LitellmUserRoles.__annotations__.keys()): - return _user_role - return LitellmUserRoles.INTERNAL_USER + return role def _check_valid_ip(allowed_ips: Optional[List[str]], request: Request) -> bool: diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index 12204ec06..93110ffb6 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -2555,6 +2555,7 @@ async def test_update_user_role(prisma_client): await asyncio.sleep(2) # use generated key to auth in + print("\n\nMAKING NEW REQUEST WITH UPDATED USER ROLE\n\n") result = await user_api_key_auth(request=request, api_key=api_key) print("result from user auth with new key", result) diff --git a/litellm/tests/test_user_api_key_auth.py b/litellm/tests/test_user_api_key_auth.py index e8f5a8e08..33f055b37 100644 --- a/litellm/tests/test_user_api_key_auth.py +++ b/litellm/tests/test_user_api_key_auth.py @@ -96,26 +96,28 @@ async def test_check_blocked_team(): @pytest.mark.parametrize( - "user_role", ["app_user", "internal_user", "proxy_admin_viewer"] + "user_role, expected_role", + [ + ("app_user", "internal_user"), + ("internal_user", "internal_user"), + ("proxy_admin_viewer", "proxy_admin_viewer"), + ], ) -def test_returned_user_api_key_auth(user_role): - from litellm.proxy._types import LitellmUserRoles +def test_returned_user_api_key_auth(user_role, expected_role): + from litellm.proxy._types import LiteLLM_UserTable, LitellmUserRoles from litellm.proxy.auth.user_api_key_auth import _return_user_api_key_auth_obj - user_id_information = [{"user_role": user_role}] - new_obj = _return_user_api_key_auth_obj( - user_id_information, + user_obj=LiteLLM_UserTable( + user_role=user_role, user_id="", max_budget=None, user_email="" + ), api_key="hello-world", parent_otel_span=None, valid_token_dict={}, route="/chat/completion", ) - if user_role in list(LitellmUserRoles.__annotations__.keys()): - assert new_obj.user_role == user_role - else: - assert new_obj.user_role == "internal_user" + assert new_obj.user_role == expected_role @pytest.mark.parametrize("key_ownership", ["user_key", "team_key"])