From a4b6a959d807ee50dfa1cf7a05e41a69e438d8de Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 30 May 2024 14:28:53 -0700 Subject: [PATCH] fix literal usage --- litellm/proxy/_types.py | 34 ++++++++--- litellm/proxy/auth/auth_checks.py | 11 +++- litellm/proxy/proxy_server.py | 72 +++++++++++------------ litellm/tests/test_key_generate_prisma.py | 30 +++++----- 4 files changed, 82 insertions(+), 65 deletions(-) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 1806fabc77..984e65d694 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -522,7 +522,16 @@ class LiteLLM_ModelTable(LiteLLMBase): class NewUserRequest(GenerateKeyRequest): max_budget: Optional[float] = None user_email: Optional[str] = None - user_role: Optional[str] = None + user_role: Optional[ + Literal[ + LitellmUserRoles.PROXY_ADMIN, + LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY, + LitellmUserRoles.INTERNAL_USER, + LitellmUserRoles.INTERNAL_USER_VIEW_ONLY, + LitellmUserRoles.TEAM, + LitellmUserRoles.CUSTOMER, + ] + ] = None teams: Optional[list] = None organization_id: Optional[str] = None auto_create_key: bool = ( @@ -541,7 +550,16 @@ class UpdateUserRequest(GenerateRequestBase): user_email: Optional[str] = None spend: Optional[float] = None metadata: Optional[dict] = None - user_role: Optional[str] = None + user_role: Optional[ + Literal[ + LitellmUserRoles.PROXY_ADMIN, + LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY, + LitellmUserRoles.INTERNAL_USER, + LitellmUserRoles.INTERNAL_USER_VIEW_ONLY, + LitellmUserRoles.TEAM, + LitellmUserRoles.CUSTOMER, + ] + ] = None max_budget: Optional[float] = None @root_validator(pre=True) @@ -1088,12 +1106,12 @@ class UserAPIKeyAuth( api_key: Optional[str] = None user_role: Optional[ Literal[ - "proxy_admin", - "proxy_admin_view_only", - "internal_user", - "internal_user_view_only", - "team", - "customer", + LitellmUserRoles.PROXY_ADMIN, + LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY, + LitellmUserRoles.INTERNAL_USER, + LitellmUserRoles.INTERNAL_USER_VIEW_ONLY, + LitellmUserRoles.TEAM, + LitellmUserRoles.CUSTOMER, ] ] = None allowed_model_region: Optional[Literal["eu"]] = None diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index a6e97960e5..e4b8e6c8a8 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -15,6 +15,7 @@ from litellm.proxy._types import ( LiteLLM_TeamTable, LiteLLMRoutes, LiteLLM_OrganizationTable, + LitellmUserRoles, ) from typing import Optional, Literal, Union from litellm.proxy.utils import PrismaClient @@ -133,7 +134,11 @@ def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool: def allowed_routes_check( - user_role: Literal["proxy_admin", "team", "user"], + user_role: Literal[ + LitellmUserRoles.PROXY_ADMIN, + LitellmUserRoles.TEAM, + LitellmUserRoles.INTERNAL_USER, + ], user_route: str, litellm_proxy_roles: LiteLLM_JWTAuth, ) -> bool: @@ -141,14 +146,14 @@ def allowed_routes_check( Check if user -> not admin - allowed to access these routes """ - if user_role == "proxy_admin": + if user_role == LitellmUserRoles.PROXY_ADMIN: is_allowed = _allowed_routes_check( user_route=user_route, allowed_routes=litellm_proxy_roles.admin_allowed_routes, ) return is_allowed - elif user_role == "team": + elif user_role == LitellmUserRoles.TEAM: if litellm_proxy_roles.team_allowed_routes is None: """ By default allow a team to call openai + info routes diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 629ab3dd3b..f3a9b56924 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -507,9 +507,7 @@ async def user_api_key_auth( if route in LiteLLMRoutes.public_routes.value: # check if public endpoint - return UserAPIKeyAuth( - user_role=LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value - ) + return UserAPIKeyAuth(user_role=LitellmUserRoles.INTERNAL_USER_VIEW_ONLY) if general_settings.get("enable_jwt_auth", False) == True: is_jwt = jwt_handler.is_jwt(token=api_key) @@ -526,14 +524,12 @@ async def user_api_key_auth( if is_admin: # check allowed admin routes is_allowed = allowed_routes_check( - user_role=LitellmUserRoles.PROXY_ADMIN.value, + user_role=LitellmUserRoles.PROXY_ADMIN, user_route=route, litellm_proxy_roles=jwt_handler.litellm_jwtauth, ) if is_allowed: - return UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN.value - ) + return UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN) else: allowed_routes = ( jwt_handler.litellm_jwtauth.admin_allowed_routes @@ -556,7 +552,7 @@ async def user_api_key_auth( if team_id is not None: # check allowed team routes is_allowed = allowed_routes_check( - user_role="team", + user_role=LitellmUserRoles.TEAM, user_route=route, litellm_proxy_roles=jwt_handler.litellm_jwtauth, ) @@ -668,7 +664,7 @@ async def user_api_key_auth( team_object.rpm_limit if team_object is not None else None ), team_models=team_object.models if team_object is not None else [], - user_role=LitellmUserRoles.INTERNAL_USER.value, + user_role=LitellmUserRoles.INTERNAL_USER, user_id=user_id, org_id=org_id, ) @@ -676,10 +672,10 @@ async def user_api_key_auth( if master_key is None: if isinstance(api_key, str): return UserAPIKeyAuth( - api_key=api_key, user_role=LitellmUserRoles.PROXY_ADMIN.value + api_key=api_key, user_role=LitellmUserRoles.PROXY_ADMIN ) else: - return UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN.value) + return UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN) elif api_key is None: # only require api key if master key is set raise Exception("No api key passed in.") elif api_key == "": @@ -746,7 +742,7 @@ async def user_api_key_auth( if ( valid_token is not None and isinstance(valid_token, UserAPIKeyAuth) - and valid_token.user_role == LitellmUserRoles.PROXY_ADMIN.value + and valid_token.user_role == LitellmUserRoles.PROXY_ADMIN ): # update end-user params on valid token valid_token.end_user_id = end_user_params.get("end_user_id") @@ -779,7 +775,7 @@ async def user_api_key_auth( if is_master_key_valid: _user_api_key_obj = UserAPIKeyAuth( api_key=master_key, - user_role=LitellmUserRoles.PROXY_ADMIN.value, + user_role=LitellmUserRoles.PROXY_ADMIN, user_id=litellm_proxy_admin_name, **end_user_params, ) @@ -1384,7 +1380,7 @@ async def user_api_key_auth( ): return UserAPIKeyAuth( api_key=api_key, - user_role=LitellmUserRoles.PROXY_ADMIN.value, + user_role=LitellmUserRoles.PROXY_ADMIN, **valid_token_dict, ) elif ( @@ -1407,19 +1403,19 @@ async def user_api_key_auth( ): return UserAPIKeyAuth( api_key=api_key, - user_role=LitellmUserRoles.PROXY_ADMIN.value, + user_role=LitellmUserRoles.PROXY_ADMIN, **valid_token_dict, ) elif _has_user_setup_sso() and route in LiteLLMRoutes.sso_only_routes.value: return UserAPIKeyAuth( api_key=api_key, - user_role=LitellmUserRoles.INTERNAL_USER.value, + user_role=LitellmUserRoles.INTERNAL_USER, **valid_token_dict, ) else: return UserAPIKeyAuth( api_key=api_key, - user_role=LitellmUserRoles.INTERNAL_USER.value, + user_role=LitellmUserRoles.INTERNAL_USER, **valid_token_dict, ) else: @@ -3752,9 +3748,9 @@ async def startup_event(): spend=0, token=master_key, user_id=litellm_proxy_admin_name, - user_role=LitellmUserRoles.PROXY_ADMIN.value, + user_role=LitellmUserRoles.PROXY_ADMIN, query_type="update_data", - update_key_values={"user_role": LitellmUserRoles.PROXY_ADMIN.value}, + update_key_values={"user_role": LitellmUserRoles.PROXY_ADMIN}, ) ) @@ -6105,7 +6101,7 @@ async def delete_key_fn( ) if ( user_api_key_dict.user_role is not None - and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value + and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN ): user_id = None # unless they're admin @@ -7902,7 +7898,7 @@ async def user_info( # *NEW* get all teams in user 'teams' field if ( getattr(caller_user_info, "user_role", None) - == LitellmUserRoles.PROXY_ADMIN.value + == LitellmUserRoles.PROXY_ADMIN ): teams_2 = await prisma_client.get_data( table_name="team", @@ -8731,7 +8727,7 @@ async def new_team( if ( user_api_key_dict.user_role is None - or user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value + or user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN ): # don't restrict proxy admin if ( data.tpm_limit is not None @@ -9337,7 +9333,7 @@ async def list_team( """ global prisma_client - if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value: + if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN: raise HTTPException( status_code=401, detail={ @@ -9431,7 +9427,7 @@ async def new_organization( if ( user_api_key_dict.user_role is None - or user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value + or user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN ): raise HTTPException( status_code=401, @@ -9634,7 +9630,7 @@ async def budget_settings( detail={"error": CommonProxyErrors.db_not_connected_error.value}, ) - if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value: + if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN: raise HTTPException( status_code=400, detail={ @@ -9699,7 +9695,7 @@ async def list_budget( detail={"error": CommonProxyErrors.db_not_connected_error.value}, ) - if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value: + if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN: raise HTTPException( status_code=400, detail={ @@ -9733,7 +9729,7 @@ async def delete_budget( detail={"error": CommonProxyErrors.db_not_connected_error.value}, ) - if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value: + if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN: raise HTTPException( status_code=400, detail={ @@ -10711,7 +10707,7 @@ async def alerting_settings( detail={"error": CommonProxyErrors.db_not_connected_error.value}, ) - if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value: + if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN: raise HTTPException( status_code=400, detail={ @@ -10792,7 +10788,7 @@ async def alerting_settings( # detail={"error": CommonProxyErrors.db_not_connected_error.value}, # ) -# if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value: +# if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN: # raise HTTPException( # status_code=400, # detail={"error": CommonProxyErrors.not_allowed_access.value}, @@ -11250,12 +11246,12 @@ async def login(request: Request): await user_update( data=UpdateUserRequest( user_id=key_user_id, - user_role=LitellmUserRoles.PROXY_ADMIN.value, + user_role=LitellmUserRoles.PROXY_ADMIN, ) ) if os.getenv("DATABASE_URL") is not None: response = await generate_key_helper_fn( - **{"user_role": LitellmUserRoles.PROXY_ADMIN.value, "duration": "2hr", "key_max_budget": 5, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": key_user_id, "team_id": "litellm-dashboard"} # type: ignore + **{"user_role": LitellmUserRoles.PROXY_ADMIN, "duration": "2hr", "key_max_budget": 5, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": key_user_id, "team_id": "litellm-dashboard"} # type: ignore ) else: raise ProxyException( @@ -11650,7 +11646,7 @@ async def new_invitation( detail={"error": CommonProxyErrors.db_not_connected_error.value}, ) - if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value: + if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN: raise HTTPException( status_code=400, detail={ @@ -11714,7 +11710,7 @@ async def invitation_info( detail={"error": CommonProxyErrors.db_not_connected_error.value}, ) - if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value: + if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN: raise HTTPException( status_code=400, detail={ @@ -11826,7 +11822,7 @@ async def invitation_delete( detail={"error": CommonProxyErrors.db_not_connected_error.value}, ) - if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value: + if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN: raise HTTPException( status_code=400, detail={ @@ -12021,7 +12017,7 @@ async def update_config_general_settings( detail={"error": CommonProxyErrors.db_not_connected_error.value}, ) - if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value: + if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN: raise HTTPException( status_code=400, detail={"error": CommonProxyErrors.not_allowed_access.value}, @@ -12095,7 +12091,7 @@ async def get_config_general_settings( detail={"error": CommonProxyErrors.db_not_connected_error.value}, ) - if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value: + if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN: raise HTTPException( status_code=400, detail={"error": CommonProxyErrors.not_allowed_access.value}, @@ -12158,7 +12154,7 @@ async def get_config_list( detail={"error": CommonProxyErrors.db_not_connected_error.value}, ) - if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value: + if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN: raise HTTPException( status_code=400, detail={ @@ -12233,7 +12229,7 @@ async def delete_config_general_settings( detail={"error": CommonProxyErrors.db_not_connected_error.value}, ) - if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value: + if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN: raise HTTPException( status_code=400, detail={ diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index 7ceab3663d..6be720ffd1 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -138,7 +138,7 @@ async def test_new_user_response(prisma_client): team_id=_team_id, ), user_api_key_dict=UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN.value, + user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="1234", ), @@ -367,9 +367,7 @@ async def test_call_with_valid_model_using_all_models(prisma_client): new_team_response = await new_team( data=team_request, - user_api_key_dict=UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN.value - ), + user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), ) print("new_team_response", new_team_response) created_team_id = new_team_response["team_id"] @@ -928,7 +926,7 @@ def test_delete_key(prisma_client): # use generated key to auth in result = await user_api_key_auth(request=request, api_key=bearer_token) print(f"result: {result}") - result.user_role = LitellmUserRoles.PROXY_ADMIN.value + result.user_role = LitellmUserRoles.PROXY_ADMIN # delete the key result_delete_key = await delete_key_fn( data=delete_key_request, user_api_key_dict=result @@ -978,7 +976,7 @@ def test_delete_key_auth(prisma_client): # use generated key to auth in result = await user_api_key_auth(request=request, api_key=bearer_token) print(f"result: {result}") - result.user_role = LitellmUserRoles.PROXY_ADMIN.value + result.user_role = LitellmUserRoles.PROXY_ADMIN result_delete_key = await delete_key_fn( data=delete_key_request, user_api_key_dict=result @@ -1050,7 +1048,7 @@ def test_generate_and_call_key_info(prisma_client): # use generated key to auth in result = await user_api_key_auth(request=request, api_key=bearer_token) print(f"result: {result}") - result.user_role = LitellmUserRoles.PROXY_ADMIN.value + result.user_role = LitellmUserRoles.PROXY_ADMIN result_delete_key = await delete_key_fn( data=delete_key_request, user_api_key_dict=result @@ -1084,7 +1082,7 @@ def test_generate_and_update_key(prisma_client): team_id=_team_1, ), user_api_key_dict=UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN.value, + user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="1234", ), @@ -1096,7 +1094,7 @@ def test_generate_and_update_key(prisma_client): team_id=_team_2, ), user_api_key_dict=UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN.value, + user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="1234", ), @@ -1168,7 +1166,7 @@ def test_generate_and_update_key(prisma_client): # use generated key to auth in result = await user_api_key_auth(request=request, api_key=bearer_token) print(f"result: {result}") - result.user_role = LitellmUserRoles.PROXY_ADMIN.value + result.user_role = LitellmUserRoles.PROXY_ADMIN result_delete_key = await delete_key_fn( data=delete_key_request, user_api_key_dict=result @@ -2048,7 +2046,7 @@ async def test_master_key_hashing(prisma_client): await new_team( NewTeamRequest(team_id=_team_id), user_api_key_dict=UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN.value, + user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="1234", ), @@ -2088,7 +2086,7 @@ async def test_reset_spend_authentication(prisma_client): """ 1. Test master key can access this route -> ONLY MASTER KEY SHOULD BE ABLE TO RESET SPEND 2. Test that non-master key gets rejected - 3. Test that non-master key with role == LitellmUserRoles.PROXY_ADMIN.value or admin gets rejected + 3. Test that non-master key with role == LitellmUserRoles.PROXY_ADMIN or admin gets rejected """ print("prisma client=", prisma_client) @@ -2133,10 +2131,10 @@ async def test_reset_spend_authentication(prisma_client): in e.message ) - # Test 3 - Non-Master Key with role == LitellmUserRoles.PROXY_ADMIN.value or admin + # Test 3 - Non-Master Key with role == LitellmUserRoles.PROXY_ADMIN or admin _response = await new_user( data=NewUserRequest( - user_role=LitellmUserRoles.PROXY_ADMIN.value, + user_role=LitellmUserRoles.PROXY_ADMIN, tpm_limit=20, ) ) @@ -2186,7 +2184,7 @@ async def test_create_update_team(prisma_client): rpm_limit=20, ), user_api_key_dict=UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN.value, + user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="1234", ), @@ -2214,7 +2212,7 @@ async def test_create_update_team(prisma_client): rpm_limit=30, ), user_api_key_dict=UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN.value, + user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="1234", ),