fix(proxy_server.py): prevent user from deleting non-user owned keys when they use ui

This commit is contained in:
Krrish Dholakia 2024-03-11 12:13:30 -07:00
parent 40c9682de7
commit 4eb244c3ca
2 changed files with 44 additions and 12 deletions

View file

@ -2103,12 +2103,14 @@ async def generate_key_helper_fn(
return key_data return key_data
async def delete_verification_token(tokens: List): async def delete_verification_token(tokens: List, user_id: Optional[str] = None):
global prisma_client global prisma_client
try: try:
if prisma_client: if prisma_client:
# Assuming 'db' is your Prisma Client instance # Assuming 'db' is your Prisma Client instance
deleted_tokens = await prisma_client.delete_data(tokens=tokens) deleted_tokens = await prisma_client.delete_data(
tokens=tokens, user_id=user_id
)
else: else:
raise Exception raise Exception
except Exception as e: except Exception as e:
@ -3744,7 +3746,10 @@ async def update_key_fn(request: Request, data: UpdateKeyRequest):
@router.post( @router.post(
"/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)] "/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)]
) )
async def delete_key_fn(data: KeyRequest): async def delete_key_fn(
data: KeyRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
""" """
Delete a key from the key management system. Delete a key from the key management system.
@ -3769,11 +3774,28 @@ async def delete_key_fn(data: KeyRequest):
code=status.HTTP_400_BAD_REQUEST, code=status.HTTP_400_BAD_REQUEST,
) )
result = await delete_verification_token(tokens=keys) ## only allow user to delete keys they own
verbose_proxy_logger.debug("/key/delete - deleted_keys=", result) user_id = user_api_key_dict.user_id
if (
user_api_key_dict.user_role is not None
and user_api_key_dict.user_role == "proxy_admin"
):
user_id = None # unless they're admin
number_deleted_keys = len(result["deleted_keys"]) number_deleted_keys = await delete_verification_token(
assert len(keys) == number_deleted_keys tokens=keys, user_id=user_id
)
verbose_proxy_logger.debug("/key/delete - deleted_keys=", number_deleted_keys)
try:
assert len(keys) == number_deleted_keys
except Exception as e:
raise HTTPException(
status_code=400,
detail={
"error": "Not all keys passed in were deleted. This probably means you don't have access to delete all the keys passed in."
},
)
for key in keys: for key in keys:
user_api_key_cache.delete_cache(key) user_api_key_cache.delete_cache(key)
@ -6529,8 +6551,6 @@ async def login(request: Request):
algorithm="HS256", algorithm="HS256",
) )
litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token
# if a user has logged in they should be allowed to create keys - this ensures that it's set to True
general_settings["allow_user_auth"] = True
return RedirectResponse(url=litellm_dashboard_ui, status_code=303) return RedirectResponse(url=litellm_dashboard_ui, status_code=303)
else: else:
raise ProxyException( raise ProxyException(

View file

@ -1356,9 +1356,12 @@ class PrismaClient:
tokens: Optional[List] = None, tokens: Optional[List] = None,
team_id_list: Optional[List] = None, team_id_list: Optional[List] = None,
table_name: Optional[Literal["user", "key", "config", "spend", "team"]] = None, table_name: Optional[Literal["user", "key", "config", "spend", "team"]] = None,
user_id: Optional[str] = None,
): ):
""" """
Allow user to delete a key(s) Allow user to delete a key(s)
Ensure user owns that key, unless admin.
""" """
try: try:
if tokens is not None and isinstance(tokens, List): if tokens is not None and isinstance(tokens, List):
@ -1369,15 +1372,23 @@ class PrismaClient:
else: else:
hashed_token = token hashed_token = token
hashed_tokens.append(hashed_token) hashed_tokens.append(hashed_token)
await self.db.litellm_verificationtoken.delete_many( filter_query: dict = {}
where={"token": {"in": hashed_tokens}} if user_id is not None:
filter_query = {
"AND": [{"token": {"in": hashed_tokens}}, {"user_id": user_id}]
}
else:
filter_query = {"token": {"in": hashed_tokens}}
deleted_tokens = await self.db.litellm_verificationtoken.delete_many(
where=filter_query # type: ignore
) )
return {"deleted_keys": tokens} return {"deleted_keys": deleted_tokens}
elif ( elif (
table_name == "team" table_name == "team"
and team_id_list is not None and team_id_list is not None
and isinstance(team_id_list, List) and isinstance(team_id_list, List)
): ):
# admin only endpoint -> `/team/delete`
await self.db.litellm_teamtable.delete_many( await self.db.litellm_teamtable.delete_many(
where={"team_id": {"in": team_id_list}} where={"team_id": {"in": team_id_list}}
) )
@ -1387,6 +1398,7 @@ class PrismaClient:
and team_id_list is not None and team_id_list is not None
and isinstance(team_id_list, List) and isinstance(team_id_list, List)
): ):
# admin only endpoint -> `/team/delete`
await self.db.litellm_verificationtoken.delete_many( await self.db.litellm_verificationtoken.delete_many(
where={"team_id": {"in": team_id_list}} where={"team_id": {"in": team_id_list}}
) )