refactor(user_api_key_auth.py): refactor to replace user_id_information list with pydantic user_obj

Allows using the 'get_user_object' function in user_api_key_auth, keeping it consistent across jwt-auth and key-auth
This commit is contained in:
Krrish Dholakia 2024-08-07 15:33:55 -07:00
parent ff373663a3
commit f76261af35
4 changed files with 42 additions and 79 deletions

View file

@ -1338,6 +1338,7 @@ class LiteLLM_UserTable(LiteLLMBase):
models: list = [] models: list = []
tpm_limit: Optional[int] = None tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None rpm_limit: Optional[int] = None
user_role: Optional[str] = None
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod

View file

@ -581,7 +581,6 @@ async def user_api_key_auth(
"allowed_model_region" "allowed_model_region"
) )
user_id_information: Optional[List] = None
if valid_token is not None: if valid_token is not None:
user_obj: Optional[LiteLLM_UserTable] = None user_obj: Optional[LiteLLM_UserTable] = None
# Got Valid Token from Cache, DB # Got Valid Token from Cache, DB
@ -661,15 +660,6 @@ async def user_api_key_auth(
parent_otel_span=parent_otel_span, parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj, proxy_logging_obj=proxy_logging_obj,
) )
# value = user_api_key_cache.get_cache(key=id)
if user_obj is not None:
if user_id_information is None:
user_id_information = []
user_id_information.append(user_obj.model_dump())
verbose_proxy_logger.debug(
f"user_id_information: {user_id_information}"
)
# Check 3. Check if user is in their team budget # Check 3. Check if user is in their team budget
if valid_token.team_member_spend is not None: if valid_token.team_member_spend is not None:
@ -742,11 +732,8 @@ async def user_api_key_auth(
user_email: Optional[str] = None user_email: Optional[str] = None
# Check if the token has any user id information # Check if the token has any user id information
if user_id_information is not None and len(user_id_information) > 0: if user_obj is not None:
specific_user_id_information = user_id_information[0] user_email = user_obj.user_email
_user_email = specific_user_id_information.get("user_email", None)
if _user_email is not None:
user_email = str(_user_email)
call_info = CallInfo( call_info = CallInfo(
token=valid_token.token, token=valid_token.token,
@ -920,9 +907,9 @@ async def user_api_key_auth(
if _end_user_object is not None: if _end_user_object is not None:
valid_token_dict.update(end_user_params) valid_token_dict.update(end_user_params)
_user_role = _get_user_role(user_id_information=user_id_information) _user_role = _get_user_role(user_obj=user_obj)
if not _is_user_proxy_admin(user_id_information): # if non-admin if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
if is_llm_api_route(route=route): if is_llm_api_route(route=route):
pass pass
elif is_llm_api_route(route=request["route"].name): elif is_llm_api_route(route=request["route"].name):
@ -1004,14 +991,9 @@ async def user_api_key_auth(
else: else:
user_role = "unknown" user_role = "unknown"
user_id = "unknown" user_id = "unknown"
if ( if user_obj is not None:
user_id_information is not None user_role = user_obj.user_role or "unknown"
and isinstance(user_id_information, list) user_id = user_obj.user_id or "unknown"
and len(user_id_information) > 0
):
_user = user_id_information[0]
user_role = _user.get("user_role", "unknown")
user_id = _user.get("user_id", "unknown")
raise Exception( raise Exception(
f"Only proxy admin can be used to generate, delete, update info for new keys/users/teams. Route={route}. Your role={user_role}. Your user_id={user_id}" f"Only proxy admin can be used to generate, delete, update info for new keys/users/teams. Route={route}. Your role={user_role}. Your user_id={user_id}"
) )
@ -1057,9 +1039,7 @@ async def user_api_key_auth(
# Do something if the current route starts with any of the allowed routes # Do something if the current route starts with any of the allowed routes
pass pass
else: else:
if user_id_information is not None and _is_user_proxy_admin( if user_obj is not None and _is_user_proxy_admin(user_obj=user_obj):
user_id_information
):
return UserAPIKeyAuth( return UserAPIKeyAuth(
api_key=api_key, api_key=api_key,
user_role=LitellmUserRoles.PROXY_ADMIN, user_role=LitellmUserRoles.PROXY_ADMIN,
@ -1085,7 +1065,7 @@ async def user_api_key_auth(
raise Exception("Invalid proxy server token passed") raise Exception("Invalid proxy server token passed")
if valid_token_dict is not None: if valid_token_dict is not None:
return _return_user_api_key_auth_obj( return _return_user_api_key_auth_obj(
user_id_information=user_id_information, user_obj=user_obj,
api_key=api_key, api_key=api_key,
parent_otel_span=parent_otel_span, parent_otel_span=parent_otel_span,
valid_token_dict=valid_token_dict, valid_token_dict=valid_token_dict,
@ -1132,17 +1112,16 @@ async def user_api_key_auth(
def _return_user_api_key_auth_obj( def _return_user_api_key_auth_obj(
user_id_information: Optional[list], user_obj: Optional[LiteLLM_UserTable],
api_key: str, api_key: str,
parent_otel_span: Optional[Span], parent_otel_span: Optional[Span],
valid_token_dict: dict, valid_token_dict: dict,
route: str, route: str,
) -> UserAPIKeyAuth: ) -> UserAPIKeyAuth:
retrieved_user_role = ( retrieved_user_role = (
_get_user_role(user_id_information=user_id_information) _get_user_role(user_obj=user_obj) or LitellmUserRoles.INTERNAL_USER
or LitellmUserRoles.INTERNAL_USER
) )
if user_id_information is not None and _is_user_proxy_admin(user_id_information): if user_obj is not None and _is_user_proxy_admin(user_obj=user_obj):
return UserAPIKeyAuth( return UserAPIKeyAuth(
api_key=api_key, api_key=api_key,
user_role=LitellmUserRoles.PROXY_ADMIN, user_role=LitellmUserRoles.PROXY_ADMIN,
@ -1183,30 +1162,19 @@ def _has_user_setup_sso():
return sso_setup return sso_setup
def _is_user_proxy_admin(user_id_information: Optional[list]): def _is_user_proxy_admin(user_obj: Optional[LiteLLM_UserTable]):
if user_id_information is None: if user_obj is None:
return False return False
if len(user_id_information) == 0 or user_id_information[0] is None:
return False
_user = user_id_information[0]
if ( if (
_user.get("user_role", None) is not None user_obj.user_role is not None
and _user.get("user_role") == LitellmUserRoles.PROXY_ADMIN.value and user_obj.user_role == LitellmUserRoles.PROXY_ADMIN.value
): ):
return True return True
# if user_id_information contains litellm-proxy-budget
# get first user_id that is not litellm-proxy-budget
for user in user_id_information:
if user.get("user_id") != "litellm-proxy-budget":
_user = user
break
if ( if (
_user.get("user_role", None) is not None user_obj.user_role is not None
and _user.get("user_role") == LitellmUserRoles.PROXY_ADMIN.value and user_obj.user_role == LitellmUserRoles.PROXY_ADMIN.value
): ):
return True return True
@ -1214,29 +1182,20 @@ def _is_user_proxy_admin(user_id_information: Optional[list]):
def _get_user_role( def _get_user_role(
user_id_information: Optional[list], user_obj: Optional[LiteLLM_UserTable],
) -> Optional[ ) -> Optional[LitellmUserRoles]:
Literal[ if user_obj is None:
LitellmUserRoles.PROXY_ADMIN,
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY,
LitellmUserRoles.INTERNAL_USER,
LitellmUserRoles.INTERNAL_USER_VIEW_ONLY,
LitellmUserRoles.TEAM,
LitellmUserRoles.CUSTOMER,
]
]:
if user_id_information is None:
return None return None
if len(user_id_information) == 0 or user_id_information[0] is None: _user = user_obj
return None
_user = user_id_information[0] _user_role = _user.user_role
try:
role = LitellmUserRoles(_user_role)
except ValueError:
return LitellmUserRoles.INTERNAL_USER
_user_role = _user.get("user_role") return role
if _user_role in list(LitellmUserRoles.__annotations__.keys()):
return _user_role
return LitellmUserRoles.INTERNAL_USER
def _check_valid_ip(allowed_ips: Optional[List[str]], request: Request) -> bool: def _check_valid_ip(allowed_ips: Optional[List[str]], request: Request) -> bool:

View file

@ -2555,6 +2555,7 @@ async def test_update_user_role(prisma_client):
await asyncio.sleep(2) await asyncio.sleep(2)
# use generated key to auth in # use generated key to auth in
print("\n\nMAKING NEW REQUEST WITH UPDATED USER ROLE\n\n")
result = await user_api_key_auth(request=request, api_key=api_key) result = await user_api_key_auth(request=request, api_key=api_key)
print("result from user auth with new key", result) print("result from user auth with new key", result)

View file

@ -96,26 +96,28 @@ async def test_check_blocked_team():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"user_role", ["app_user", "internal_user", "proxy_admin_viewer"] "user_role, expected_role",
[
("app_user", "internal_user"),
("internal_user", "internal_user"),
("proxy_admin_viewer", "proxy_admin_viewer"),
],
) )
def test_returned_user_api_key_auth(user_role): def test_returned_user_api_key_auth(user_role, expected_role):
from litellm.proxy._types import LitellmUserRoles from litellm.proxy._types import LiteLLM_UserTable, LitellmUserRoles
from litellm.proxy.auth.user_api_key_auth import _return_user_api_key_auth_obj from litellm.proxy.auth.user_api_key_auth import _return_user_api_key_auth_obj
user_id_information = [{"user_role": user_role}]
new_obj = _return_user_api_key_auth_obj( new_obj = _return_user_api_key_auth_obj(
user_id_information, user_obj=LiteLLM_UserTable(
user_role=user_role, user_id="", max_budget=None, user_email=""
),
api_key="hello-world", api_key="hello-world",
parent_otel_span=None, parent_otel_span=None,
valid_token_dict={}, valid_token_dict={},
route="/chat/completion", route="/chat/completion",
) )
if user_role in list(LitellmUserRoles.__annotations__.keys()): assert new_obj.user_role == expected_role
assert new_obj.user_role == user_role
else:
assert new_obj.user_role == "internal_user"
@pytest.mark.parametrize("key_ownership", ["user_key", "team_key"]) @pytest.mark.parametrize("key_ownership", ["user_key", "team_key"])