diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index fff4c16b2d..62c74960eb 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -1676,15 +1676,41 @@ async def regenerate_key_fn( raise handle_exception_on_proxy(e) -def validate_key_list_check( - complete_user_info: LiteLLM_UserTable, +async def validate_key_list_check( + user_api_key_dict: UserAPIKeyAuth, user_id: Optional[str], team_id: Optional[str], organization_id: Optional[str], key_alias: Optional[str], + prisma_client: PrismaClient, ): - if complete_user_info.user_role == LitellmUserRoles.PROXY_ADMIN.value: - return # proxy admin can see all keys + + if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value: + return + + if user_api_key_dict.user_id is None: + raise ProxyException( + message="You are not authorized to access this endpoint. No 'user_id' is associated with your API key.", + type=ProxyErrorTypes.bad_request_error, + param="user_id", + code=status.HTTP_403_FORBIDDEN, + ) + complete_user_info_db_obj: Optional[BaseModel] = ( + await prisma_client.db.litellm_usertable.find_unique( + where={"user_id": user_api_key_dict.user_id}, + include={"organization_memberships": True}, + ) + ) + + if complete_user_info_db_obj is None: + raise ProxyException( + message="You are not authorized to access this endpoint. No 'user_id' is associated with your API key.", + type=ProxyErrorTypes.bad_request_error, + param="user_id", + code=status.HTTP_403_FORBIDDEN, + ) + + complete_user_info = LiteLLM_UserTable(**complete_user_info_db_obj.model_dump()) # internal user can only see their own keys if user_id: @@ -1779,42 +1805,16 @@ async def list_keys( verbose_proxy_logger.error("Database not connected") raise Exception("Database not connected") - if not user_api_key_dict.user_id: - raise ProxyException( - message="You are not authorized to access this endpoint. No 'user_id' is associated with your API key.", - type=ProxyErrorTypes.bad_request_error, - param="user_id", - code=status.HTTP_403_FORBIDDEN, - ) - - complete_user_info: Optional[BaseModel] = ( - await prisma_client.db.litellm_usertable.find_unique( - where={"user_id": user_api_key_dict.user_id}, - include={"organization_memberships": True}, - ) - ) - - if complete_user_info is None: - raise ProxyException( - message="You are not authorized to access this endpoint. No 'user_id' is associated with your API key.", - type=ProxyErrorTypes.bad_request_error, - param="user_id", - code=status.HTTP_403_FORBIDDEN, - ) - - complete_user_info_pydantic_obj = LiteLLM_UserTable( - **complete_user_info.model_dump() - ) - - validate_key_list_check( - complete_user_info=complete_user_info_pydantic_obj, + await validate_key_list_check( + user_api_key_dict=user_api_key_dict, user_id=user_id, team_id=team_id, organization_id=organization_id, key_alias=key_alias, + prisma_client=prisma_client, ) - if user_id is None and complete_user_info_pydantic_obj.user_role != [ + if user_id is None and user_api_key_dict.user_role not in [ LitellmUserRoles.PROXY_ADMIN.value, LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value, ]: diff --git a/tests/proxy_unit_tests/test_key_generate_prisma.py b/tests/proxy_unit_tests/test_key_generate_prisma.py index ecd14afed7..9a8fcdeb15 100644 --- a/tests/proxy_unit_tests/test_key_generate_prisma.py +++ b/tests/proxy_unit_tests/test_key_generate_prisma.py @@ -3359,6 +3359,7 @@ async def test_list_keys(prisma_client): from fastapi import Query from litellm.proxy.proxy_server import hash_token + from litellm.proxy._types import LitellmUserRoles setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") @@ -3368,7 +3369,9 @@ async def test_list_keys(prisma_client): request = Request(scope={"type": "http", "query_string": b""}) response = await list_keys( request, - UserAPIKeyAuth(), + UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN.value, + ), page=1, size=10, ) @@ -3380,7 +3383,12 @@ async def test_list_keys(prisma_client): assert "total_pages" in response # Test pagination - response = await list_keys(request, UserAPIKeyAuth(), page=1, size=2) + response = await list_keys( + request, + UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN.value), + page=1, + size=2, + ) print("pagination response=", response) assert len(response["keys"]) == 2 assert response["current_page"] == 1 @@ -3406,7 +3414,11 @@ async def test_list_keys(prisma_client): # Test filtering by user_id response = await list_keys( - request, UserAPIKeyAuth(), user_id=user_id, page=1, size=10 + request, + UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN.value), + user_id=user_id, + page=1, + size=10, ) print("filtered user_id response=", response) assert len(response["keys"]) == 1 @@ -3414,7 +3426,11 @@ async def test_list_keys(prisma_client): # Test filtering by key_alias response = await list_keys( - request, UserAPIKeyAuth(), key_alias=key_alias, page=1, size=10 + request, + UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN.value), + key_alias=key_alias, + page=1, + size=10, ) assert len(response["keys"]) == 1 assert _key in response["keys"] @@ -3436,7 +3452,12 @@ async def test_key_list_unsupported_params(prisma_client): request = Request(scope={"type": "http", "query_string": b"alias=foo"}) try: - await list_keys(request, UserAPIKeyAuth(), page=1, size=10) + await list_keys( + request, + UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN.value), + page=1, + size=10, + ) pytest.fail("Expected this call to fail") except Exception as e: print("error str=", str(e.message))