refactor(user_api_key_auth.py): refactor to replace user_id_information list with pydantic user_obj

Allows using the 'get_user_object' function in user_api_key_auth, keeping it consistent across jwt-auth and key-auth
This commit is contained in:
Krrish Dholakia 2024-08-07 15:33:55 -07:00
parent ff373663a3
commit f76261af35
4 changed files with 42 additions and 79 deletions

View file

@ -581,7 +581,6 @@ async def user_api_key_auth(
"allowed_model_region"
)
user_id_information: Optional[List] = None
if valid_token is not None:
user_obj: Optional[LiteLLM_UserTable] = None
# Got Valid Token from Cache, DB
@ -661,15 +660,6 @@ async def user_api_key_auth(
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
# value = user_api_key_cache.get_cache(key=id)
if user_obj is not None:
if user_id_information is None:
user_id_information = []
user_id_information.append(user_obj.model_dump())
verbose_proxy_logger.debug(
f"user_id_information: {user_id_information}"
)
# Check 3. Check if user is in their team budget
if valid_token.team_member_spend is not None:
@ -742,11 +732,8 @@ async def user_api_key_auth(
user_email: Optional[str] = None
# Check if the token has any user id information
if user_id_information is not None and len(user_id_information) > 0:
specific_user_id_information = user_id_information[0]
_user_email = specific_user_id_information.get("user_email", None)
if _user_email is not None:
user_email = str(_user_email)
if user_obj is not None:
user_email = user_obj.user_email
call_info = CallInfo(
token=valid_token.token,
@ -920,9 +907,9 @@ async def user_api_key_auth(
if _end_user_object is not None:
valid_token_dict.update(end_user_params)
_user_role = _get_user_role(user_id_information=user_id_information)
_user_role = _get_user_role(user_obj=user_obj)
if not _is_user_proxy_admin(user_id_information): # if non-admin
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
if is_llm_api_route(route=route):
pass
elif is_llm_api_route(route=request["route"].name):
@ -1004,14 +991,9 @@ async def user_api_key_auth(
else:
user_role = "unknown"
user_id = "unknown"
if (
user_id_information is not None
and isinstance(user_id_information, list)
and len(user_id_information) > 0
):
_user = user_id_information[0]
user_role = _user.get("user_role", "unknown")
user_id = _user.get("user_id", "unknown")
if user_obj is not None:
user_role = user_obj.user_role or "unknown"
user_id = user_obj.user_id or "unknown"
raise Exception(
f"Only proxy admin can be used to generate, delete, update info for new keys/users/teams. Route={route}. Your role={user_role}. Your user_id={user_id}"
)
@ -1057,9 +1039,7 @@ async def user_api_key_auth(
# Do something if the current route starts with any of the allowed routes
pass
else:
if user_id_information is not None and _is_user_proxy_admin(
user_id_information
):
if user_obj is not None and _is_user_proxy_admin(user_obj=user_obj):
return UserAPIKeyAuth(
api_key=api_key,
user_role=LitellmUserRoles.PROXY_ADMIN,
@ -1085,7 +1065,7 @@ async def user_api_key_auth(
raise Exception("Invalid proxy server token passed")
if valid_token_dict is not None:
return _return_user_api_key_auth_obj(
user_id_information=user_id_information,
user_obj=user_obj,
api_key=api_key,
parent_otel_span=parent_otel_span,
valid_token_dict=valid_token_dict,
@ -1132,17 +1112,16 @@ async def user_api_key_auth(
def _return_user_api_key_auth_obj(
user_id_information: Optional[list],
user_obj: Optional[LiteLLM_UserTable],
api_key: str,
parent_otel_span: Optional[Span],
valid_token_dict: dict,
route: str,
) -> UserAPIKeyAuth:
retrieved_user_role = (
_get_user_role(user_id_information=user_id_information)
or LitellmUserRoles.INTERNAL_USER
_get_user_role(user_obj=user_obj) or LitellmUserRoles.INTERNAL_USER
)
if user_id_information is not None and _is_user_proxy_admin(user_id_information):
if user_obj is not None and _is_user_proxy_admin(user_obj=user_obj):
return UserAPIKeyAuth(
api_key=api_key,
user_role=LitellmUserRoles.PROXY_ADMIN,
@ -1183,30 +1162,19 @@ def _has_user_setup_sso():
return sso_setup
def _is_user_proxy_admin(user_id_information: Optional[list]):
if user_id_information is None:
def _is_user_proxy_admin(user_obj: Optional[LiteLLM_UserTable]):
if user_obj is None:
return False
if len(user_id_information) == 0 or user_id_information[0] is None:
return False
_user = user_id_information[0]
if (
_user.get("user_role", None) is not None
and _user.get("user_role") == LitellmUserRoles.PROXY_ADMIN.value
user_obj.user_role is not None
and user_obj.user_role == LitellmUserRoles.PROXY_ADMIN.value
):
return True
# if user_id_information contains litellm-proxy-budget
# get first user_id that is not litellm-proxy-budget
for user in user_id_information:
if user.get("user_id") != "litellm-proxy-budget":
_user = user
break
if (
_user.get("user_role", None) is not None
and _user.get("user_role") == LitellmUserRoles.PROXY_ADMIN.value
user_obj.user_role is not None
and user_obj.user_role == LitellmUserRoles.PROXY_ADMIN.value
):
return True
@ -1214,29 +1182,20 @@ def _is_user_proxy_admin(user_id_information: Optional[list]):
def _get_user_role(
user_id_information: Optional[list],
) -> Optional[
Literal[
LitellmUserRoles.PROXY_ADMIN,
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY,
LitellmUserRoles.INTERNAL_USER,
LitellmUserRoles.INTERNAL_USER_VIEW_ONLY,
LitellmUserRoles.TEAM,
LitellmUserRoles.CUSTOMER,
]
]:
if user_id_information is None:
user_obj: Optional[LiteLLM_UserTable],
) -> Optional[LitellmUserRoles]:
if user_obj is None:
return None
if len(user_id_information) == 0 or user_id_information[0] is None:
return None
_user = user_obj
_user = user_id_information[0]
_user_role = _user.user_role
try:
role = LitellmUserRoles(_user_role)
except ValueError:
return LitellmUserRoles.INTERNAL_USER
_user_role = _user.get("user_role")
if _user_role in list(LitellmUserRoles.__annotations__.keys()):
return _user_role
return LitellmUserRoles.INTERNAL_USER
return role
def _check_valid_ip(allowed_ips: Optional[List[str]], request: Request) -> bool: