Support deleting keys by key_alias (#7552)

* feat(key_management_endpoints.py): allow deleting keys based on key alias

easier for proxy admin to delete known bad key

* fix(key_management_event_hooks.py): fix linting error

* docs(key_management_endpoints.py): document new key_aliases param

* fix(key_management_endpoints.py): return deleted keys to user

fixes return when passing key aliases
This commit is contained in:
Krish Dholakia 2025-01-04 19:41:48 -08:00 committed by GitHub
parent 8a5e74d519
commit 5cf223c66a
4 changed files with 72 additions and 26 deletions

View file

@ -15,6 +15,5 @@ model_list:
tpm: 1000000 tpm: 1000000
prompt_id: "jokes" prompt_id: "jokes"
# litellm_settings: # litellm_settings:
# callbacks: ["otel"] # callbacks: ["otel"]

View file

@ -690,7 +690,17 @@ class RegenerateKeyRequest(GenerateKeyRequest):
class KeyRequest(LiteLLMPydanticObjectBase): class KeyRequest(LiteLLMPydanticObjectBase):
keys: List[str] keys: Optional[List[str]] = None
key_aliases: Optional[List[str]] = None
@model_validator(mode="before")
@classmethod
def validate_at_least_one(cls, values):
if not values.get("keys") and not values.get("key_aliases"):
raise ValueError(
"At least one of 'keys' or 'key_aliases' must be provided."
)
return values
class LiteLLM_ModelTable(LiteLLMPydanticObjectBase): class LiteLLM_ModelTable(LiteLLMPydanticObjectBase):

View file

@ -142,7 +142,7 @@ class KeyManagementEventHooks:
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
# we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes # we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes
if litellm.store_audit_logs is True: if litellm.store_audit_logs is True and data.keys is not None:
# make an audit log for each team deleted # make an audit log for each team deleted
for key in data.keys: for key in data.keys:
key_row = await prisma_client.get_data( # type: ignore key_row = await prisma_client.get_data( # type: ignore

View file

@ -23,6 +23,7 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, s
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache
from litellm.proxy._types import * from litellm.proxy._types import *
from litellm.proxy.auth.auth_checks import ( from litellm.proxy.auth.auth_checks import (
_cache_key_object, _cache_key_object,
@ -34,6 +35,7 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
from litellm.proxy.utils import ( from litellm.proxy.utils import (
PrismaClient,
_hash_token_if_needed, _hash_token_if_needed,
duration_in_seconds, duration_in_seconds,
handle_exception_on_proxy, handle_exception_on_proxy,
@ -772,6 +774,7 @@ async def delete_key_fn(
Parameters:: Parameters::
- keys (List[str]): A list of keys or hashed keys to delete. Example {"keys": ["sk-QWrxEynunsNpV1zT48HIrw", "837e17519f44683334df5291321d97b8bf1098cd490e49e215f6fea935aa28be"]} - keys (List[str]): A list of keys or hashed keys to delete. Example {"keys": ["sk-QWrxEynunsNpV1zT48HIrw", "837e17519f44683334df5291321d97b8bf1098cd490e49e215f6fea935aa28be"]}
- key_aliases (List[str]): A list of key aliases to delete. Can be passed instead of `keys`.Example {"key_aliases": ["alias1", "alias2"]}
Returns: Returns:
- deleted_keys (List[str]): A list of deleted keys. Example {"deleted_keys": ["sk-QWrxEynunsNpV1zT48HIrw", "837e17519f44683334df5291321d97b8bf1098cd490e49e215f6fea935aa28be"]} - deleted_keys (List[str]): A list of deleted keys. Example {"deleted_keys": ["sk-QWrxEynunsNpV1zT48HIrw", "837e17519f44683334df5291321d97b8bf1098cd490e49e215f6fea935aa28be"]}
@ -795,15 +798,6 @@ async def delete_key_fn(
if prisma_client is None: if prisma_client is None:
raise Exception("Not connected to DB!") raise Exception("Not connected to DB!")
keys = data.keys
if len(keys) == 0:
raise ProxyException(
message=f"No keys provided, passed in: keys={keys}",
type=ProxyErrorTypes.auth_error,
param="keys",
code=status.HTTP_400_BAD_REQUEST,
)
## only allow user to delete keys they own ## only allow user to delete keys they own
user_id = user_api_key_dict.user_id user_id = user_api_key_dict.user_id
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
@ -815,9 +809,28 @@ async def delete_key_fn(
): ):
user_id = None # unless they're admin user_id = None # unless they're admin
number_deleted_keys, _keys_being_deleted = await delete_verification_token( num_keys_to_be_deleted = 0
tokens=keys, user_id=user_id deleted_keys = []
) if data.keys:
number_deleted_keys, _keys_being_deleted = await delete_verification_token(
tokens=data.keys,
user_api_key_cache=user_api_key_cache,
user_id=user_id,
)
num_keys_to_be_deleted = len(data.keys)
deleted_keys = data.keys
elif data.key_aliases:
number_deleted_keys, _keys_being_deleted = await delete_key_aliases(
key_aliases=data.key_aliases,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
user_id=user_id,
)
num_keys_to_be_deleted = len(data.key_aliases)
deleted_keys = data.key_aliases
else:
raise ValueError("Invalid request type")
if number_deleted_keys is None: if number_deleted_keys is None:
raise ProxyException( raise ProxyException(
message="Failed to delete keys got None response from delete_verification_token", message="Failed to delete keys got None response from delete_verification_token",
@ -830,21 +843,15 @@ async def delete_key_fn(
) )
try: try:
assert len(keys) == number_deleted_keys["deleted_keys"] assert num_keys_to_be_deleted == number_deleted_keys["deleted_keys"]
except Exception: except Exception:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail={ detail={
"error": f"Not all keys passed in were deleted. This probably means you don't have access to delete all the keys passed in. Keys passed in={len(keys)}, Deleted keys ={number_deleted_keys['deleted_keys']}" "error": f"Not all keys passed in were deleted. This probably means you don't have access to delete all the keys passed in. Keys passed in={num_keys_to_be_deleted}, Deleted keys ={number_deleted_keys['deleted_keys']}"
}, },
) )
for key in keys:
user_api_key_cache.delete_cache(key)
# remove hash token from cache
hashed_token = hash_token(key)
user_api_key_cache.delete_cache(hashed_token)
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"/keys/delete - cache after delete: {user_api_key_cache.in_memory_cache.cache_dict}" f"/keys/delete - cache after delete: {user_api_key_cache.in_memory_cache.cache_dict}"
) )
@ -859,8 +866,13 @@ async def delete_key_fn(
) )
) )
return {"deleted_keys": keys} return {"deleted_keys": deleted_keys}
except Exception as e: except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.delete_key_fn(): Exception occured - {}".format(
str(e)
)
)
raise handle_exception_on_proxy(e) raise handle_exception_on_proxy(e)
@ -1277,7 +1289,9 @@ async def generate_key_helper_fn( # noqa: PLR0915
async def delete_verification_token( async def delete_verification_token(
tokens: List, user_id: Optional[str] = None tokens: List,
user_api_key_cache: DualCache,
user_id: Optional[str] = None,
) -> Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]: ) -> Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]:
""" """
Helper that deletes the list of tokens from the database Helper that deletes the list of tokens from the database
@ -1327,16 +1341,39 @@ async def delete_verification_token(
else: else:
raise Exception("DB not connected. prisma_client is None") raise Exception("DB not connected. prisma_client is None")
except Exception as e: except Exception as e:
verbose_proxy_logger.error( verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.delete_verification_token(): Exception occured - {}".format( "litellm.proxy.proxy_server.delete_verification_token(): Exception occured - {}".format(
str(e) str(e)
) )
) )
verbose_proxy_logger.debug(traceback.format_exc()) verbose_proxy_logger.debug(traceback.format_exc())
raise e raise e
for key in tokens:
user_api_key_cache.delete_cache(key)
# remove hash token from cache
hashed_token = hash_token(key)
user_api_key_cache.delete_cache(hashed_token)
return deleted_tokens, _keys_being_deleted return deleted_tokens, _keys_being_deleted
async def delete_key_aliases(
key_aliases: List[str],
user_api_key_cache: DualCache,
prisma_client: PrismaClient,
user_id: Optional[str] = None,
) -> Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]:
_keys_being_deleted = await prisma_client.db.litellm_verificationtoken.find_many(
where={"key_alias": {"in": key_aliases}}
)
tokens = [key.token for key in _keys_being_deleted]
return await delete_verification_token(
tokens=tokens, user_api_key_cache=user_api_key_cache, user_id=user_id
)
@router.post( @router.post(
"/key/{key:path}/regenerate", "/key/{key:path}/regenerate",
tags=["key management"], tags=["key management"],