From 7feb1fb65d415b7c950b24800c01ea982b10b35d Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 12 Apr 2025 16:33:41 -0700 Subject: [PATCH] fix(key_management_endpoints.py): fix filtering litellm-dashboard keys for internal users --- .../proxy/_experimental/out/onboarding.html | 1 - .../key_management_endpoints.py | 41 ++++++++-------- .../test_key_management_endpoints.py | 48 +++++++++++++++++++ 3 files changed, 68 insertions(+), 22 deletions(-) delete mode 100644 litellm/proxy/_experimental/out/onboarding.html create mode 100644 tests/litellm/proxy/management_endpoints/test_key_management_endpoints.py diff --git a/litellm/proxy/_experimental/out/onboarding.html b/litellm/proxy/_experimental/out/onboarding.html deleted file mode 100644 index 1b1ad5c2cc..0000000000 --- a/litellm/proxy/_experimental/out/onboarding.html +++ /dev/null @@ -1 +0,0 @@ -LiteLLM Dashboard \ No newline at end of file diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 11f6e5e603..d37163d2ef 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -577,9 +577,9 @@ async def generate_key_fn( # noqa: PLR0915 request_type="key", **data_json, table_name="key" ) - response["soft_budget"] = ( - data.soft_budget - ) # include the user-input soft budget in the response + response[ + "soft_budget" + ] = data.soft_budget # include the user-input soft budget in the response response = GenerateKeyResponse(**response) @@ -1467,10 +1467,10 @@ async def delete_verification_tokens( try: if prisma_client: tokens = [_hash_token_if_needed(token=key) for key in tokens] - _keys_being_deleted: List[LiteLLM_VerificationToken] = ( - await prisma_client.db.litellm_verificationtoken.find_many( - where={"token": {"in": tokens}} - ) + _keys_being_deleted: List[ + LiteLLM_VerificationToken + ] = await prisma_client.db.litellm_verificationtoken.find_many( + where={"token": {"in": tokens}} ) # Assuming 'db' is your Prisma Client instance @@ -1572,9 +1572,9 @@ async def _rotate_master_key( from litellm.proxy.proxy_server import proxy_config try: - models: Optional[List] = ( - await prisma_client.db.litellm_proxymodeltable.find_many() - ) + models: Optional[ + List + ] = await prisma_client.db.litellm_proxymodeltable.find_many() except Exception: models = None # 2. process model table @@ -1861,11 +1861,11 @@ async def validate_key_list_check( param="user_id", code=status.HTTP_403_FORBIDDEN, ) - complete_user_info_db_obj: Optional[BaseModel] = ( - await prisma_client.db.litellm_usertable.find_unique( - where={"user_id": user_api_key_dict.user_id}, - include={"organization_memberships": True}, - ) + complete_user_info_db_obj: Optional[ + BaseModel + ] = await prisma_client.db.litellm_usertable.find_unique( + where={"user_id": user_api_key_dict.user_id}, + include={"organization_memberships": True}, ) if complete_user_info_db_obj is None: @@ -1926,10 +1926,10 @@ async def get_admin_team_ids( if complete_user_info is None: return [] # Get all teams that user is an admin of - teams: Optional[List[BaseModel]] = ( - await prisma_client.db.litellm_teamtable.find_many( - where={"team_id": {"in": complete_user_info.teams}} - ) + teams: Optional[ + List[BaseModel] + ] = await prisma_client.db.litellm_teamtable.find_many( + where={"team_id": {"in": complete_user_info.teams}} ) if teams is None: return [] @@ -2080,7 +2080,6 @@ async def _list_key_helper( "total_pages": int, } """ - # Prepare filter conditions where: Dict[str, Union[str, Dict[str, Any], List[Dict[str, Any]]]] = {} where.update(_get_condition_to_filter_out_ui_session_tokens()) @@ -2110,7 +2109,7 @@ async def _list_key_helper( # Combine conditions with OR if we have multiple conditions if len(or_conditions) > 1: - where["OR"] = or_conditions + where = {"AND": [where, {"OR": or_conditions}]} elif len(or_conditions) == 1: where.update(or_conditions[0]) diff --git a/tests/litellm/proxy/management_endpoints/test_key_management_endpoints.py b/tests/litellm/proxy/management_endpoints/test_key_management_endpoints.py new file mode 100644 index 0000000000..51bbbb49c4 --- /dev/null +++ b/tests/litellm/proxy/management_endpoints/test_key_management_endpoints.py @@ -0,0 +1,48 @@ +import json +import os +import sys + +import pytest +from fastapi.testclient import TestClient + +sys.path.insert( + 0, os.path.abspath("../../../..") +) # Adds the parent directory to the system path + +from unittest.mock import AsyncMock, MagicMock + +from litellm.proxy.management_endpoints.key_management_endpoints import _list_key_helper +from litellm.proxy.proxy_server import app + +client = TestClient(app) + + +@pytest.mark.asyncio +async def test_list_keys(): + mock_prisma_client = AsyncMock() + mock_find_many = AsyncMock(return_value=[]) + mock_prisma_client.db.litellm_verificationtoken.find_many = mock_find_many + args = { + "prisma_client": mock_prisma_client, + "page": 1, + "size": 50, + "user_id": "cda88cb4-cc2c-4e8c-b871-dc71ca111b00", + "team_id": None, + "organization_id": None, + "key_alias": None, + "exclude_team_id": None, + "return_full_object": True, + "admin_team_ids": ["28bd3181-02c5-48f2-b408-ce790fb3d5ba"], + } + try: + result = await _list_key_helper(**args) + except Exception as e: + print(f"error: {e}") + + mock_find_many.assert_called_once() + + where_condition = mock_find_many.call_args.kwargs["where"] + print(f"where_condition: {where_condition}") + assert json.dumps({"team_id": {"not": "litellm-dashboard"}}) in json.dumps( + where_condition + )