From 1ed79462807a52ae4f6a8e44b4d33cfe2e7ddabd Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Fri, 17 Jan 2025 20:16:12 -0800 Subject: [PATCH] `/key/delete` - allow team admin to delete team keys (#7846) * fix(key_management_endpoints.py): fix key delete to allow team admins + other proxy admins to delete keys Fixes https://github.com/BerriAI/litellm/issues/7760 * fix(key_management_endpoints.py): remove unused variables * fix(key_management_endpoints.py): fix linting error --- .../key_management_endpoints.py | 153 +++++++++++++----- litellm/proxy/proxy_server.py | 2 +- 2 files changed, 118 insertions(+), 37 deletions(-) diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 1e9edf0e0f..02d8f49a34 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -49,7 +49,7 @@ from litellm.types.utils import ( ) -def _is_team_key(data: GenerateKeyRequest): +def _is_team_key(data: Union[GenerateKeyRequest, LiteLLM_VerificationToken]): return data.team_id is not None @@ -805,23 +805,17 @@ async def delete_key_fn( raise Exception("Not connected to DB!") ## only allow user to delete keys they own - user_id = user_api_key_dict.user_id verbose_proxy_logger.debug( f"user_api_key_dict.user_role: {user_api_key_dict.user_role}" ) - if ( - user_api_key_dict.user_role is not None - and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN - ): - user_id = None # unless they're admin num_keys_to_be_deleted = 0 deleted_keys = [] if data.keys: - number_deleted_keys, _keys_being_deleted = await delete_verification_token( + number_deleted_keys, _keys_being_deleted = await delete_verification_tokens( tokens=data.keys, user_api_key_cache=user_api_key_cache, - user_id=user_id, + user_api_key_dict=user_api_key_dict, ) num_keys_to_be_deleted = len(data.keys) deleted_keys = data.keys @@ -830,7 +824,7 @@ async def delete_key_fn( key_aliases=data.key_aliases, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, - user_id=user_id, + user_api_key_dict=user_api_key_dict, ) num_keys_to_be_deleted = len(data.key_aliases) deleted_keys = data.key_aliases @@ -844,17 +838,15 @@ async def delete_key_fn( param="keys", code=status.HTTP_500_INTERNAL_SERVER_ERROR, ) - verbose_proxy_logger.debug( - f"/key/delete - deleted_keys={number_deleted_keys['deleted_keys']}" - ) + verbose_proxy_logger.debug(f"/key/delete - deleted_keys={number_deleted_keys}") try: - assert num_keys_to_be_deleted == number_deleted_keys["deleted_keys"] + assert num_keys_to_be_deleted == len(deleted_keys) except Exception: raise HTTPException( status_code=400, detail={ - "error": f"Not all keys passed in were deleted. This probably means you don't have access to delete all the keys passed in. Keys passed in={num_keys_to_be_deleted}, Deleted keys ={number_deleted_keys['deleted_keys']}" + "error": f"Not all keys passed in were deleted. This probably means you don't have access to delete all the keys passed in. Keys passed in={num_keys_to_be_deleted}, Deleted keys ={number_deleted_keys}" }, ) @@ -1294,14 +1286,88 @@ async def generate_key_helper_fn( # noqa: PLR0915 return key_data -async def delete_verification_token( +async def _team_key_deletion_check( + user_api_key_dict: UserAPIKeyAuth, + key_info: LiteLLM_VerificationToken, + prisma_client: PrismaClient, + user_api_key_cache: DualCache, +): + is_team_key = _is_team_key(data=key_info) + + if is_team_key and key_info.team_id is not None: + team_table = await get_team_object( + team_id=key_info.team_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + check_db_only=True, + ) + if ( + litellm.key_generation_settings is not None + and "team_key_generation" in litellm.key_generation_settings + ): + _team_key_generation = litellm.key_generation_settings[ + "team_key_generation" + ] + else: + _team_key_generation = TeamUIKeyGenerationConfig( + allowed_team_member_roles=["admin", "user"], + ) + # check if user is team admin + if team_table is not None: + return _team_key_generation_team_member_check( + assigned_user_id=user_api_key_dict.user_id, + team_table=team_table, + user_api_key_dict=user_api_key_dict, + team_key_generation=_team_key_generation, + ) + else: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={ + "error": f"Team not found in db, and user not proxy admin. Team id = {key_info.team_id}" + }, + ) + return False + + +async def can_delete_verification_token( + key_info: LiteLLM_VerificationToken, + user_api_key_cache: DualCache, + user_api_key_dict: UserAPIKeyAuth, + prisma_client: PrismaClient, +) -> bool: + """ + - check if user is proxy admin + - check if user is team admin and key is a team key + - check if key is personal key + """ + is_team_key = _is_team_key(data=key_info) + if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value: + return True + elif is_team_key and key_info.team_id is not None: + return await _team_key_deletion_check( + user_api_key_dict=user_api_key_dict, + key_info=key_info, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + ) + elif key_info.user_id is not None and key_info.user_id == user_api_key_dict.user_id: + return True + else: + return False + + +async def delete_verification_tokens( tokens: List, user_api_key_cache: DualCache, - user_id: Optional[str] = None, + user_api_key_dict: UserAPIKeyAuth, ) -> Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]: """ Helper that deletes the list of tokens from the database + - check if user is proxy admin + - check if user is team admin and key is a team key + Args: tokens: List of tokens to delete user_id: Optional user_id to filter by @@ -1314,12 +1380,12 @@ async def delete_verification_token( - List of keys being deleted, this contains information about the key_alias, token, and user_id being deleted, this is passed down to the KeyManagementEventHooks to delete the keys from the secret manager and handle audit logs """ - from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client + from litellm.proxy.proxy_server import prisma_client try: if prisma_client: tokens = [_hash_token_if_needed(token=key) for key in tokens] - _keys_being_deleted = ( + _keys_being_deleted: List[LiteLLM_VerificationToken] = ( await prisma_client.db.litellm_verificationtoken.find_many( where={"token": {"in": tokens}} ) @@ -1327,28 +1393,41 @@ async def delete_verification_token( # Assuming 'db' is your Prisma Client instance # check if admin making request - don't filter by user-id - if user_id == litellm_proxy_admin_name: + if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value: deleted_tokens = await prisma_client.delete_data(tokens=tokens) # else else: - deleted_tokens = await prisma_client.delete_data( - tokens=tokens, user_id=user_id - ) - if deleted_tokens is None: - raise Exception( - "Failed to delete tokens got response None when deleting tokens" - ) - _num_deleted_tokens = deleted_tokens.get("deleted_keys", 0) + tasks = [] + deleted_tokens = [] + for key in _keys_being_deleted: + + async def _delete_key(key: LiteLLM_VerificationToken): + if await can_delete_verification_token( + key_info=key, + user_api_key_cache=user_api_key_cache, + user_api_key_dict=user_api_key_dict, + prisma_client=prisma_client, + ): + await prisma_client.delete_data(tokens=[key.token]) + deleted_tokens.append(key.token) + + tasks.append(_delete_key(key)) + await asyncio.gather(*tasks) + + _num_deleted_tokens = len(deleted_tokens) if _num_deleted_tokens != len(tokens): + failed_tokens = [ + token for token in tokens if token not in deleted_tokens + ] raise Exception( - "Failed to delete all tokens. Tried to delete tokens that don't belong to user: " - + str(user_id) + "Failed to delete all tokens. Failed to delete tokens: " + + str(failed_tokens) ) else: raise Exception("DB not connected. prisma_client is None") except Exception as e: verbose_proxy_logger.exception( - "litellm.proxy.proxy_server.delete_verification_token(): Exception occured - {}".format( + "litellm.proxy.proxy_server.delete_verification_tokens(): Exception occured - {}".format( str(e) ) ) @@ -1358,25 +1437,27 @@ async def delete_verification_token( for key in tokens: user_api_key_cache.delete_cache(key) # remove hash token from cache - hashed_token = hash_token(key) + hashed_token = hash_token(cast(str, key)) user_api_key_cache.delete_cache(hashed_token) - return deleted_tokens, _keys_being_deleted + return {"deleted_keys": deleted_tokens}, _keys_being_deleted async def delete_key_aliases( key_aliases: List[str], user_api_key_cache: DualCache, prisma_client: PrismaClient, - user_id: Optional[str] = None, + user_api_key_dict: UserAPIKeyAuth, ) -> Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]: _keys_being_deleted = await prisma_client.db.litellm_verificationtoken.find_many( where={"key_alias": {"in": key_aliases}} ) tokens = [key.token for key in _keys_being_deleted] - return await delete_verification_token( - tokens=tokens, user_api_key_cache=user_api_key_cache, user_id=user_id + return await delete_verification_tokens( + tokens=tokens, + user_api_key_cache=user_api_key_cache, + user_api_key_dict=user_api_key_dict, ) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 082e5251f0..b036945674 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -189,7 +189,7 @@ from litellm.proxy.management_endpoints.internal_user_endpoints import ( ) from litellm.proxy.management_endpoints.internal_user_endpoints import user_update from litellm.proxy.management_endpoints.key_management_endpoints import ( - delete_verification_token, + delete_verification_tokens, duration_in_seconds, generate_key_helper_fn, )