/key/delete - allow team admin to delete team keys (#7846)

* fix(key_management_endpoints.py): fix key delete to allow team admins + other proxy admins to delete keys

Fixes https://github.com/BerriAI/litellm/issues/7760

* fix(key_management_endpoints.py): remove unused variables

* fix(key_management_endpoints.py): fix linting error
This commit is contained in:
Krish Dholakia 2025-01-17 20:16:12 -08:00 committed by GitHub
parent 2b58f16fda
commit 1ed7946280
2 changed files with 118 additions and 37 deletions

View file

@ -49,7 +49,7 @@ from litellm.types.utils import (
) )
def _is_team_key(data: GenerateKeyRequest): def _is_team_key(data: Union[GenerateKeyRequest, LiteLLM_VerificationToken]):
return data.team_id is not None return data.team_id is not None
@ -805,23 +805,17 @@ async def delete_key_fn(
raise Exception("Not connected to DB!") raise Exception("Not connected to DB!")
## only allow user to delete keys they own ## only allow user to delete keys they own
user_id = user_api_key_dict.user_id
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"user_api_key_dict.user_role: {user_api_key_dict.user_role}" f"user_api_key_dict.user_role: {user_api_key_dict.user_role}"
) )
if (
user_api_key_dict.user_role is not None
and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
):
user_id = None # unless they're admin
num_keys_to_be_deleted = 0 num_keys_to_be_deleted = 0
deleted_keys = [] deleted_keys = []
if data.keys: if data.keys:
number_deleted_keys, _keys_being_deleted = await delete_verification_token( number_deleted_keys, _keys_being_deleted = await delete_verification_tokens(
tokens=data.keys, tokens=data.keys,
user_api_key_cache=user_api_key_cache, user_api_key_cache=user_api_key_cache,
user_id=user_id, user_api_key_dict=user_api_key_dict,
) )
num_keys_to_be_deleted = len(data.keys) num_keys_to_be_deleted = len(data.keys)
deleted_keys = data.keys deleted_keys = data.keys
@ -830,7 +824,7 @@ async def delete_key_fn(
key_aliases=data.key_aliases, key_aliases=data.key_aliases,
prisma_client=prisma_client, prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache, user_api_key_cache=user_api_key_cache,
user_id=user_id, user_api_key_dict=user_api_key_dict,
) )
num_keys_to_be_deleted = len(data.key_aliases) num_keys_to_be_deleted = len(data.key_aliases)
deleted_keys = data.key_aliases deleted_keys = data.key_aliases
@ -844,17 +838,15 @@ async def delete_key_fn(
param="keys", param="keys",
code=status.HTTP_500_INTERNAL_SERVER_ERROR, code=status.HTTP_500_INTERNAL_SERVER_ERROR,
) )
verbose_proxy_logger.debug( verbose_proxy_logger.debug(f"/key/delete - deleted_keys={number_deleted_keys}")
f"/key/delete - deleted_keys={number_deleted_keys['deleted_keys']}"
)
try: try:
assert num_keys_to_be_deleted == number_deleted_keys["deleted_keys"] assert num_keys_to_be_deleted == len(deleted_keys)
except Exception: except Exception:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail={ detail={
"error": f"Not all keys passed in were deleted. This probably means you don't have access to delete all the keys passed in. Keys passed in={num_keys_to_be_deleted}, Deleted keys ={number_deleted_keys['deleted_keys']}" "error": f"Not all keys passed in were deleted. This probably means you don't have access to delete all the keys passed in. Keys passed in={num_keys_to_be_deleted}, Deleted keys ={number_deleted_keys}"
}, },
) )
@ -1294,14 +1286,88 @@ async def generate_key_helper_fn( # noqa: PLR0915
return key_data return key_data
async def delete_verification_token( async def _team_key_deletion_check(
user_api_key_dict: UserAPIKeyAuth,
key_info: LiteLLM_VerificationToken,
prisma_client: PrismaClient,
user_api_key_cache: DualCache,
):
is_team_key = _is_team_key(data=key_info)
if is_team_key and key_info.team_id is not None:
team_table = await get_team_object(
team_id=key_info.team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
check_db_only=True,
)
if (
litellm.key_generation_settings is not None
and "team_key_generation" in litellm.key_generation_settings
):
_team_key_generation = litellm.key_generation_settings[
"team_key_generation"
]
else:
_team_key_generation = TeamUIKeyGenerationConfig(
allowed_team_member_roles=["admin", "user"],
)
# check if user is team admin
if team_table is not None:
return _team_key_generation_team_member_check(
assigned_user_id=user_api_key_dict.user_id,
team_table=team_table,
user_api_key_dict=user_api_key_dict,
team_key_generation=_team_key_generation,
)
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"error": f"Team not found in db, and user not proxy admin. Team id = {key_info.team_id}"
},
)
return False
async def can_delete_verification_token(
key_info: LiteLLM_VerificationToken,
user_api_key_cache: DualCache,
user_api_key_dict: UserAPIKeyAuth,
prisma_client: PrismaClient,
) -> bool:
"""
- check if user is proxy admin
- check if user is team admin and key is a team key
- check if key is personal key
"""
is_team_key = _is_team_key(data=key_info)
if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value:
return True
elif is_team_key and key_info.team_id is not None:
return await _team_key_deletion_check(
user_api_key_dict=user_api_key_dict,
key_info=key_info,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
)
elif key_info.user_id is not None and key_info.user_id == user_api_key_dict.user_id:
return True
else:
return False
async def delete_verification_tokens(
tokens: List, tokens: List,
user_api_key_cache: DualCache, user_api_key_cache: DualCache,
user_id: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth,
) -> Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]: ) -> Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]:
""" """
Helper that deletes the list of tokens from the database Helper that deletes the list of tokens from the database
- check if user is proxy admin
- check if user is team admin and key is a team key
Args: Args:
tokens: List of tokens to delete tokens: List of tokens to delete
user_id: Optional user_id to filter by user_id: Optional user_id to filter by
@ -1314,12 +1380,12 @@ async def delete_verification_token(
- List of keys being deleted, this contains information about the key_alias, token, and user_id being deleted, - List of keys being deleted, this contains information about the key_alias, token, and user_id being deleted,
this is passed down to the KeyManagementEventHooks to delete the keys from the secret manager and handle audit logs this is passed down to the KeyManagementEventHooks to delete the keys from the secret manager and handle audit logs
""" """
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client from litellm.proxy.proxy_server import prisma_client
try: try:
if prisma_client: if prisma_client:
tokens = [_hash_token_if_needed(token=key) for key in tokens] tokens = [_hash_token_if_needed(token=key) for key in tokens]
_keys_being_deleted = ( _keys_being_deleted: List[LiteLLM_VerificationToken] = (
await prisma_client.db.litellm_verificationtoken.find_many( await prisma_client.db.litellm_verificationtoken.find_many(
where={"token": {"in": tokens}} where={"token": {"in": tokens}}
) )
@ -1327,28 +1393,41 @@ async def delete_verification_token(
# Assuming 'db' is your Prisma Client instance # Assuming 'db' is your Prisma Client instance
# check if admin making request - don't filter by user-id # check if admin making request - don't filter by user-id
if user_id == litellm_proxy_admin_name: if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value:
deleted_tokens = await prisma_client.delete_data(tokens=tokens) deleted_tokens = await prisma_client.delete_data(tokens=tokens)
# else # else
else: else:
deleted_tokens = await prisma_client.delete_data( tasks = []
tokens=tokens, user_id=user_id deleted_tokens = []
) for key in _keys_being_deleted:
if deleted_tokens is None:
raise Exception( async def _delete_key(key: LiteLLM_VerificationToken):
"Failed to delete tokens got response None when deleting tokens" if await can_delete_verification_token(
) key_info=key,
_num_deleted_tokens = deleted_tokens.get("deleted_keys", 0) user_api_key_cache=user_api_key_cache,
user_api_key_dict=user_api_key_dict,
prisma_client=prisma_client,
):
await prisma_client.delete_data(tokens=[key.token])
deleted_tokens.append(key.token)
tasks.append(_delete_key(key))
await asyncio.gather(*tasks)
_num_deleted_tokens = len(deleted_tokens)
if _num_deleted_tokens != len(tokens): if _num_deleted_tokens != len(tokens):
failed_tokens = [
token for token in tokens if token not in deleted_tokens
]
raise Exception( raise Exception(
"Failed to delete all tokens. Tried to delete tokens that don't belong to user: " "Failed to delete all tokens. Failed to delete tokens: "
+ str(user_id) + str(failed_tokens)
) )
else: else:
raise Exception("DB not connected. prisma_client is None") raise Exception("DB not connected. prisma_client is None")
except Exception as e: except Exception as e:
verbose_proxy_logger.exception( verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.delete_verification_token(): Exception occured - {}".format( "litellm.proxy.proxy_server.delete_verification_tokens(): Exception occured - {}".format(
str(e) str(e)
) )
) )
@ -1358,25 +1437,27 @@ async def delete_verification_token(
for key in tokens: for key in tokens:
user_api_key_cache.delete_cache(key) user_api_key_cache.delete_cache(key)
# remove hash token from cache # remove hash token from cache
hashed_token = hash_token(key) hashed_token = hash_token(cast(str, key))
user_api_key_cache.delete_cache(hashed_token) user_api_key_cache.delete_cache(hashed_token)
return deleted_tokens, _keys_being_deleted return {"deleted_keys": deleted_tokens}, _keys_being_deleted
async def delete_key_aliases( async def delete_key_aliases(
key_aliases: List[str], key_aliases: List[str],
user_api_key_cache: DualCache, user_api_key_cache: DualCache,
prisma_client: PrismaClient, prisma_client: PrismaClient,
user_id: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth,
) -> Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]: ) -> Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]:
_keys_being_deleted = await prisma_client.db.litellm_verificationtoken.find_many( _keys_being_deleted = await prisma_client.db.litellm_verificationtoken.find_many(
where={"key_alias": {"in": key_aliases}} where={"key_alias": {"in": key_aliases}}
) )
tokens = [key.token for key in _keys_being_deleted] tokens = [key.token for key in _keys_being_deleted]
return await delete_verification_token( return await delete_verification_tokens(
tokens=tokens, user_api_key_cache=user_api_key_cache, user_id=user_id tokens=tokens,
user_api_key_cache=user_api_key_cache,
user_api_key_dict=user_api_key_dict,
) )

View file

@ -189,7 +189,7 @@ from litellm.proxy.management_endpoints.internal_user_endpoints import (
) )
from litellm.proxy.management_endpoints.internal_user_endpoints import user_update from litellm.proxy.management_endpoints.internal_user_endpoints import user_update
from litellm.proxy.management_endpoints.key_management_endpoints import ( from litellm.proxy.management_endpoints.key_management_endpoints import (
delete_verification_token, delete_verification_tokens,
duration_in_seconds, duration_in_seconds,
generate_key_helper_fn, generate_key_helper_fn,
) )