diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 15edab8909..068e867f3c 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -790,6 +790,19 @@ async def update_key_fn( data=data, existing_key_row=existing_key_row ) + is_admin = ( + user_api_key_dict.user_role is not None + and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value + ) + + if not is_admin: + # Ensure user_id is preserved when not specified by admin + if non_default_values.get("user_id", None) is None: + non_default_values["user_id"] = existing_key_row.user_id + elif "user_id" not in non_default_values: + # preserve user_id from existing key only when admin does not specify to unset it + non_default_values["user_id"] = existing_key_row.user_id + await _enforce_unique_key_alias( key_alias=non_default_values.get("key_alias", None), prisma_client=prisma_client, diff --git a/tests/litellm/proxy/test_key_management_endpoints.py b/tests/litellm/proxy/test_key_management_endpoints.py new file mode 100644 index 0000000000..f89d842c76 --- /dev/null +++ b/tests/litellm/proxy/test_key_management_endpoints.py @@ -0,0 +1,69 @@ +import json +import os +import sys +import pytest +from fastapi.testclient import TestClient +from litellm.proxy._types import LiteLLM_VerificationToken, LitellmUserRoles, UserAPIKeyAuth +from litellm.proxy.proxy_server import app + +sys.path.insert( + 0, os.path.abspath("../../..") +) # Adds the parent directory to the system path + +class MockPrismaClient: + def __init__(self): + self.db = self + self.litellm_verificationtoken = self + + async def find_unique(self, where): + return LiteLLM_VerificationToken( + token="sk-existing", + user_id="user-123", + team_id=None, + key_name="test-key" + ) + + async def find_first(self, where): + return None + + async def get_data(self, token, table_name, query_type="find_unique"): + return await self.find_unique({"token": token}) + + async def update_data(self, token, data): + self.last_update_data = data # Store the update data for test verification + return {"data": data} + +@pytest.fixture +def test_client(): + return TestClient(app) + +@pytest.fixture +def mock_prisma(): + return MockPrismaClient() + +@pytest.fixture(autouse=True) +def mock_user_auth(mocker): + return mocker.patch( + "litellm.proxy.auth.user_api_key_auth", + return_value=UserAPIKeyAuth( + api_key="sk-auth", + user_id="user-123", + team_id=None, + user_role=LitellmUserRoles.INTERNAL_USER.value + ) + ) + +def test_user_id_not_reset_on_key_update(test_client, mock_prisma, mocker): + mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma) + + response = test_client.post( + "/key/update", + headers={"Authorization": "Bearer sk-auth"}, + json={ + "key": "sk-existing", + "key_alias": "new-alias" + } + ) + + assert response.status_code == 200 + assert mock_prisma.last_update_data["user_id"] == "user-123" diff --git a/ui/litellm-dashboard/src/components/key_info_view.tsx b/ui/litellm-dashboard/src/components/key_info_view.tsx index b7ebdc651a..c0b31cb942 100644 --- a/ui/litellm-dashboard/src/components/key_info_view.tsx +++ b/ui/litellm-dashboard/src/components/key_info_view.tsx @@ -64,6 +64,11 @@ export default function KeyInfoView({ keyId, onClose, keyData, accessToken, user const currentKey = formValues.token; formValues.key = currentKey; + // Explicitly set user_id to null if not present + // if (!('user_id' in formValues)) { + // formValues.user_id = null; + // } + // Convert metadata back to an object if it exists and is a string if (formValues.metadata && typeof formValues.metadata === "string") { try {