diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index c449a21b02..6b2569eb3c 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -787,6 +787,7 @@ class NewUserResponse(GenerateKeyResponse): ] = None teams: Optional[list] = None user_alias: Optional[str] = None + model_max_budget: Optional[dict] = None class UpdateUserRequest(GenerateRequestBase): diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py index d9f64ea752..8c8086cf81 100644 --- a/litellm/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -142,7 +142,6 @@ async def new_user( data_json = data.json() # type: ignore data_json = _update_internal_new_user_params(data_json, data) response = await generate_key_helper_fn(request_type="user", **data_json) - # Admin UI Logic # Add User to Team and Organization # if team_id passed add this user to the team @@ -220,6 +219,7 @@ async def new_user( tpm_limit=response.get("tpm_limit", None), rpm_limit=response.get("rpm_limit", None), budget_duration=response.get("budget_duration", None), + model_max_budget=response.get("model_max_budget", None), ) diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index b9b462a4e8..e1efa23df6 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -65,7 +65,7 @@ def _get_user_in_team( return None -def _is_allowed_to_create_key( +def _is_allowed_to_make_key_request( user_api_key_dict: UserAPIKeyAuth, user_id: Optional[str], team_id: Optional[str] ) -> bool: """ @@ -266,6 +266,40 @@ def key_generation_check( ) +def common_key_access_checks( + user_api_key_dict: UserAPIKeyAuth, + data: Union[GenerateKeyRequest, UpdateKeyRequest], + llm_router: Optional[Router], + premium_user: bool, +) -> Literal[True]: + """ + Check if user is allowed to make a key request, for this key + """ + try: + _is_allowed_to_make_key_request( + user_api_key_dict=user_api_key_dict, + user_id=data.user_id, + team_id=data.team_id, + ) + except AssertionError as e: + raise HTTPException( + status_code=403, + detail=str(e), + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail=str(e), + ) + + _check_model_access_group( + models=data.models, + llm_router=llm_router, + premium_user=premium_user, + ) + return True + + router = APIRouter() @@ -381,25 +415,9 @@ async def generate_key_fn( # noqa: PLR0915 data=data, ) - try: - _is_allowed_to_create_key( - user_api_key_dict=user_api_key_dict, - user_id=data.user_id, - team_id=data.team_id, - ) - except AssertionError as e: - raise HTTPException( - status_code=403, - detail=str(e), - ) - except Exception as e: - raise HTTPException( - status_code=500, - detail=str(e), - ) - - _check_model_access_group( - models=data.models, + common_key_access_checks( + user_api_key_dict=user_api_key_dict, + data=data, llm_router=llm_router, premium_user=premium_user, ) @@ -684,6 +702,8 @@ async def update_key_fn( ``` """ from litellm.proxy.proxy_server import ( + llm_router, + premium_user, prisma_client, proxy_logging_obj, user_api_key_cache, @@ -692,10 +712,18 @@ async def update_key_fn( try: data_json: dict = data.model_dump(exclude_unset=True, exclude_none=True) key = data_json.pop("key") + # get the row from db if prisma_client is None: raise Exception("Not connected to DB!") + common_key_access_checks( + user_api_key_dict=user_api_key_dict, + data=data, + llm_router=llm_router, + premium_user=premium_user, + ) + existing_key_row = await prisma_client.get_data( token=data.key, table_name="key", query_type="find_unique" ) @@ -1412,6 +1440,13 @@ async def delete_verification_tokens( ): await prisma_client.delete_data(tokens=[key.token]) deleted_tokens.append(key.token) + else: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={ + "error": "You are not authorized to delete this key" + }, + ) tasks.append(_delete_key(key)) await asyncio.gather(*tasks) diff --git a/tests/proxy_admin_ui_tests/test_key_management.py b/tests/proxy_admin_ui_tests/test_key_management.py index 12c50ac0cc..ae80b05b70 100644 --- a/tests/proxy_admin_ui_tests/test_key_management.py +++ b/tests/proxy_admin_ui_tests/test_key_management.py @@ -842,7 +842,15 @@ async def test_key_update_with_model_specific_params(prisma_client): "litellm_budget_table": None, "token": token_hash, } - await update_key_fn(request=request, data=UpdateKeyRequest(**args)) + await update_key_fn( + request=request, + data=UpdateKeyRequest(**args), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) @pytest.mark.asyncio diff --git a/tests/proxy_unit_tests/test_key_generate_prisma.py b/tests/proxy_unit_tests/test_key_generate_prisma.py index 30c2be9911..538cd2aeee 100644 --- a/tests/proxy_unit_tests/test_key_generate_prisma.py +++ b/tests/proxy_unit_tests/test_key_generate_prisma.py @@ -1314,6 +1314,11 @@ def test_generate_and_update_key(prisma_client): budget_duration="1mo", max_budget=100, ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), ) print("response1=", response1) @@ -1322,6 +1327,11 @@ def test_generate_and_update_key(prisma_client): response2 = await update_key_fn( request=Request, data=UpdateKeyRequest(key=generated_key, team_id=_team_2), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), ) print("response2=", response2) @@ -2956,7 +2966,11 @@ async def test_generate_key_with_model_tpm_limit(prisma_client): _request = Request(scope={"type": "http"}) _request._url = URL(url="/update/key") - await update_key_fn(data=request, request=_request) + await update_key_fn( + data=request, + request=_request, + user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), + ) result = await info_key_fn( key=generated_key, user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), @@ -3017,7 +3031,11 @@ async def test_generate_key_with_guardrails(prisma_client): _request = Request(scope={"type": "http"}) _request._url = URL(url="/update/key") - await update_key_fn(data=request, request=_request) + await update_key_fn( + data=request, + request=_request, + user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), + ) result = await info_key_fn( key=generated_key, user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), @@ -3710,6 +3728,11 @@ async def test_key_alias_uniqueness(prisma_client): await update_key_fn( data=UpdateKeyRequest(key=key3.key, key_alias=unique_alias), request=Request(scope={"type": "http"}), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), ) pytest.fail("Should not be able to update a key to use an existing alias") except Exception as e: @@ -3719,6 +3742,11 @@ async def test_key_alias_uniqueness(prisma_client): updated_key = await update_key_fn( data=UpdateKeyRequest(key=key1.key, key_alias=unique_alias), request=Request(scope={"type": "http"}), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), ) assert updated_key is not None diff --git a/tests/proxy_unit_tests/test_proxy_utils.py b/tests/proxy_unit_tests/test_proxy_utils.py index 3f0b127af4..36f9b6652f 100644 --- a/tests/proxy_unit_tests/test_proxy_utils.py +++ b/tests/proxy_unit_tests/test_proxy_utils.py @@ -1216,14 +1216,14 @@ def test_litellm_verification_token_view_response_with_budget_table( ) -def test_is_allowed_to_create_key(): +def test_is_allowed_to_make_key_request(): from litellm.proxy._types import LitellmUserRoles from litellm.proxy.management_endpoints.key_management_endpoints import ( - _is_allowed_to_create_key, + _is_allowed_to_make_key_request, ) assert ( - _is_allowed_to_create_key( + _is_allowed_to_make_key_request( user_api_key_dict=UserAPIKeyAuth( user_id="test_user_id", user_role=LitellmUserRoles.PROXY_ADMIN ), @@ -1234,7 +1234,7 @@ def test_is_allowed_to_create_key(): ) assert ( - _is_allowed_to_create_key( + _is_allowed_to_make_key_request( user_api_key_dict=UserAPIKeyAuth( user_id="test_user_id", user_role=LitellmUserRoles.INTERNAL_USER, @@ -1553,6 +1553,7 @@ async def test_spend_logs_cleanup_after_error(): mock_client.spend_log_transactions == original_logs[100:] ), "Should remove processed logs even after error" + def test_provider_specific_header(): from litellm.proxy.litellm_pre_call_utils import ( add_provider_specific_headers_to_request, diff --git a/tests/test_users.py b/tests/test_users.py index 812783681c..f2923d2c8d 100644 --- a/tests/test_users.py +++ b/tests/test_users.py @@ -315,3 +315,142 @@ async def test_user_model_access(): key=key, model="groq/claude-3-5-haiku-20241022", ) + + +import json +import uuid +import pytest +import aiohttp +from typing import Dict, Tuple + + +async def setup_test_users(session: aiohttp.ClientSession) -> Tuple[Dict, Dict]: + """ + Create two test users and an additional key for the first user. + Returns tuple of (user1_data, user2_data) where each contains user info and keys. + """ + # Create two test users + user1 = await new_user( + session=session, + i=0, + budget=100, + budget_duration="30d", + models=["anthropic.claude-3-5-sonnet-20240620-v1:0"], + ) + + user2 = await new_user( + session=session, + i=1, + budget=100, + budget_duration="30d", + models=["anthropic.claude-3-5-sonnet-20240620-v1:0"], + ) + + print("\nCreated two test users:") + print(f"User 1 ID: {user1['user_id']}") + print(f"User 2 ID: {user2['user_id']}") + + # Create an additional key for user1 + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {user1['key']}", + } + + key_payload = { + "user_id": user1["user_id"], + "duration": "7d", + "key_alias": f"test_key_{uuid.uuid4()}", + "models": ["anthropic.claude-3-5-sonnet-20240620-v1:0"], + } + + print("\nGenerating additional key for user1...") + key_response = await session.post( + f"http://0.0.0.0:4000/key/generate", headers=headers, json=key_payload + ) + + assert key_response.status == 200, "Failed to generate additional key for user1" + user1_additional_key = await key_response.json() + + print(f"\nGenerated key details:") + print(json.dumps(user1_additional_key, indent=2)) + + # Return both users' data including the additional key + return { + "user_data": user1, + "additional_key": user1_additional_key, + "headers": headers, + }, { + "user_data": user2, + "headers": { + "Content-Type": "application/json", + "Authorization": f"Bearer {user2['key']}", + }, + } + + +async def print_response_details(response: aiohttp.ClientResponse) -> None: + """Helper function to print response details""" + print("\nResponse Details:") + print(f"Status Code: {response.status}") + print("\nResponse Content:") + try: + formatted_json = json.dumps(await response.json(), indent=2) + print(formatted_json) + except json.JSONDecodeError: + print(await response.text()) + + +@pytest.mark.asyncio +async def test_key_update_user_isolation(): + """Test that a user cannot update a key that belongs to another user""" + async with aiohttp.ClientSession() as session: + user1_data, user2_data = await setup_test_users(session) + + # Try to update the key to belong to user2 + update_payload = { + "key": user1_data["additional_key"]["key"], + "user_id": user2_data["user_data"][ + "user_id" + ], # Attempting to change ownership + "metadata": {"purpose": "testing_user_isolation", "environment": "test"}, + } + + print("\nAttempting to update key ownership to user2...") + update_response = await session.post( + f"http://0.0.0.0:4000/key/update", + headers=user1_data["headers"], # Using user1's headers + json=update_payload, + ) + + await print_response_details(update_response) + + # Verify update attempt was rejected + assert ( + update_response.status == 403 + ), "Request should have been rejected with 403 status code" + + +@pytest.mark.asyncio +async def test_key_delete_user_isolation(): + """Test that a user cannot delete a key that belongs to another user""" + async with aiohttp.ClientSession() as session: + user1_data, user2_data = await setup_test_users(session) + + # Try to delete user1's additional key using user2's credentials + delete_payload = { + "keys": [user1_data["additional_key"]["key"]], + } + + print("\nAttempting to delete user1's key using user2's credentials...") + delete_response = await session.post( + f"http://0.0.0.0:4000/key/delete", + headers=user2_data["headers"], + json=delete_payload, + ) + + await print_response_details(delete_response) + + # Verify delete attempt was rejected + assert ( + delete_response.status == 403 + ), "Request should have been rejected with 403 status code"