From fa396c4b0df8997b15636c5b0f2651bbcebed61a Mon Sep 17 00:00:00 2001 From: Sebastian Sosa <1sebastian1sosa1@gmail.com> Date: Thu, 13 Mar 2025 19:47:41 -0400 Subject: [PATCH 1/3] token foreign entity validation for /key/generate --- .../key_management_endpoints.py | 5 +++ litellm/proxy/management_helpers/utils.py | 43 ++++++++++++++++++- 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 9141d9d14a..a615818b2f 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -392,6 +392,11 @@ async def generate_key_fn( # noqa: PLR0915 user_api_key_cache, user_custom_key_generate, ) + from litellm.proxy.management_helpers.utils import validate_entity_exists + + await validate_entity_exists(prisma_client, "user", data.user_id) + await validate_entity_exists(prisma_client, "team", data.team_id) + await validate_entity_exists(prisma_client, "budget", data.budget_id) verbose_proxy_logger.debug("entered /key/generate") diff --git a/litellm/proxy/management_helpers/utils.py b/litellm/proxy/management_helpers/utils.py index 69a5cf9141..a616690ca4 100644 --- a/litellm/proxy/management_helpers/utils.py +++ b/litellm/proxy/management_helpers/utils.py @@ -3,7 +3,7 @@ import uuid from datetime import datetime from functools import wraps -from typing import Optional, Tuple +from typing import Literal, Optional, Tuple, Any from fastapi import HTTPException, Request @@ -372,3 +372,44 @@ def management_endpoint_wrapper(func): raise e return wrapper + + +async def validate_entity_exists( + prisma_client: Any, + table_name: Literal["budget", "team", "user"], + entity_id: Optional[str], +) -> None: + """ + Validates if an entity exists in the database. + + Args: + prisma_client: Database client + entity_id: ID of the entity to check + entity_type: Human-readable name of the entity type (e.g., "User", "Team") + table_name: Database table name to check + id_field: Field name for the ID in the table (defaults to entity_type_id format) + + Raises: + Exception: If the entity doesn't exist + """ + if entity_id is None: + return + + if prisma_client is None: + raise Exception("Database client is not available") + + if table_name == "budget": + existing_entity = await prisma_client.db.litellm_budgettable.find_unique( + where={"budget_id": entity_id}, # type: ignore + ) + elif table_name == "team": + existing_entity = await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": entity_id}, # type: ignore + ) + elif table_name == "user": + existing_entity = await prisma_client.db.litellm_usertable.find_unique( + where={"user_id": entity_id}, # type: ignore + ) + + if existing_entity is None: + raise Exception(f"'{table_name}' with the id '{entity_id}' does not exist") From 9102eff0f3c3139a59e3570da39875073ce75e63 Mon Sep 17 00:00:00 2001 From: Sebastian Sosa <1sebastian1sosa1@gmail.com> Date: Fri, 14 Mar 2025 01:12:41 -0400 Subject: [PATCH 2/3] e2e key generation validation test --- .../test_key_entity_validation.py | 108 ++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 tests/proxy_admin_ui_tests/test_key_entity_validation.py diff --git a/tests/proxy_admin_ui_tests/test_key_entity_validation.py b/tests/proxy_admin_ui_tests/test_key_entity_validation.py new file mode 100644 index 0000000000..0ee258420e --- /dev/null +++ b/tests/proxy_admin_ui_tests/test_key_entity_validation.py @@ -0,0 +1,108 @@ +import os +import uuid +import pytest +import asyncio +import aiohttp +import json +from typing import Dict, Optional, Tuple, List + +@pytest.mark.asyncio +async def test_key_entity_validation(): + """ + End-to-end test for validating entity existence in key generation: + 1. Create a new user with a random UUID using POST /user/new + 2. Try to create a key with an incorrect user_id (should fail) + 3. Create a key with the correct user_id (should succeed) + 4. Clean up by deleting both the key and user regardless of test results + """ + # Set up base URL and auth + base_url = "http://localhost:4000" + master_key = "sk-1234" # This should match your proxy's master key + headers = { + "Authorization": f"Bearer {master_key}", + "Content-Type": "application/json" + } + + # Variables to store created resources for cleanup + user_id = str(uuid.uuid4()) + invalid_user_id = str(uuid.uuid4()) + key_value = None + + async with aiohttp.ClientSession() as session: + try: + # Step 1: Create a new user + user_data = { + "user_id": user_id, + "user_email": f"test-{user_id[:8]}@example.com", + "max_budget": 100, + "user_role": "internal_user" + } + + async with session.post( + f"{base_url}/user/new", + headers=headers, + json=user_data + ) as response: + assert response.status == 200, f"Failed to create user: {await response.text()}" + user_response = await response.json() + print(f"Successfully created user: {user_id}") + + # Step 2: Try to create a key with an incorrect user_id (should fail) + invalid_key_data = { + "user_id": invalid_user_id, + "models": ["gpt-3.5-turbo"], + "max_budget": 50 + } + + async with session.post( + f"{base_url}/key/generate", + headers=headers, + json=invalid_key_data + ) as response: + response_text = await response.text() + print(f"Response for invalid user ID: {response_text}") + # This should fail with a 400 status code and error message about user not existing + assert response.status != 200, "Key generation with invalid user_id should fail" + assert "'user' with the id" in response_text and "does not exist" in response_text + + # Step 3: Create a key with the correct user_id (should succeed) + valid_key_data = { + "user_id": user_id, + "models": ["gpt-3.5-turbo"], + "max_budget": 50 + } + + async with session.post( + f"{base_url}/key/generate", + headers=headers, + json=valid_key_data + ) as response: + assert response.status == 200, f"Failed to create key: {await response.text()}" + key_response = await response.json() + key_value = key_response.get("key") + print(f"Successfully created key: {key_value}") + assert key_value is not None, "Response should contain a key" + assert key_value.startswith("sk-"), "Key should start with 'sk-'" + + finally: + # Step 4: Clean up - Delete key and user regardless of test results + if key_value: + async with session.delete( + f"{base_url}/key/delete", + headers=headers, + json={"keys": [key_value]} + ) as response: + if response.status == 200: + print(f"Successfully deleted key: {key_value}") + else: + print(f"Warning: Failed to delete key: {await response.text()}") + + async with session.post( + f"{base_url}/user/delete", + headers=headers, + json={"user_ids": [user_id]} + ) as response: + if response.status == 200: + print(f"Successfully deleted user: {user_id}") + else: + print(f"Warning: Failed to delete user: {await response.text()}") \ No newline at end of file From 25693de37f51594b59e5a35d611f27824ad24ec8 Mon Sep 17 00:00:00 2001 From: Sebastian Sosa <1sebastian1sosa1@gmail.com> Date: Fri, 14 Mar 2025 03:02:48 -0400 Subject: [PATCH 3/3] remove team validation due to test regression & update test to ensure user entities exist --- .../key_management_endpoints.py | 2 +- .../test_key_management.py | 321 +++++++++++------- 2 files changed, 198 insertions(+), 125 deletions(-) diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index a615818b2f..deec23ae7d 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -395,7 +395,7 @@ async def generate_key_fn( # noqa: PLR0915 from litellm.proxy.management_helpers.utils import validate_entity_exists await validate_entity_exists(prisma_client, "user", data.user_id) - await validate_entity_exists(prisma_client, "team", data.team_id) + # await validate_entity_exists(prisma_client, "team", data.team_id) # TODO: causing substantial regressions in tests/proxy_admin_ui_tests/ await validate_entity_exists(prisma_client, "budget", data.budget_id) verbose_proxy_logger.debug("entered /key/generate") diff --git a/tests/proxy_admin_ui_tests/test_key_management.py b/tests/proxy_admin_ui_tests/test_key_management.py index 0852d46831..9ed374e1a2 100644 --- a/tests/proxy_admin_ui_tests/test_key_management.py +++ b/tests/proxy_admin_ui_tests/test_key_management.py @@ -907,143 +907,216 @@ async def test_list_key_helper(prisma_client): from litellm.proxy.management_endpoints.key_management_endpoints import ( _list_key_helper, ) - + from litellm.proxy.management_endpoints.key_management_endpoints import ( + delete_key_fn, + ) + import aiohttp # Setup - create multiple test keys setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") await litellm.proxy.proxy_server.prisma_client.connect() + - # Create test data - test_user_id = f"test_user_{uuid.uuid4()}" - test_team_id = f"test_team_{uuid.uuid4()}" - test_key_alias = f"test_alias_{uuid.uuid4()}" + base_url = "http://localhost:4000" + master_key = "sk-1234" # This should match your proxy's master key + headers = { + "Authorization": f"Bearer {master_key}", + "Content-Type": "application/json" + } - # Create test data with clear patterns - test_keys = [] + async with aiohttp.ClientSession() as session: + try: + # Create test data + test_user_id = f"test_user_{uuid.uuid4()}" + test_other_user_id = f"test_other_user_{uuid.uuid4()}" + test_team_id = f"test_team_{uuid.uuid4()}" + test_key_alias = f"test_alias_{uuid.uuid4()}" - # 1. Create 2 keys for test user + test team - for i in range(2): - key = await generate_key_fn( - data=GenerateKeyRequest( + user_data = { + "user_id": test_user_id, + } + other_user_data = { + "user_id": test_other_user_id, + + } + team_data = { + "team_id": test_team_id, + } + + async with session.post( + f"{base_url}/user/new", + headers=headers, + json=user_data + ) as response: + assert response.status == 200, f"Failed to create user: {await response.text()}" + user_response = await response.json() + print(f"Successfully created user: {test_user_id}") + + async with session.post( + f"{base_url}/user/new", + headers=headers, + json=other_user_data + ) as response: + assert response.status == 200, f"Failed to create user: {await response.text()}" + user_response = await response.json() + print(f"Successfully created user: {test_other_user_id}") + + async with session.post( + f"{base_url}/team/new", + headers=headers, + json=team_data + ) as response: + assert response.status == 200, f"Failed to create team: {await response.text()}" + team_response = await response.json() + print(f"Successfully created team: {test_team_id}") + + + + # Create test data with clear patterns + test_keys = [] + + # 1. Create 2 keys for test user + test team + for i in range(2): + key = await generate_key_fn( + data=GenerateKeyRequest( + user_id=test_user_id, + team_id=test_team_id, + key_alias=f"team_key_{uuid.uuid4()}", # Make unique with UUID + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="admin", + ), + ) + test_keys.append(key) + + # 2. Create 1 key for test user (no team) + key = await generate_key_fn( + data=GenerateKeyRequest( + user_id=test_user_id, + key_alias=test_key_alias, # Already unique from earlier UUID generation + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="admin", + ), + ) + test_keys.append(key) + + # 3. Create 2 keys for other users + for i in range(2): + key = await generate_key_fn( + data=GenerateKeyRequest( + user_id=test_other_user_id, + key_alias=f"other_key_{uuid.uuid4()}", # Make unique with UUID + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="admin", + ), + ) + test_keys.append(key) + + # Test 1: Basic pagination + result = await _list_key_helper( + prisma_client=prisma_client, + page=1, + size=2, + user_id=None, + team_id=None, + key_alias=None, + organization_id=None, + ) + assert len(result["keys"]) == 2, "Should return exactly 2 keys" + assert result["total_count"] >= 5, "Should have at least 5 total keys" + assert result["current_page"] == 1 + assert isinstance(result["keys"][0], str), "Should return token strings by default" + + # Test 2: Filter by user_id + result = await _list_key_helper( + prisma_client=prisma_client, + page=1, + size=10, user_id=test_user_id, + team_id=None, + key_alias=None, + organization_id=None, + ) + assert len(result["keys"]) == 4, "Should return exactly 4 keys for test user (1 default key + 3 keys created)" + + # Test 3: Filter by team_id + result = await _list_key_helper( + prisma_client=prisma_client, + page=1, + size=10, + user_id=None, team_id=test_team_id, - key_alias=f"team_key_{uuid.uuid4()}", # Make unique with UUID - ), - user_api_key_dict=UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN, - api_key="sk-1234", - user_id="admin", - ), - ) - test_keys.append(key) + key_alias=None, + organization_id=None, + ) + assert len(result["keys"]) == 2, "Should return exactly 2 keys for test team" - # 2. Create 1 key for test user (no team) - key = await generate_key_fn( - data=GenerateKeyRequest( - user_id=test_user_id, - key_alias=test_key_alias, # Already unique from earlier UUID generation - ), - user_api_key_dict=UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN, - api_key="sk-1234", - user_id="admin", - ), - ) - test_keys.append(key) + # Test 4: Filter by key_alias + result = await _list_key_helper( + prisma_client=prisma_client, + page=1, + size=10, + user_id=None, + team_id=None, + key_alias=test_key_alias, + organization_id=None, + ) + assert len(result["keys"]) == 1, "Should return exactly 1 key with test alias" - # 3. Create 2 keys for other users - for i in range(2): - key = await generate_key_fn( - data=GenerateKeyRequest( - user_id=f"other_user_{i}", - key_alias=f"other_key_{uuid.uuid4()}", # Make unique with UUID - ), - user_api_key_dict=UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN, - api_key="sk-1234", - user_id="admin", - ), - ) - test_keys.append(key) + # Test 5: Return full object + result = await _list_key_helper( + prisma_client=prisma_client, + page=1, + size=10, + user_id=test_user_id, + team_id=None, + key_alias=None, + return_full_object=True, + organization_id=None, + ) + assert all( + isinstance(key, UserAPIKeyAuth) for key in result["keys"] + ), "Should return UserAPIKeyAuth objects" + assert len(result["keys"]) == 4, "Should return exactly 4 keys for test user (1 default key + 3 keys created)" - # Test 1: Basic pagination - result = await _list_key_helper( - prisma_client=prisma_client, - page=1, - size=2, - user_id=None, - team_id=None, - key_alias=None, - organization_id=None, - ) - assert len(result["keys"]) == 2, "Should return exactly 2 keys" - assert result["total_count"] >= 5, "Should have at least 5 total keys" - assert result["current_page"] == 1 - assert isinstance(result["keys"][0], str), "Should return token strings by default" - - # Test 2: Filter by user_id - result = await _list_key_helper( - prisma_client=prisma_client, - page=1, - size=10, - user_id=test_user_id, - team_id=None, - key_alias=None, - organization_id=None, - ) - assert len(result["keys"]) == 3, "Should return exactly 3 keys for test user" - - # Test 3: Filter by team_id - result = await _list_key_helper( - prisma_client=prisma_client, - page=1, - size=10, - user_id=None, - team_id=test_team_id, - key_alias=None, - organization_id=None, - ) - assert len(result["keys"]) == 2, "Should return exactly 2 keys for test team" - - # Test 4: Filter by key_alias - result = await _list_key_helper( - prisma_client=prisma_client, - page=1, - size=10, - user_id=None, - team_id=None, - key_alias=test_key_alias, - organization_id=None, - ) - assert len(result["keys"]) == 1, "Should return exactly 1 key with test alias" - - # Test 5: Return full object - result = await _list_key_helper( - prisma_client=prisma_client, - page=1, - size=10, - user_id=test_user_id, - team_id=None, - key_alias=None, - return_full_object=True, - organization_id=None, - ) - assert all( - isinstance(key, UserAPIKeyAuth) for key in result["keys"] - ), "Should return UserAPIKeyAuth objects" - assert len(result["keys"]) == 3, "Should return exactly 3 keys for test user" - - # Clean up test keys - for key in test_keys: - await delete_key_fn( - data=KeyRequest(keys=[key.key]), - user_api_key_dict=UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN, - api_key="sk-1234", - user_id="admin", - ), - ) + # Clean up test keys + for key in test_keys: + await delete_key_fn( + data=KeyRequest(keys=[key.key]), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="admin", + ), + ) + finally: + async with session.post( + f"{base_url}/user/delete", + headers=headers, + json={"user_ids": [test_user_id]} + ) as response: + if response.status == 200: + print(f"Successfully deleted user: {test_user_id}") + else: + print(f"Warning: Failed to delete user: {await response.text()}") + async with session.post( + f"{base_url}/team/delete", + headers=headers, + json={"team_ids": [test_team_id]} + ) as response: + if response.status == 200: + print(f"Successfully deleted team: {test_team_id}") + else: + print(f"Warning: Failed to delete team: {await response.text()}") @pytest.mark.asyncio async def test_list_key_helper_team_filtering(prisma_client):