fix literal usage

This commit is contained in:
Ishaan Jaff 2024-05-30 14:28:53 -07:00
parent 4861ff2fd4
commit a4b6a959d8
4 changed files with 82 additions and 65 deletions

View file

@ -522,7 +522,16 @@ class LiteLLM_ModelTable(LiteLLMBase):
class NewUserRequest(GenerateKeyRequest):
max_budget: Optional[float] = None
user_email: Optional[str] = None
user_role: Optional[str] = None
user_role: Optional[
Literal[
LitellmUserRoles.PROXY_ADMIN,
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY,
LitellmUserRoles.INTERNAL_USER,
LitellmUserRoles.INTERNAL_USER_VIEW_ONLY,
LitellmUserRoles.TEAM,
LitellmUserRoles.CUSTOMER,
]
] = None
teams: Optional[list] = None
organization_id: Optional[str] = None
auto_create_key: bool = (
@ -541,7 +550,16 @@ class UpdateUserRequest(GenerateRequestBase):
user_email: Optional[str] = None
spend: Optional[float] = None
metadata: Optional[dict] = None
user_role: Optional[str] = None
user_role: Optional[
Literal[
LitellmUserRoles.PROXY_ADMIN,
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY,
LitellmUserRoles.INTERNAL_USER,
LitellmUserRoles.INTERNAL_USER_VIEW_ONLY,
LitellmUserRoles.TEAM,
LitellmUserRoles.CUSTOMER,
]
] = None
max_budget: Optional[float] = None
@root_validator(pre=True)
@ -1088,12 +1106,12 @@ class UserAPIKeyAuth(
api_key: Optional[str] = None
user_role: Optional[
Literal[
"proxy_admin",
"proxy_admin_view_only",
"internal_user",
"internal_user_view_only",
"team",
"customer",
LitellmUserRoles.PROXY_ADMIN,
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY,
LitellmUserRoles.INTERNAL_USER,
LitellmUserRoles.INTERNAL_USER_VIEW_ONLY,
LitellmUserRoles.TEAM,
LitellmUserRoles.CUSTOMER,
]
] = None
allowed_model_region: Optional[Literal["eu"]] = None

View file

@ -15,6 +15,7 @@ from litellm.proxy._types import (
LiteLLM_TeamTable,
LiteLLMRoutes,
LiteLLM_OrganizationTable,
LitellmUserRoles,
)
from typing import Optional, Literal, Union
from litellm.proxy.utils import PrismaClient
@ -133,7 +134,11 @@ def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool:
def allowed_routes_check(
user_role: Literal["proxy_admin", "team", "user"],
user_role: Literal[
LitellmUserRoles.PROXY_ADMIN,
LitellmUserRoles.TEAM,
LitellmUserRoles.INTERNAL_USER,
],
user_route: str,
litellm_proxy_roles: LiteLLM_JWTAuth,
) -> bool:
@ -141,14 +146,14 @@ def allowed_routes_check(
Check if user -> not admin - allowed to access these routes
"""
if user_role == "proxy_admin":
if user_role == LitellmUserRoles.PROXY_ADMIN:
is_allowed = _allowed_routes_check(
user_route=user_route,
allowed_routes=litellm_proxy_roles.admin_allowed_routes,
)
return is_allowed
elif user_role == "team":
elif user_role == LitellmUserRoles.TEAM:
if litellm_proxy_roles.team_allowed_routes is None:
"""
By default allow a team to call openai + info routes

View file

@ -507,9 +507,7 @@ async def user_api_key_auth(
if route in LiteLLMRoutes.public_routes.value:
# check if public endpoint
return UserAPIKeyAuth(
user_role=LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value
)
return UserAPIKeyAuth(user_role=LitellmUserRoles.INTERNAL_USER_VIEW_ONLY)
if general_settings.get("enable_jwt_auth", False) == True:
is_jwt = jwt_handler.is_jwt(token=api_key)
@ -526,14 +524,12 @@ async def user_api_key_auth(
if is_admin:
# check allowed admin routes
is_allowed = allowed_routes_check(
user_role=LitellmUserRoles.PROXY_ADMIN.value,
user_role=LitellmUserRoles.PROXY_ADMIN,
user_route=route,
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
)
if is_allowed:
return UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN.value
)
return UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN)
else:
allowed_routes = (
jwt_handler.litellm_jwtauth.admin_allowed_routes
@ -556,7 +552,7 @@ async def user_api_key_auth(
if team_id is not None:
# check allowed team routes
is_allowed = allowed_routes_check(
user_role="team",
user_role=LitellmUserRoles.TEAM,
user_route=route,
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
)
@ -668,7 +664,7 @@ async def user_api_key_auth(
team_object.rpm_limit if team_object is not None else None
),
team_models=team_object.models if team_object is not None else [],
user_role=LitellmUserRoles.INTERNAL_USER.value,
user_role=LitellmUserRoles.INTERNAL_USER,
user_id=user_id,
org_id=org_id,
)
@ -676,10 +672,10 @@ async def user_api_key_auth(
if master_key is None:
if isinstance(api_key, str):
return UserAPIKeyAuth(
api_key=api_key, user_role=LitellmUserRoles.PROXY_ADMIN.value
api_key=api_key, user_role=LitellmUserRoles.PROXY_ADMIN
)
else:
return UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN.value)
return UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN)
elif api_key is None: # only require api key if master key is set
raise Exception("No api key passed in.")
elif api_key == "":
@ -746,7 +742,7 @@ async def user_api_key_auth(
if (
valid_token is not None
and isinstance(valid_token, UserAPIKeyAuth)
and valid_token.user_role == LitellmUserRoles.PROXY_ADMIN.value
and valid_token.user_role == LitellmUserRoles.PROXY_ADMIN
):
# update end-user params on valid token
valid_token.end_user_id = end_user_params.get("end_user_id")
@ -779,7 +775,7 @@ async def user_api_key_auth(
if is_master_key_valid:
_user_api_key_obj = UserAPIKeyAuth(
api_key=master_key,
user_role=LitellmUserRoles.PROXY_ADMIN.value,
user_role=LitellmUserRoles.PROXY_ADMIN,
user_id=litellm_proxy_admin_name,
**end_user_params,
)
@ -1384,7 +1380,7 @@ async def user_api_key_auth(
):
return UserAPIKeyAuth(
api_key=api_key,
user_role=LitellmUserRoles.PROXY_ADMIN.value,
user_role=LitellmUserRoles.PROXY_ADMIN,
**valid_token_dict,
)
elif (
@ -1407,19 +1403,19 @@ async def user_api_key_auth(
):
return UserAPIKeyAuth(
api_key=api_key,
user_role=LitellmUserRoles.PROXY_ADMIN.value,
user_role=LitellmUserRoles.PROXY_ADMIN,
**valid_token_dict,
)
elif _has_user_setup_sso() and route in LiteLLMRoutes.sso_only_routes.value:
return UserAPIKeyAuth(
api_key=api_key,
user_role=LitellmUserRoles.INTERNAL_USER.value,
user_role=LitellmUserRoles.INTERNAL_USER,
**valid_token_dict,
)
else:
return UserAPIKeyAuth(
api_key=api_key,
user_role=LitellmUserRoles.INTERNAL_USER.value,
user_role=LitellmUserRoles.INTERNAL_USER,
**valid_token_dict,
)
else:
@ -3752,9 +3748,9 @@ async def startup_event():
spend=0,
token=master_key,
user_id=litellm_proxy_admin_name,
user_role=LitellmUserRoles.PROXY_ADMIN.value,
user_role=LitellmUserRoles.PROXY_ADMIN,
query_type="update_data",
update_key_values={"user_role": LitellmUserRoles.PROXY_ADMIN.value},
update_key_values={"user_role": LitellmUserRoles.PROXY_ADMIN},
)
)
@ -6105,7 +6101,7 @@ async def delete_key_fn(
)
if (
user_api_key_dict.user_role is not None
and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value
and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
):
user_id = None # unless they're admin
@ -7902,7 +7898,7 @@ async def user_info(
# *NEW* get all teams in user 'teams' field
if (
getattr(caller_user_info, "user_role", None)
== LitellmUserRoles.PROXY_ADMIN.value
== LitellmUserRoles.PROXY_ADMIN
):
teams_2 = await prisma_client.get_data(
table_name="team",
@ -8731,7 +8727,7 @@ async def new_team(
if (
user_api_key_dict.user_role is None
or user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value
or user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN
): # don't restrict proxy admin
if (
data.tpm_limit is not None
@ -9337,7 +9333,7 @@ async def list_team(
"""
global prisma_client
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value:
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=401,
detail={
@ -9431,7 +9427,7 @@ async def new_organization(
if (
user_api_key_dict.user_role is None
or user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value
or user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN
):
raise HTTPException(
status_code=401,
@ -9634,7 +9630,7 @@ async def budget_settings(
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value:
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=400,
detail={
@ -9699,7 +9695,7 @@ async def list_budget(
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value:
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=400,
detail={
@ -9733,7 +9729,7 @@ async def delete_budget(
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value:
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=400,
detail={
@ -10711,7 +10707,7 @@ async def alerting_settings(
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value:
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=400,
detail={
@ -10792,7 +10788,7 @@ async def alerting_settings(
# detail={"error": CommonProxyErrors.db_not_connected_error.value},
# )
# if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value:
# if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
# raise HTTPException(
# status_code=400,
# detail={"error": CommonProxyErrors.not_allowed_access.value},
@ -11250,12 +11246,12 @@ async def login(request: Request):
await user_update(
data=UpdateUserRequest(
user_id=key_user_id,
user_role=LitellmUserRoles.PROXY_ADMIN.value,
user_role=LitellmUserRoles.PROXY_ADMIN,
)
)
if os.getenv("DATABASE_URL") is not None:
response = await generate_key_helper_fn(
**{"user_role": LitellmUserRoles.PROXY_ADMIN.value, "duration": "2hr", "key_max_budget": 5, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": key_user_id, "team_id": "litellm-dashboard"} # type: ignore
**{"user_role": LitellmUserRoles.PROXY_ADMIN, "duration": "2hr", "key_max_budget": 5, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": key_user_id, "team_id": "litellm-dashboard"} # type: ignore
)
else:
raise ProxyException(
@ -11650,7 +11646,7 @@ async def new_invitation(
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value:
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=400,
detail={
@ -11714,7 +11710,7 @@ async def invitation_info(
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value:
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=400,
detail={
@ -11826,7 +11822,7 @@ async def invitation_delete(
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value:
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=400,
detail={
@ -12021,7 +12017,7 @@ async def update_config_general_settings(
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value:
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=400,
detail={"error": CommonProxyErrors.not_allowed_access.value},
@ -12095,7 +12091,7 @@ async def get_config_general_settings(
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value:
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=400,
detail={"error": CommonProxyErrors.not_allowed_access.value},
@ -12158,7 +12154,7 @@ async def get_config_list(
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value:
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=400,
detail={
@ -12233,7 +12229,7 @@ async def delete_config_general_settings(
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value:
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=400,
detail={

View file

@ -138,7 +138,7 @@ async def test_new_user_response(prisma_client):
team_id=_team_id,
),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN.value,
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
@ -367,9 +367,7 @@ async def test_call_with_valid_model_using_all_models(prisma_client):
new_team_response = await new_team(
data=team_request,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN.value
),
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
)
print("new_team_response", new_team_response)
created_team_id = new_team_response["team_id"]
@ -928,7 +926,7 @@ def test_delete_key(prisma_client):
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
print(f"result: {result}")
result.user_role = LitellmUserRoles.PROXY_ADMIN.value
result.user_role = LitellmUserRoles.PROXY_ADMIN
# delete the key
result_delete_key = await delete_key_fn(
data=delete_key_request, user_api_key_dict=result
@ -978,7 +976,7 @@ def test_delete_key_auth(prisma_client):
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
print(f"result: {result}")
result.user_role = LitellmUserRoles.PROXY_ADMIN.value
result.user_role = LitellmUserRoles.PROXY_ADMIN
result_delete_key = await delete_key_fn(
data=delete_key_request, user_api_key_dict=result
@ -1050,7 +1048,7 @@ def test_generate_and_call_key_info(prisma_client):
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
print(f"result: {result}")
result.user_role = LitellmUserRoles.PROXY_ADMIN.value
result.user_role = LitellmUserRoles.PROXY_ADMIN
result_delete_key = await delete_key_fn(
data=delete_key_request, user_api_key_dict=result
@ -1084,7 +1082,7 @@ def test_generate_and_update_key(prisma_client):
team_id=_team_1,
),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN.value,
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
@ -1096,7 +1094,7 @@ def test_generate_and_update_key(prisma_client):
team_id=_team_2,
),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN.value,
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
@ -1168,7 +1166,7 @@ def test_generate_and_update_key(prisma_client):
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
print(f"result: {result}")
result.user_role = LitellmUserRoles.PROXY_ADMIN.value
result.user_role = LitellmUserRoles.PROXY_ADMIN
result_delete_key = await delete_key_fn(
data=delete_key_request, user_api_key_dict=result
@ -2048,7 +2046,7 @@ async def test_master_key_hashing(prisma_client):
await new_team(
NewTeamRequest(team_id=_team_id),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN.value,
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
@ -2088,7 +2086,7 @@ async def test_reset_spend_authentication(prisma_client):
"""
1. Test master key can access this route -> ONLY MASTER KEY SHOULD BE ABLE TO RESET SPEND
2. Test that non-master key gets rejected
3. Test that non-master key with role == LitellmUserRoles.PROXY_ADMIN.value or admin gets rejected
3. Test that non-master key with role == LitellmUserRoles.PROXY_ADMIN or admin gets rejected
"""
print("prisma client=", prisma_client)
@ -2133,10 +2131,10 @@ async def test_reset_spend_authentication(prisma_client):
in e.message
)
# Test 3 - Non-Master Key with role == LitellmUserRoles.PROXY_ADMIN.value or admin
# Test 3 - Non-Master Key with role == LitellmUserRoles.PROXY_ADMIN or admin
_response = await new_user(
data=NewUserRequest(
user_role=LitellmUserRoles.PROXY_ADMIN.value,
user_role=LitellmUserRoles.PROXY_ADMIN,
tpm_limit=20,
)
)
@ -2186,7 +2184,7 @@ async def test_create_update_team(prisma_client):
rpm_limit=20,
),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN.value,
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
@ -2214,7 +2212,7 @@ async def test_create_update_team(prisma_client):
rpm_limit=30,
),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN.value,
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),