fix(key_management_endpoints.py): fix filtering litellm-dashboard keys for internal users

This commit is contained in:
Krrish Dholakia 2025-04-12 16:33:41 -07:00
parent 93e147940a
commit 7feb1fb65d
3 changed files with 68 additions and 22 deletions

File diff suppressed because one or more lines are too long

View file

@ -577,9 +577,9 @@ async def generate_key_fn( # noqa: PLR0915
request_type="key", **data_json, table_name="key" request_type="key", **data_json, table_name="key"
) )
response["soft_budget"] = ( response[
data.soft_budget "soft_budget"
) # include the user-input soft budget in the response ] = data.soft_budget # include the user-input soft budget in the response
response = GenerateKeyResponse(**response) response = GenerateKeyResponse(**response)
@ -1467,11 +1467,11 @@ async def delete_verification_tokens(
try: try:
if prisma_client: if prisma_client:
tokens = [_hash_token_if_needed(token=key) for key in tokens] tokens = [_hash_token_if_needed(token=key) for key in tokens]
_keys_being_deleted: List[LiteLLM_VerificationToken] = ( _keys_being_deleted: List[
await prisma_client.db.litellm_verificationtoken.find_many( LiteLLM_VerificationToken
] = await prisma_client.db.litellm_verificationtoken.find_many(
where={"token": {"in": tokens}} where={"token": {"in": tokens}}
) )
)
# Assuming 'db' is your Prisma Client instance # Assuming 'db' is your Prisma Client instance
# check if admin making request - don't filter by user-id # check if admin making request - don't filter by user-id
@ -1572,9 +1572,9 @@ async def _rotate_master_key(
from litellm.proxy.proxy_server import proxy_config from litellm.proxy.proxy_server import proxy_config
try: try:
models: Optional[List] = ( models: Optional[
await prisma_client.db.litellm_proxymodeltable.find_many() List
) ] = await prisma_client.db.litellm_proxymodeltable.find_many()
except Exception: except Exception:
models = None models = None
# 2. process model table # 2. process model table
@ -1861,12 +1861,12 @@ async def validate_key_list_check(
param="user_id", param="user_id",
code=status.HTTP_403_FORBIDDEN, code=status.HTTP_403_FORBIDDEN,
) )
complete_user_info_db_obj: Optional[BaseModel] = ( complete_user_info_db_obj: Optional[
await prisma_client.db.litellm_usertable.find_unique( BaseModel
] = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_api_key_dict.user_id}, where={"user_id": user_api_key_dict.user_id},
include={"organization_memberships": True}, include={"organization_memberships": True},
) )
)
if complete_user_info_db_obj is None: if complete_user_info_db_obj is None:
raise ProxyException( raise ProxyException(
@ -1926,11 +1926,11 @@ async def get_admin_team_ids(
if complete_user_info is None: if complete_user_info is None:
return [] return []
# Get all teams that user is an admin of # Get all teams that user is an admin of
teams: Optional[List[BaseModel]] = ( teams: Optional[
await prisma_client.db.litellm_teamtable.find_many( List[BaseModel]
] = await prisma_client.db.litellm_teamtable.find_many(
where={"team_id": {"in": complete_user_info.teams}} where={"team_id": {"in": complete_user_info.teams}}
) )
)
if teams is None: if teams is None:
return [] return []
@ -2080,7 +2080,6 @@ async def _list_key_helper(
"total_pages": int, "total_pages": int,
} }
""" """
# Prepare filter conditions # Prepare filter conditions
where: Dict[str, Union[str, Dict[str, Any], List[Dict[str, Any]]]] = {} where: Dict[str, Union[str, Dict[str, Any], List[Dict[str, Any]]]] = {}
where.update(_get_condition_to_filter_out_ui_session_tokens()) where.update(_get_condition_to_filter_out_ui_session_tokens())
@ -2110,7 +2109,7 @@ async def _list_key_helper(
# Combine conditions with OR if we have multiple conditions # Combine conditions with OR if we have multiple conditions
if len(or_conditions) > 1: if len(or_conditions) > 1:
where["OR"] = or_conditions where = {"AND": [where, {"OR": or_conditions}]}
elif len(or_conditions) == 1: elif len(or_conditions) == 1:
where.update(or_conditions[0]) where.update(or_conditions[0])

View file

@ -0,0 +1,48 @@
import json
import os
import sys
import pytest
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../../..")
) # Adds the parent directory to the system path
from unittest.mock import AsyncMock, MagicMock
from litellm.proxy.management_endpoints.key_management_endpoints import _list_key_helper
from litellm.proxy.proxy_server import app
client = TestClient(app)
@pytest.mark.asyncio
async def test_list_keys():
mock_prisma_client = AsyncMock()
mock_find_many = AsyncMock(return_value=[])
mock_prisma_client.db.litellm_verificationtoken.find_many = mock_find_many
args = {
"prisma_client": mock_prisma_client,
"page": 1,
"size": 50,
"user_id": "cda88cb4-cc2c-4e8c-b871-dc71ca111b00",
"team_id": None,
"organization_id": None,
"key_alias": None,
"exclude_team_id": None,
"return_full_object": True,
"admin_team_ids": ["28bd3181-02c5-48f2-b408-ce790fb3d5ba"],
}
try:
result = await _list_key_helper(**args)
except Exception as e:
print(f"error: {e}")
mock_find_many.assert_called_once()
where_condition = mock_find_many.call_args.kwargs["where"]
print(f"where_condition: {where_condition}")
assert json.dumps({"team_id": {"not": "litellm-dashboard"}}) in json.dumps(
where_condition
)