From 1b4d3db170e516fab6b66f57e47ac0455a971ad3 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 15 Feb 2025 16:25:57 -0800 Subject: [PATCH] (Patch/bug fix) - UI, filter out litellm ui session tokens on Virtual Keys Page (#8568) * fix key list endpoint * _get_condition_to_filter_out_ui_session_tokens * duration_in_seconds * test_list_key_helper_team_filtering --- .../key_management_endpoints.py | 19 ++- .../test_key_management.py | 115 ++++++++++++++++++ 2 files changed, 133 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 9e3403d437..d7b939e066 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -24,6 +24,7 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, s import litellm from litellm._logging import verbose_proxy_logger from litellm.caching import DualCache +from litellm.constants import UI_SESSION_TOKEN_TEAM_ID from litellm.litellm_core_utils.duration_parser import duration_in_seconds from litellm.proxy._types import * from litellm.proxy.auth.auth_checks import ( @@ -1885,7 +1886,9 @@ async def _list_key_helper( """ # Prepare filter conditions - where: Dict[str, Union[str, Dict[str, str]]] = {} + where: Dict[str, Union[str, Dict[str, Any]]] = {} + where.update(_get_condition_to_filter_out_ui_session_tokens()) + if user_id and isinstance(user_id, str): where["user_id"] = user_id if team_id and isinstance(team_id, str): @@ -1938,6 +1941,20 @@ async def _list_key_helper( ) +def _get_condition_to_filter_out_ui_session_tokens() -> Dict[str, Any]: + """ + Condition to filter out UI session tokens + """ + return { + "OR": [ + {"team_id": None}, # Include records where team_id is null + { + "team_id": {"not": UI_SESSION_TOKEN_TEAM_ID} + }, # Include records where team_id != UI_SESSION_TOKEN_TEAM_ID + ] + } + + @router.post( "/key/block", tags=["key management"], dependencies=[Depends(user_api_key_auth)] ) diff --git a/tests/proxy_admin_ui_tests/test_key_management.py b/tests/proxy_admin_ui_tests/test_key_management.py index 4fb94a5462..55738923a3 100644 --- a/tests/proxy_admin_ui_tests/test_key_management.py +++ b/tests/proxy_admin_ui_tests/test_key_management.py @@ -999,6 +999,121 @@ async def test_list_key_helper(prisma_client): ) +@pytest.mark.asyncio +async def test_list_key_helper_team_filtering(prisma_client): + """ + Test _list_key_helper function's team filtering behavior: + 1. Create keys with different team_ids (None, litellm-dashboard, other) + 2. Verify filtering excludes litellm-dashboard keys + 3. Verify keys with team_id=None are included + 4. Test with pagination to ensure behavior is consistent across pages + """ + from litellm.proxy.management_endpoints.key_management_endpoints import ( + _list_key_helper, + ) + + # Setup + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + await litellm.proxy.proxy_server.prisma_client.connect() + + # Create test data with different team_ids + test_keys = [] + + # Create 3 keys with team_id=None + for i in range(3): + key = await generate_key_fn( + data=GenerateKeyRequest( + key_alias=f"no_team_key_{i}", + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="admin", + ), + ) + test_keys.append(key) + + # Create 2 keys with team_id=litellm-dashboard + for i in range(2): + key = await generate_key_fn( + data=GenerateKeyRequest( + team_id="litellm-dashboard", + key_alias=f"dashboard_key_{i}", + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="admin", + ), + ) + test_keys.append(key) + + # Create 2 keys with a different team_id + other_team_id = f"other_team_{uuid.uuid4()}" + for i in range(2): + key = await generate_key_fn( + data=GenerateKeyRequest( + team_id=other_team_id, + key_alias=f"other_team_key_{i}", + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="admin", + ), + ) + test_keys.append(key) + + try: + # Test 1: Get all keys with pagination (exclude litellm-dashboard) + all_keys = [] + page = 1 + while True: + result = await _list_key_helper( + prisma_client=prisma_client, + size=100, + page=page, + user_id=None, + team_id=None, + key_alias=None, + return_full_object=True, + ) + + all_keys.extend(result["keys"]) + + if page >= result["total_pages"]: + break + page += 1 + + # Verify results + print(f"Total keys found: {len(all_keys)}") + for key in all_keys: + print(f"Key team_id: {key.team_id}, alias: {key.key_alias}") + + # Verify no litellm-dashboard keys are present + dashboard_keys = [k for k in all_keys if k.team_id == "litellm-dashboard"] + assert len(dashboard_keys) == 0, "Should not include litellm-dashboard keys" + + # Verify keys with team_id=None are included + no_team_keys = [k for k in all_keys if k.team_id is None] + assert ( + len(no_team_keys) > 0 + ), f"Expected more than 0 keys with no team, got {len(no_team_keys)}" + + finally: + # Clean up test keys + for key in test_keys: + await delete_key_fn( + data=KeyRequest(keys=[key.key]), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="admin", + ), + ) + + @pytest.mark.asyncio @patch("litellm.proxy.management_endpoints.key_management_endpoints.get_team_object") async def test_key_generate_always_db_team(mock_get_team_object):