mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
refactor(user_api_key_auth.py): refactor to replace user_id_information list with pydantic user_obj
Allows using the 'get_user_object' function in user_api_key_auth, keeping it consistent across jwt-auth and key-auth
This commit is contained in:
parent
228da08b81
commit
86a3dba1bf
4 changed files with 42 additions and 79 deletions
|
@ -1338,6 +1338,7 @@ class LiteLLM_UserTable(LiteLLMBase):
|
||||||
models: list = []
|
models: list = []
|
||||||
tpm_limit: Optional[int] = None
|
tpm_limit: Optional[int] = None
|
||||||
rpm_limit: Optional[int] = None
|
rpm_limit: Optional[int] = None
|
||||||
|
user_role: Optional[str] = None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -581,7 +581,6 @@ async def user_api_key_auth(
|
||||||
"allowed_model_region"
|
"allowed_model_region"
|
||||||
)
|
)
|
||||||
|
|
||||||
user_id_information: Optional[List] = None
|
|
||||||
if valid_token is not None:
|
if valid_token is not None:
|
||||||
user_obj: Optional[LiteLLM_UserTable] = None
|
user_obj: Optional[LiteLLM_UserTable] = None
|
||||||
# Got Valid Token from Cache, DB
|
# Got Valid Token from Cache, DB
|
||||||
|
@ -661,15 +660,6 @@ async def user_api_key_auth(
|
||||||
parent_otel_span=parent_otel_span,
|
parent_otel_span=parent_otel_span,
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
)
|
)
|
||||||
# value = user_api_key_cache.get_cache(key=id)
|
|
||||||
if user_obj is not None:
|
|
||||||
if user_id_information is None:
|
|
||||||
user_id_information = []
|
|
||||||
user_id_information.append(user_obj.model_dump())
|
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f"user_id_information: {user_id_information}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check 3. Check if user is in their team budget
|
# Check 3. Check if user is in their team budget
|
||||||
if valid_token.team_member_spend is not None:
|
if valid_token.team_member_spend is not None:
|
||||||
|
@ -742,11 +732,8 @@ async def user_api_key_auth(
|
||||||
|
|
||||||
user_email: Optional[str] = None
|
user_email: Optional[str] = None
|
||||||
# Check if the token has any user id information
|
# Check if the token has any user id information
|
||||||
if user_id_information is not None and len(user_id_information) > 0:
|
if user_obj is not None:
|
||||||
specific_user_id_information = user_id_information[0]
|
user_email = user_obj.user_email
|
||||||
_user_email = specific_user_id_information.get("user_email", None)
|
|
||||||
if _user_email is not None:
|
|
||||||
user_email = str(_user_email)
|
|
||||||
|
|
||||||
call_info = CallInfo(
|
call_info = CallInfo(
|
||||||
token=valid_token.token,
|
token=valid_token.token,
|
||||||
|
@ -920,9 +907,9 @@ async def user_api_key_auth(
|
||||||
if _end_user_object is not None:
|
if _end_user_object is not None:
|
||||||
valid_token_dict.update(end_user_params)
|
valid_token_dict.update(end_user_params)
|
||||||
|
|
||||||
_user_role = _get_user_role(user_id_information=user_id_information)
|
_user_role = _get_user_role(user_obj=user_obj)
|
||||||
|
|
||||||
if not _is_user_proxy_admin(user_id_information): # if non-admin
|
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
|
||||||
if is_llm_api_route(route=route):
|
if is_llm_api_route(route=route):
|
||||||
pass
|
pass
|
||||||
elif is_llm_api_route(route=request["route"].name):
|
elif is_llm_api_route(route=request["route"].name):
|
||||||
|
@ -1004,14 +991,9 @@ async def user_api_key_auth(
|
||||||
else:
|
else:
|
||||||
user_role = "unknown"
|
user_role = "unknown"
|
||||||
user_id = "unknown"
|
user_id = "unknown"
|
||||||
if (
|
if user_obj is not None:
|
||||||
user_id_information is not None
|
user_role = user_obj.user_role or "unknown"
|
||||||
and isinstance(user_id_information, list)
|
user_id = user_obj.user_id or "unknown"
|
||||||
and len(user_id_information) > 0
|
|
||||||
):
|
|
||||||
_user = user_id_information[0]
|
|
||||||
user_role = _user.get("user_role", "unknown")
|
|
||||||
user_id = _user.get("user_id", "unknown")
|
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Only proxy admin can be used to generate, delete, update info for new keys/users/teams. Route={route}. Your role={user_role}. Your user_id={user_id}"
|
f"Only proxy admin can be used to generate, delete, update info for new keys/users/teams. Route={route}. Your role={user_role}. Your user_id={user_id}"
|
||||||
)
|
)
|
||||||
|
@ -1057,9 +1039,7 @@ async def user_api_key_auth(
|
||||||
# Do something if the current route starts with any of the allowed routes
|
# Do something if the current route starts with any of the allowed routes
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
if user_id_information is not None and _is_user_proxy_admin(
|
if user_obj is not None and _is_user_proxy_admin(user_obj=user_obj):
|
||||||
user_id_information
|
|
||||||
):
|
|
||||||
return UserAPIKeyAuth(
|
return UserAPIKeyAuth(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
@ -1085,7 +1065,7 @@ async def user_api_key_auth(
|
||||||
raise Exception("Invalid proxy server token passed")
|
raise Exception("Invalid proxy server token passed")
|
||||||
if valid_token_dict is not None:
|
if valid_token_dict is not None:
|
||||||
return _return_user_api_key_auth_obj(
|
return _return_user_api_key_auth_obj(
|
||||||
user_id_information=user_id_information,
|
user_obj=user_obj,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
parent_otel_span=parent_otel_span,
|
parent_otel_span=parent_otel_span,
|
||||||
valid_token_dict=valid_token_dict,
|
valid_token_dict=valid_token_dict,
|
||||||
|
@ -1132,17 +1112,16 @@ async def user_api_key_auth(
|
||||||
|
|
||||||
|
|
||||||
def _return_user_api_key_auth_obj(
|
def _return_user_api_key_auth_obj(
|
||||||
user_id_information: Optional[list],
|
user_obj: Optional[LiteLLM_UserTable],
|
||||||
api_key: str,
|
api_key: str,
|
||||||
parent_otel_span: Optional[Span],
|
parent_otel_span: Optional[Span],
|
||||||
valid_token_dict: dict,
|
valid_token_dict: dict,
|
||||||
route: str,
|
route: str,
|
||||||
) -> UserAPIKeyAuth:
|
) -> UserAPIKeyAuth:
|
||||||
retrieved_user_role = (
|
retrieved_user_role = (
|
||||||
_get_user_role(user_id_information=user_id_information)
|
_get_user_role(user_obj=user_obj) or LitellmUserRoles.INTERNAL_USER
|
||||||
or LitellmUserRoles.INTERNAL_USER
|
|
||||||
)
|
)
|
||||||
if user_id_information is not None and _is_user_proxy_admin(user_id_information):
|
if user_obj is not None and _is_user_proxy_admin(user_obj=user_obj):
|
||||||
return UserAPIKeyAuth(
|
return UserAPIKeyAuth(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
@ -1183,30 +1162,19 @@ def _has_user_setup_sso():
|
||||||
return sso_setup
|
return sso_setup
|
||||||
|
|
||||||
|
|
||||||
def _is_user_proxy_admin(user_id_information: Optional[list]):
|
def _is_user_proxy_admin(user_obj: Optional[LiteLLM_UserTable]):
|
||||||
if user_id_information is None:
|
if user_obj is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if len(user_id_information) == 0 or user_id_information[0] is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
_user = user_id_information[0]
|
|
||||||
if (
|
if (
|
||||||
_user.get("user_role", None) is not None
|
user_obj.user_role is not None
|
||||||
and _user.get("user_role") == LitellmUserRoles.PROXY_ADMIN.value
|
and user_obj.user_role == LitellmUserRoles.PROXY_ADMIN.value
|
||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# if user_id_information contains litellm-proxy-budget
|
|
||||||
# get first user_id that is not litellm-proxy-budget
|
|
||||||
for user in user_id_information:
|
|
||||||
if user.get("user_id") != "litellm-proxy-budget":
|
|
||||||
_user = user
|
|
||||||
break
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
_user.get("user_role", None) is not None
|
user_obj.user_role is not None
|
||||||
and _user.get("user_role") == LitellmUserRoles.PROXY_ADMIN.value
|
and user_obj.user_role == LitellmUserRoles.PROXY_ADMIN.value
|
||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -1214,29 +1182,20 @@ def _is_user_proxy_admin(user_id_information: Optional[list]):
|
||||||
|
|
||||||
|
|
||||||
def _get_user_role(
|
def _get_user_role(
|
||||||
user_id_information: Optional[list],
|
user_obj: Optional[LiteLLM_UserTable],
|
||||||
) -> Optional[
|
) -> Optional[LitellmUserRoles]:
|
||||||
Literal[
|
if user_obj is None:
|
||||||
LitellmUserRoles.PROXY_ADMIN,
|
|
||||||
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY,
|
|
||||||
LitellmUserRoles.INTERNAL_USER,
|
|
||||||
LitellmUserRoles.INTERNAL_USER_VIEW_ONLY,
|
|
||||||
LitellmUserRoles.TEAM,
|
|
||||||
LitellmUserRoles.CUSTOMER,
|
|
||||||
]
|
|
||||||
]:
|
|
||||||
if user_id_information is None:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if len(user_id_information) == 0 or user_id_information[0] is None:
|
_user = user_obj
|
||||||
return None
|
|
||||||
|
|
||||||
_user = user_id_information[0]
|
_user_role = _user.user_role
|
||||||
|
try:
|
||||||
|
role = LitellmUserRoles(_user_role)
|
||||||
|
except ValueError:
|
||||||
|
return LitellmUserRoles.INTERNAL_USER
|
||||||
|
|
||||||
_user_role = _user.get("user_role")
|
return role
|
||||||
if _user_role in list(LitellmUserRoles.__annotations__.keys()):
|
|
||||||
return _user_role
|
|
||||||
return LitellmUserRoles.INTERNAL_USER
|
|
||||||
|
|
||||||
|
|
||||||
def _check_valid_ip(allowed_ips: Optional[List[str]], request: Request) -> bool:
|
def _check_valid_ip(allowed_ips: Optional[List[str]], request: Request) -> bool:
|
||||||
|
|
|
@ -2555,6 +2555,7 @@ async def test_update_user_role(prisma_client):
|
||||||
await asyncio.sleep(2)
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
# use generated key to auth in
|
# use generated key to auth in
|
||||||
|
print("\n\nMAKING NEW REQUEST WITH UPDATED USER ROLE\n\n")
|
||||||
result = await user_api_key_auth(request=request, api_key=api_key)
|
result = await user_api_key_auth(request=request, api_key=api_key)
|
||||||
print("result from user auth with new key", result)
|
print("result from user auth with new key", result)
|
||||||
|
|
||||||
|
|
|
@ -96,26 +96,28 @@ async def test_check_blocked_team():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"user_role", ["app_user", "internal_user", "proxy_admin_viewer"]
|
"user_role, expected_role",
|
||||||
|
[
|
||||||
|
("app_user", "internal_user"),
|
||||||
|
("internal_user", "internal_user"),
|
||||||
|
("proxy_admin_viewer", "proxy_admin_viewer"),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
def test_returned_user_api_key_auth(user_role):
|
def test_returned_user_api_key_auth(user_role, expected_role):
|
||||||
from litellm.proxy._types import LitellmUserRoles
|
from litellm.proxy._types import LiteLLM_UserTable, LitellmUserRoles
|
||||||
from litellm.proxy.auth.user_api_key_auth import _return_user_api_key_auth_obj
|
from litellm.proxy.auth.user_api_key_auth import _return_user_api_key_auth_obj
|
||||||
|
|
||||||
user_id_information = [{"user_role": user_role}]
|
|
||||||
|
|
||||||
new_obj = _return_user_api_key_auth_obj(
|
new_obj = _return_user_api_key_auth_obj(
|
||||||
user_id_information,
|
user_obj=LiteLLM_UserTable(
|
||||||
|
user_role=user_role, user_id="", max_budget=None, user_email=""
|
||||||
|
),
|
||||||
api_key="hello-world",
|
api_key="hello-world",
|
||||||
parent_otel_span=None,
|
parent_otel_span=None,
|
||||||
valid_token_dict={},
|
valid_token_dict={},
|
||||||
route="/chat/completion",
|
route="/chat/completion",
|
||||||
)
|
)
|
||||||
|
|
||||||
if user_role in list(LitellmUserRoles.__annotations__.keys()):
|
assert new_obj.user_role == expected_role
|
||||||
assert new_obj.user_role == user_role
|
|
||||||
else:
|
|
||||||
assert new_obj.user_role == "internal_user"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("key_ownership", ["user_key", "team_key"])
|
@pytest.mark.parametrize("key_ownership", ["user_key", "team_key"])
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue