diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 4accfbc09..01bb5e090 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -1134,6 +1134,7 @@ async def regenerate_key_fn( ) @management_endpoint_wrapper async def list_keys( + request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), page: int = Query(1, description="Page number", ge=1), size: int = Query(10, description="Page size", ge=1, le=100), @@ -1152,13 +1153,24 @@ async def list_keys( logging.error("Database not connected") raise Exception("Database not connected") + # Check for unsupported parameters + supported_params = {"page", "size", "user_id", "team_id", "key_alias"} + unsupported_params = set(request.query_params.keys()) - supported_params + if unsupported_params: + raise ProxyException( + message=f"Unsupported parameter(s): {', '.join(unsupported_params)}. Supported parameters: {', '.join(supported_params)}", + type=ProxyErrorTypes.bad_request_error, + param=", ".join(unsupported_params), + code=status.HTTP_400_BAD_REQUEST, + ) + # Prepare filter conditions where = {} - if user_id: + if user_id and isinstance(user_id, str): where["user_id"] = user_id - if team_id: + if team_id and isinstance(team_id, str): where["team_id"] = team_id - if key_alias: + if key_alias and isinstance(key_alias, str): where["key_alias"] = key_alias logging.debug(f"Filter conditions: {where}") @@ -1206,9 +1218,18 @@ async def list_keys( return response except Exception as e: + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"error({str(e)})"), + type=ProxyErrorTypes.internal_server_error, + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR), + ) + elif isinstance(e, ProxyException): + raise e raise ProxyException( - message=f"Error listing keys: {str(e)}", - type=ProxyErrorTypes.internal_server_error, # Use the enum value - param=None, + message="Authentication Error, " + str(e), + type=ProxyErrorTypes.internal_server_error, + param=getattr(e, "param", "None"), code=status.HTTP_500_INTERNAL_SERVER_ERROR, ) diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index 06b087ee7..0a9264c9e 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -56,6 +56,7 @@ from litellm.proxy.management_endpoints.key_management_endpoints import ( generate_key_fn, generate_key_helper_fn, info_key_fn, + list_keys, regenerate_key_fn, update_key_fn, ) @@ -3123,3 +3124,100 @@ async def test_admin_only_routes(prisma_client): pass setattr(proxy_server, "general_settings", initial_general_settings) + + +@pytest.mark.asyncio +async def test_list_keys(prisma_client): + """ + Test the list_keys function: + - Test basic key + - Test pagination + - Test filtering by user_id, and key_alias + """ + from fastapi import Query + + from litellm.proxy.proxy_server import hash_token + + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + await litellm.proxy.proxy_server.prisma_client.connect() + + # Test basic listing + request = Request(scope={"type": "http", "query_string": b""}) + response = await list_keys( + request, + UserAPIKeyAuth(), + page=1, + size=10, + ) + print("response=", response) + assert "keys" in response + assert len(response["keys"]) > 0 + assert "total_count" in response + assert "current_page" in response + assert "total_pages" in response + + # Test pagination + response = await list_keys(request, UserAPIKeyAuth(), page=1, size=2) + print("pagination response=", response) + assert len(response["keys"]) == 2 + assert response["current_page"] == 1 + + # Test filtering by user_id + + unique_id = str(uuid.uuid4()) + team_id = f"key-list-team-{unique_id}" + key_alias = f"key-list-alias-{unique_id}" + user_id = f"key-list-user-{unique_id}" + response = await new_user( + data=NewUserRequest( + user_id=f"key-list-user-{unique_id}", + user_role=LitellmUserRoles.INTERNAL_USER, + key_alias=f"key-list-alias-{unique_id}", + ), + user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), + ) + + _key = hash_token(response.key) + + await asyncio.sleep(2) + + # Test filtering by user_id + response = await list_keys( + request, UserAPIKeyAuth(), user_id=user_id, page=1, size=10 + ) + print("filtered user_id response=", response) + assert len(response["keys"]) == 1 + assert _key in response["keys"] + + # Test filtering by key_alias + response = await list_keys( + request, UserAPIKeyAuth(), key_alias=key_alias, page=1, size=10 + ) + assert len(response["keys"]) == 1 + assert _key in response["keys"] + + +@pytest.mark.asyncio +async def test_key_list_unsupported_params(prisma_client): + """ + Test the list_keys function: + - Test unsupported params + """ + + from litellm.proxy.proxy_server import hash_token + + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + await litellm.proxy.proxy_server.prisma_client.connect() + + request = Request(scope={"type": "http", "query_string": b"alias=foo"}) + + try: + await list_keys(request, UserAPIKeyAuth(), page=1, size=10) + pytest.fail("Expected this call to fail") + except Exception as e: + print("error str=", str(e.message)) + error_str = str(e.message) + assert "Unsupported parameter" in error_str + pass