From 6d03d71773f96b29228c6784f5c19d9943f05ccc Mon Sep 17 00:00:00 2001 From: NisanthChsr Date: Thu, 13 Mar 2025 00:20:53 -0400 Subject: [PATCH 1/4] fix: security checks for team and user ids --- .../key_management_endpoints.py | 35 +++++++++++-- .../test_team_models.py | 51 +++++++++++++++++++ 2 files changed, 83 insertions(+), 3 deletions(-) diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 9141d9d14a..986e4eb8ed 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -626,7 +626,9 @@ def prepare_metadata_fields( def prepare_key_update_data( - data: Union[UpdateKeyRequest, RegenerateKeyRequest], existing_key_row + data: Union[UpdateKeyRequest, RegenerateKeyRequest], + existing_key_row, + user_api_key_dict: Optional[UserAPIKeyAuth] = None ): data_json: dict = data.model_dump(exclude_unset=True) data_json.pop("key", None) @@ -636,6 +638,31 @@ def prepare_key_update_data( continue non_default_values[k] = v + # Check if user is trying to modify user_id or team_id + if user_api_key_dict is not None and user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value: + # For internal users, prevent changing user_id to null or to another user's id + if "user_id" in non_default_values: + new_user_id = non_default_values["user_id"] + if new_user_id is None or (user_api_key_dict.user_id is not None and new_user_id != user_api_key_dict.user_id): + raise HTTPException( + status_code=403, + detail={"error": "You are not authorized to change the user_id of this key"} + ) + + # For internal users, prevent changing team_id to null or to a team they don't belong to + if "team_id" in non_default_values: + new_team_id = non_default_values["team_id"] + user_teams = user_api_key_dict.teams or [] + if new_team_id is None or (len(user_teams) > 0 and new_team_id not in user_teams): + raise HTTPException( + status_code=403, + detail={"error": "You are not authorized to change the team_id to a team you don't belong to"} + ) + + # Preserve user_id if not explicitly provided + if "user_id" not in non_default_values and existing_key_row.user_id: + non_default_values["user_id"] = existing_key_row.user_id + if "duration" in non_default_values: duration = non_default_values.pop("duration") if duration and (isinstance(duration, str)) and len(duration) > 0: @@ -764,7 +791,9 @@ async def update_key_fn( ) non_default_values = prepare_key_update_data( - data=data, existing_key_row=existing_key_row + data=data, + existing_key_row=existing_key_row, + user_api_key_dict=user_api_key_dict ) await _enforce_unique_key_alias( @@ -1761,7 +1790,7 @@ async def regenerate_key_fn( if data is not None: # Update with any provided parameters from GenerateKeyRequest non_default_values = prepare_key_update_data( - data=data, existing_key_row=_key_in_db + data=data, existing_key_row=_key_in_db, user_api_key_dict=user_api_key_dict ) verbose_proxy_logger.debug("non_default_values: %s", non_default_values) diff --git a/tests/store_model_in_db_tests/test_team_models.py b/tests/store_model_in_db_tests/test_team_models.py index 0faa01c8ee..e27b971aa9 100644 --- a/tests/store_model_in_db_tests/test_team_models.py +++ b/tests/store_model_in_db_tests/test_team_models.py @@ -310,3 +310,54 @@ async def test_team_model_visibility_in_model_info_endpoint(): # Cleanup model_id = model_response.json()["model_info"]["id"] await client.post("/model/delete", json={"id": model_id}, headers=headers) + +@pytest.mark.asyncio +async def test_key_management_with_team_models(): + """ + Test key management with team models: + 1. Create a team (mock) + 2. Verify team model permissions + 3. Test key generation with team model access + 4. Verify key info shows correct permissions + """ + client = AsyncClient(base_url=PROXY_BASE_URL) + headers = {"Authorization": f"Bearer {TEST_MASTER_KEY}"} + + # Use a fixed team ID and model name for testing + team_id = "test_team_123" + model_name = "gpt-4-team-test" + + # Mock team creation response + team_response = await client.post( + "/team/new", + json={ + "team_id": team_id, + "models": [model_name] + }, + headers=headers, + ) + assert team_response.status_code == 200 + + # Generate key for team with model access + key_response = await client.post( + "/key/generate", + json={ + "team_id": team_id, + "models": [model_name] + }, + headers=headers, + ) + assert key_response.status_code == 200 + team_key = key_response.json()["key"] + + # Verify key info shows correct permissions + key_info_response = await client.get( + "/key/info", + headers={"Authorization": f"Bearer {team_key}"}, + ) + assert key_info_response.status_code == 200 + key_info = key_info_response.json() + + # Verify key has correct team and model permissions + assert team_id == key_info.get("team_id"), "Key should be associated with correct team" + assert model_name in key_info.get("models", []), "Key should have access to team model" From ef6f0bfa1db7453e0be461193edd01fae709b70f Mon Sep 17 00:00:00 2001 From: NisanthChsr Date: Thu, 13 Mar 2025 20:20:36 -0400 Subject: [PATCH 2/4] fix: ensure user_id is not reset during key updation calls --- .../key_management_endpoints.py | 34 +----- .../proxy/test_key_management_endpoints.py | 106 ++++++++++++++++++ .../test_team_models.py | 51 --------- 3 files changed, 112 insertions(+), 79 deletions(-) create mode 100644 tests/litellm/proxy/test_key_management_endpoints.py diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 986e4eb8ed..3ea2903957 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -627,8 +627,7 @@ def prepare_metadata_fields( def prepare_key_update_data( data: Union[UpdateKeyRequest, RegenerateKeyRequest], - existing_key_row, - user_api_key_dict: Optional[UserAPIKeyAuth] = None + existing_key_row ): data_json: dict = data.model_dump(exclude_unset=True) data_json.pop("key", None) @@ -638,30 +637,10 @@ def prepare_key_update_data( continue non_default_values[k] = v - # Check if user is trying to modify user_id or team_id - if user_api_key_dict is not None and user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value: - # For internal users, prevent changing user_id to null or to another user's id - if "user_id" in non_default_values: - new_user_id = non_default_values["user_id"] - if new_user_id is None or (user_api_key_dict.user_id is not None and new_user_id != user_api_key_dict.user_id): - raise HTTPException( - status_code=403, - detail={"error": "You are not authorized to change the user_id of this key"} - ) - - # For internal users, prevent changing team_id to null or to a team they don't belong to - if "team_id" in non_default_values: - new_team_id = non_default_values["team_id"] - user_teams = user_api_key_dict.teams or [] - if new_team_id is None or (len(user_teams) > 0 and new_team_id not in user_teams): - raise HTTPException( - status_code=403, - detail={"error": "You are not authorized to change the team_id to a team you don't belong to"} - ) - - # Preserve user_id if not explicitly provided - if "user_id" not in non_default_values and existing_key_row.user_id: - non_default_values["user_id"] = existing_key_row.user_id + # Ensure user_id is preserved from existing key and not set to null + if existing_key_row.user_id: + if "user_id" not in non_default_values or non_default_values["user_id"] is None: + non_default_values["user_id"] = existing_key_row.user_id if "duration" in non_default_values: duration = non_default_values.pop("duration") @@ -793,7 +772,6 @@ async def update_key_fn( non_default_values = prepare_key_update_data( data=data, existing_key_row=existing_key_row, - user_api_key_dict=user_api_key_dict ) await _enforce_unique_key_alias( @@ -1790,7 +1768,7 @@ async def regenerate_key_fn( if data is not None: # Update with any provided parameters from GenerateKeyRequest non_default_values = prepare_key_update_data( - data=data, existing_key_row=_key_in_db, user_api_key_dict=user_api_key_dict + data=data, existing_key_row=_key_in_db ) verbose_proxy_logger.debug("non_default_values: %s", non_default_values) diff --git a/tests/litellm/proxy/test_key_management_endpoints.py b/tests/litellm/proxy/test_key_management_endpoints.py new file mode 100644 index 0000000000..97a9a646de --- /dev/null +++ b/tests/litellm/proxy/test_key_management_endpoints.py @@ -0,0 +1,106 @@ +import json +import os +import sys +import pytest +from fastapi.testclient import TestClient +from litellm.proxy._types import LiteLLM_VerificationToken, LitellmUserRoles, UserAPIKeyAuth +from litellm.proxy.proxy_server import app + +sys.path.insert( + 0, os.path.abspath("../../..") +) # Adds the parent directory to the system path + +class MockPrismaClient: + def __init__(self): + self.db = self + self.litellm_verificationtoken = self + + async def find_unique(self, where): + return LiteLLM_VerificationToken( + token="sk-existing", + user_id="user-123", + team_id=None, + key_name="test-key", + key_alias="test-alias" + ) + + async def find_first(self, where): + # Used by _enforce_unique_key_alias to check for duplicate key aliases + return None + + async def update(self, where, data): + self.last_update_data = data + return LiteLLM_VerificationToken( + token="sk-existing", + user_id=data.get("user_id", "user-123"), + team_id=None, + key_name="test-key", + key_alias=data.get("key_alias", "test-alias") + ) + + async def get_data(self, token, table_name, query_type="find_unique"): + return await self.find_unique({"token": token}) + + async def update_data(self, token, data): + updated_token = await self.update({"token": token}, data) + # Return in the format expected by the update_key_fn + return { + "data": { + "token": updated_token.token, + "user_id": updated_token.user_id, + "team_id": updated_token.team_id, + "key_name": updated_token.key_name, + "key_alias": updated_token.key_alias + } + } + +@pytest.fixture +def test_client(): + return TestClient(app) + +@pytest.fixture +def mock_prisma(): + return MockPrismaClient() + +@pytest.fixture(autouse=True) +def mock_user_auth(mocker): + return mocker.patch( + "litellm.proxy.auth.user_api_key_auth", + return_value=UserAPIKeyAuth( + api_key="sk-auth", + user_id="user-123", + team_id=None, + user_role=LitellmUserRoles.PROXY_ADMIN.value # Use the correct enum value + ) + ) + +def test_user_id_not_reset_on_key_update(test_client, mock_prisma, mocker): + mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma) + + response = test_client.post( + "/key/update", + headers={"Authorization": "Bearer sk-auth"}, + json={ + "key": "sk-existing", + "key_alias": "new-alias" + } + ) + + assert response.status_code == 200 + assert mock_prisma.last_update_data["user_id"] == "user-123" + +def test_user_id_explicit_none_prevented(test_client, mock_prisma, mocker): + mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma) + + response = test_client.post( + "/key/update", + headers={"Authorization": "Bearer sk-auth"}, + json={ + "key": "sk-existing", + "key_alias": "new-alias", + "user_id": None + } + ) + + assert response.status_code == 200 + assert mock_prisma.last_update_data["user_id"] == "user-123" diff --git a/tests/store_model_in_db_tests/test_team_models.py b/tests/store_model_in_db_tests/test_team_models.py index e27b971aa9..0faa01c8ee 100644 --- a/tests/store_model_in_db_tests/test_team_models.py +++ b/tests/store_model_in_db_tests/test_team_models.py @@ -310,54 +310,3 @@ async def test_team_model_visibility_in_model_info_endpoint(): # Cleanup model_id = model_response.json()["model_info"]["id"] await client.post("/model/delete", json={"id": model_id}, headers=headers) - -@pytest.mark.asyncio -async def test_key_management_with_team_models(): - """ - Test key management with team models: - 1. Create a team (mock) - 2. Verify team model permissions - 3. Test key generation with team model access - 4. Verify key info shows correct permissions - """ - client = AsyncClient(base_url=PROXY_BASE_URL) - headers = {"Authorization": f"Bearer {TEST_MASTER_KEY}"} - - # Use a fixed team ID and model name for testing - team_id = "test_team_123" - model_name = "gpt-4-team-test" - - # Mock team creation response - team_response = await client.post( - "/team/new", - json={ - "team_id": team_id, - "models": [model_name] - }, - headers=headers, - ) - assert team_response.status_code == 200 - - # Generate key for team with model access - key_response = await client.post( - "/key/generate", - json={ - "team_id": team_id, - "models": [model_name] - }, - headers=headers, - ) - assert key_response.status_code == 200 - team_key = key_response.json()["key"] - - # Verify key info shows correct permissions - key_info_response = await client.get( - "/key/info", - headers={"Authorization": f"Bearer {team_key}"}, - ) - assert key_info_response.status_code == 200 - key_info = key_info_response.json() - - # Verify key has correct team and model permissions - assert team_id == key_info.get("team_id"), "Key should be associated with correct team" - assert model_name in key_info.get("models", []), "Key should have access to team model" From cd88a8c80dfa6154b9c264eb0a2251e2d7b59ee1 Mon Sep 17 00:00:00 2001 From: NisanthChsr Date: Thu, 13 Mar 2025 20:24:03 -0400 Subject: [PATCH 3/4] fix: nits --- .../proxy/management_endpoints/key_management_endpoints.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 3ea2903957..82ca7622b7 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -626,8 +626,7 @@ def prepare_metadata_fields( def prepare_key_update_data( - data: Union[UpdateKeyRequest, RegenerateKeyRequest], - existing_key_row + data: Union[UpdateKeyRequest, RegenerateKeyRequest], existing_key_row ): data_json: dict = data.model_dump(exclude_unset=True) data_json.pop("key", None) @@ -770,8 +769,7 @@ async def update_key_fn( ) non_default_values = prepare_key_update_data( - data=data, - existing_key_row=existing_key_row, + data=data, existing_key_row=existing_key_row ) await _enforce_unique_key_alias( From d47bc1d11896519d9efce956f12926a41c0d3525 Mon Sep 17 00:00:00 2001 From: NisanthChsr Date: Sat, 15 Mar 2025 17:19:56 -0400 Subject: [PATCH 4/4] fix: admin can unset the user_id of key --- .../key_management_endpoints.py | 18 +++++--- .../proxy/test_key_management_endpoints.py | 45 ++----------------- .../src/components/key_info_view.tsx | 5 +++ 3 files changed, 22 insertions(+), 46 deletions(-) diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 82ca7622b7..3afdc78c7e 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -636,11 +636,6 @@ def prepare_key_update_data( continue non_default_values[k] = v - # Ensure user_id is preserved from existing key and not set to null - if existing_key_row.user_id: - if "user_id" not in non_default_values or non_default_values["user_id"] is None: - non_default_values["user_id"] = existing_key_row.user_id - if "duration" in non_default_values: duration = non_default_values.pop("duration") if duration and (isinstance(duration, str)) and len(duration) > 0: @@ -772,6 +767,19 @@ async def update_key_fn( data=data, existing_key_row=existing_key_row ) + is_admin = ( + user_api_key_dict.user_role is not None + and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value + ) + + if not is_admin: + # Ensure user_id is preserved when not specified by admin + if non_default_values.get("user_id", None) is None: + non_default_values["user_id"] = existing_key_row.user_id + elif "user_id" not in non_default_values: + # preserve user_id from existing key only when admin does not specify to unset it + non_default_values["user_id"] = existing_key_row.user_id + await _enforce_unique_key_alias( key_alias=non_default_values.get("key_alias", None), prisma_client=prisma_client, diff --git a/tests/litellm/proxy/test_key_management_endpoints.py b/tests/litellm/proxy/test_key_management_endpoints.py index 97a9a646de..f89d842c76 100644 --- a/tests/litellm/proxy/test_key_management_endpoints.py +++ b/tests/litellm/proxy/test_key_management_endpoints.py @@ -20,39 +20,18 @@ class MockPrismaClient: token="sk-existing", user_id="user-123", team_id=None, - key_name="test-key", - key_alias="test-alias" + key_name="test-key" ) async def find_first(self, where): - # Used by _enforce_unique_key_alias to check for duplicate key aliases return None - async def update(self, where, data): - self.last_update_data = data - return LiteLLM_VerificationToken( - token="sk-existing", - user_id=data.get("user_id", "user-123"), - team_id=None, - key_name="test-key", - key_alias=data.get("key_alias", "test-alias") - ) - async def get_data(self, token, table_name, query_type="find_unique"): return await self.find_unique({"token": token}) async def update_data(self, token, data): - updated_token = await self.update({"token": token}, data) - # Return in the format expected by the update_key_fn - return { - "data": { - "token": updated_token.token, - "user_id": updated_token.user_id, - "team_id": updated_token.team_id, - "key_name": updated_token.key_name, - "key_alias": updated_token.key_alias - } - } + self.last_update_data = data # Store the update data for test verification + return {"data": data} @pytest.fixture def test_client(): @@ -70,7 +49,7 @@ def mock_user_auth(mocker): api_key="sk-auth", user_id="user-123", team_id=None, - user_role=LitellmUserRoles.PROXY_ADMIN.value # Use the correct enum value + user_role=LitellmUserRoles.INTERNAL_USER.value ) ) @@ -88,19 +67,3 @@ def test_user_id_not_reset_on_key_update(test_client, mock_prisma, mocker): assert response.status_code == 200 assert mock_prisma.last_update_data["user_id"] == "user-123" - -def test_user_id_explicit_none_prevented(test_client, mock_prisma, mocker): - mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma) - - response = test_client.post( - "/key/update", - headers={"Authorization": "Bearer sk-auth"}, - json={ - "key": "sk-existing", - "key_alias": "new-alias", - "user_id": None - } - ) - - assert response.status_code == 200 - assert mock_prisma.last_update_data["user_id"] == "user-123" diff --git a/ui/litellm-dashboard/src/components/key_info_view.tsx b/ui/litellm-dashboard/src/components/key_info_view.tsx index 9d50be6cf7..387840a2a5 100644 --- a/ui/litellm-dashboard/src/components/key_info_view.tsx +++ b/ui/litellm-dashboard/src/components/key_info_view.tsx @@ -62,6 +62,11 @@ export default function KeyInfoView({ keyId, onClose, keyData, accessToken, user const currentKey = formValues.token; formValues.key = currentKey; + // Explicitly set user_id to null if not present + // if (!('user_id' in formValues)) { + // formValues.user_id = null; + // } + // Convert metadata back to an object if it exists and is a string if (formValues.metadata && typeof formValues.metadata === "string") { try {