This commit is contained in:
Sebastian Sosa 2025-04-24 00:58:15 -07:00 committed by GitHub
commit 30c263aa75
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 352 additions and 125 deletions

View file

@ -400,6 +400,11 @@ async def generate_key_fn( # noqa: PLR0915
user_api_key_cache,
user_custom_key_generate,
)
from litellm.proxy.management_helpers.utils import validate_entity_exists
await validate_entity_exists(prisma_client, "user", data.user_id)
# await validate_entity_exists(prisma_client, "team", data.team_id) # TODO: causing substantial regressions in tests/proxy_admin_ui_tests/
await validate_entity_exists(prisma_client, "budget", data.budget_id)
verbose_proxy_logger.debug("entered /key/generate")

View file

@ -3,7 +3,7 @@
import uuid
from datetime import datetime
from functools import wraps
from typing import Optional, Tuple
from typing import Literal, Optional, Tuple, Any
from fastapi import HTTPException, Request
@ -371,3 +371,44 @@ def management_endpoint_wrapper(func):
raise e
return wrapper
async def validate_entity_exists(
prisma_client: Any,
table_name: Literal["budget", "team", "user"],
entity_id: Optional[str],
) -> None:
"""
Validates if an entity exists in the database.
Args:
prisma_client: Database client
entity_id: ID of the entity to check
entity_type: Human-readable name of the entity type (e.g., "User", "Team")
table_name: Database table name to check
id_field: Field name for the ID in the table (defaults to entity_type_id format)
Raises:
Exception: If the entity doesn't exist
"""
if entity_id is None:
return
if prisma_client is None:
raise Exception("Database client is not available")
if table_name == "budget":
existing_entity = await prisma_client.db.litellm_budgettable.find_unique(
where={"budget_id": entity_id}, # type: ignore
)
elif table_name == "team":
existing_entity = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": entity_id}, # type: ignore
)
elif table_name == "user":
existing_entity = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": entity_id}, # type: ignore
)
if existing_entity is None:
raise Exception(f"'{table_name}' with the id '{entity_id}' does not exist")

View file

@ -0,0 +1,108 @@
import os
import uuid
import pytest
import asyncio
import aiohttp
import json
from typing import Dict, Optional, Tuple, List
@pytest.mark.asyncio
async def test_key_entity_validation():
"""
End-to-end test for validating entity existence in key generation:
1. Create a new user with a random UUID using POST /user/new
2. Try to create a key with an incorrect user_id (should fail)
3. Create a key with the correct user_id (should succeed)
4. Clean up by deleting both the key and user regardless of test results
"""
# Set up base URL and auth
base_url = "http://localhost:4000"
master_key = "sk-1234" # This should match your proxy's master key
headers = {
"Authorization": f"Bearer {master_key}",
"Content-Type": "application/json"
}
# Variables to store created resources for cleanup
user_id = str(uuid.uuid4())
invalid_user_id = str(uuid.uuid4())
key_value = None
async with aiohttp.ClientSession() as session:
try:
# Step 1: Create a new user
user_data = {
"user_id": user_id,
"user_email": f"test-{user_id[:8]}@example.com",
"max_budget": 100,
"user_role": "internal_user"
}
async with session.post(
f"{base_url}/user/new",
headers=headers,
json=user_data
) as response:
assert response.status == 200, f"Failed to create user: {await response.text()}"
user_response = await response.json()
print(f"Successfully created user: {user_id}")
# Step 2: Try to create a key with an incorrect user_id (should fail)
invalid_key_data = {
"user_id": invalid_user_id,
"models": ["gpt-3.5-turbo"],
"max_budget": 50
}
async with session.post(
f"{base_url}/key/generate",
headers=headers,
json=invalid_key_data
) as response:
response_text = await response.text()
print(f"Response for invalid user ID: {response_text}")
# This should fail with a 400 status code and error message about user not existing
assert response.status != 200, "Key generation with invalid user_id should fail"
assert "'user' with the id" in response_text and "does not exist" in response_text
# Step 3: Create a key with the correct user_id (should succeed)
valid_key_data = {
"user_id": user_id,
"models": ["gpt-3.5-turbo"],
"max_budget": 50
}
async with session.post(
f"{base_url}/key/generate",
headers=headers,
json=valid_key_data
) as response:
assert response.status == 200, f"Failed to create key: {await response.text()}"
key_response = await response.json()
key_value = key_response.get("key")
print(f"Successfully created key: {key_value}")
assert key_value is not None, "Response should contain a key"
assert key_value.startswith("sk-"), "Key should start with 'sk-'"
finally:
# Step 4: Clean up - Delete key and user regardless of test results
if key_value:
async with session.delete(
f"{base_url}/key/delete",
headers=headers,
json={"keys": [key_value]}
) as response:
if response.status == 200:
print(f"Successfully deleted key: {key_value}")
else:
print(f"Warning: Failed to delete key: {await response.text()}")
async with session.post(
f"{base_url}/user/delete",
headers=headers,
json={"user_ids": [user_id]}
) as response:
if response.status == 200:
print(f"Successfully deleted user: {user_id}")
else:
print(f"Warning: Failed to delete user: {await response.text()}")

View file

@ -922,143 +922,216 @@ async def test_list_key_helper(prisma_client):
from litellm.proxy.management_endpoints.key_management_endpoints import (
_list_key_helper,
)
from litellm.proxy.management_endpoints.key_management_endpoints import (
delete_key_fn,
)
import aiohttp
# Setup - create multiple test keys
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
# Create test data
test_user_id = f"test_user_{uuid.uuid4()}"
test_team_id = f"test_team_{uuid.uuid4()}"
test_key_alias = f"test_alias_{uuid.uuid4()}"
base_url = "http://localhost:4000"
master_key = "sk-1234" # This should match your proxy's master key
headers = {
"Authorization": f"Bearer {master_key}",
"Content-Type": "application/json"
}
# Create test data with clear patterns
test_keys = []
async with aiohttp.ClientSession() as session:
try:
# Create test data
test_user_id = f"test_user_{uuid.uuid4()}"
test_other_user_id = f"test_other_user_{uuid.uuid4()}"
test_team_id = f"test_team_{uuid.uuid4()}"
test_key_alias = f"test_alias_{uuid.uuid4()}"
# 1. Create 2 keys for test user + test team
for i in range(2):
key = await generate_key_fn(
data=GenerateKeyRequest(
user_data = {
"user_id": test_user_id,
}
other_user_data = {
"user_id": test_other_user_id,
}
team_data = {
"team_id": test_team_id,
}
async with session.post(
f"{base_url}/user/new",
headers=headers,
json=user_data
) as response:
assert response.status == 200, f"Failed to create user: {await response.text()}"
user_response = await response.json()
print(f"Successfully created user: {test_user_id}")
async with session.post(
f"{base_url}/user/new",
headers=headers,
json=other_user_data
) as response:
assert response.status == 200, f"Failed to create user: {await response.text()}"
user_response = await response.json()
print(f"Successfully created user: {test_other_user_id}")
async with session.post(
f"{base_url}/team/new",
headers=headers,
json=team_data
) as response:
assert response.status == 200, f"Failed to create team: {await response.text()}"
team_response = await response.json()
print(f"Successfully created team: {test_team_id}")
# Create test data with clear patterns
test_keys = []
# 1. Create 2 keys for test user + test team
for i in range(2):
key = await generate_key_fn(
data=GenerateKeyRequest(
user_id=test_user_id,
team_id=test_team_id,
key_alias=f"team_key_{uuid.uuid4()}", # Make unique with UUID
),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="admin",
),
)
test_keys.append(key)
# 2. Create 1 key for test user (no team)
key = await generate_key_fn(
data=GenerateKeyRequest(
user_id=test_user_id,
key_alias=test_key_alias, # Already unique from earlier UUID generation
),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="admin",
),
)
test_keys.append(key)
# 3. Create 2 keys for other users
for i in range(2):
key = await generate_key_fn(
data=GenerateKeyRequest(
user_id=test_other_user_id,
key_alias=f"other_key_{uuid.uuid4()}", # Make unique with UUID
),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="admin",
),
)
test_keys.append(key)
# Test 1: Basic pagination
result = await _list_key_helper(
prisma_client=prisma_client,
page=1,
size=2,
user_id=None,
team_id=None,
key_alias=None,
organization_id=None,
)
assert len(result["keys"]) == 2, "Should return exactly 2 keys"
assert result["total_count"] >= 5, "Should have at least 5 total keys"
assert result["current_page"] == 1
assert isinstance(result["keys"][0], str), "Should return token strings by default"
# Test 2: Filter by user_id
result = await _list_key_helper(
prisma_client=prisma_client,
page=1,
size=10,
user_id=test_user_id,
team_id=None,
key_alias=None,
organization_id=None,
)
assert len(result["keys"]) == 4, "Should return exactly 4 keys for test user (1 default key + 3 keys created)"
# Test 3: Filter by team_id
result = await _list_key_helper(
prisma_client=prisma_client,
page=1,
size=10,
user_id=None,
team_id=test_team_id,
key_alias=f"team_key_{uuid.uuid4()}", # Make unique with UUID
),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="admin",
),
)
test_keys.append(key)
key_alias=None,
organization_id=None,
)
assert len(result["keys"]) == 2, "Should return exactly 2 keys for test team"
# 2. Create 1 key for test user (no team)
key = await generate_key_fn(
data=GenerateKeyRequest(
user_id=test_user_id,
key_alias=test_key_alias, # Already unique from earlier UUID generation
),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="admin",
),
)
test_keys.append(key)
# Test 4: Filter by key_alias
result = await _list_key_helper(
prisma_client=prisma_client,
page=1,
size=10,
user_id=None,
team_id=None,
key_alias=test_key_alias,
organization_id=None,
)
assert len(result["keys"]) == 1, "Should return exactly 1 key with test alias"
# 3. Create 2 keys for other users
for i in range(2):
key = await generate_key_fn(
data=GenerateKeyRequest(
user_id=f"other_user_{i}",
key_alias=f"other_key_{uuid.uuid4()}", # Make unique with UUID
),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="admin",
),
)
test_keys.append(key)
# Test 5: Return full object
result = await _list_key_helper(
prisma_client=prisma_client,
page=1,
size=10,
user_id=test_user_id,
team_id=None,
key_alias=None,
return_full_object=True,
organization_id=None,
)
assert all(
isinstance(key, UserAPIKeyAuth) for key in result["keys"]
), "Should return UserAPIKeyAuth objects"
assert len(result["keys"]) == 4, "Should return exactly 4 keys for test user (1 default key + 3 keys created)"
# Test 1: Basic pagination
result = await _list_key_helper(
prisma_client=prisma_client,
page=1,
size=2,
user_id=None,
team_id=None,
key_alias=None,
organization_id=None,
)
assert len(result["keys"]) == 2, "Should return exactly 2 keys"
assert result["total_count"] >= 5, "Should have at least 5 total keys"
assert result["current_page"] == 1
assert isinstance(result["keys"][0], str), "Should return token strings by default"
# Test 2: Filter by user_id
result = await _list_key_helper(
prisma_client=prisma_client,
page=1,
size=10,
user_id=test_user_id,
team_id=None,
key_alias=None,
organization_id=None,
)
assert len(result["keys"]) == 3, "Should return exactly 3 keys for test user"
# Test 3: Filter by team_id
result = await _list_key_helper(
prisma_client=prisma_client,
page=1,
size=10,
user_id=None,
team_id=test_team_id,
key_alias=None,
organization_id=None,
)
assert len(result["keys"]) == 2, "Should return exactly 2 keys for test team"
# Test 4: Filter by key_alias
result = await _list_key_helper(
prisma_client=prisma_client,
page=1,
size=10,
user_id=None,
team_id=None,
key_alias=test_key_alias,
organization_id=None,
)
assert len(result["keys"]) == 1, "Should return exactly 1 key with test alias"
# Test 5: Return full object
result = await _list_key_helper(
prisma_client=prisma_client,
page=1,
size=10,
user_id=test_user_id,
team_id=None,
key_alias=None,
return_full_object=True,
organization_id=None,
)
assert all(
isinstance(key, UserAPIKeyAuth) for key in result["keys"]
), "Should return UserAPIKeyAuth objects"
assert len(result["keys"]) == 3, "Should return exactly 3 keys for test user"
# Clean up test keys
for key in test_keys:
await delete_key_fn(
data=KeyRequest(keys=[key.key]),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="admin",
),
)
# Clean up test keys
for key in test_keys:
await delete_key_fn(
data=KeyRequest(keys=[key.key]),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="admin",
),
)
finally:
async with session.post(
f"{base_url}/user/delete",
headers=headers,
json={"user_ids": [test_user_id]}
) as response:
if response.status == 200:
print(f"Successfully deleted user: {test_user_id}")
else:
print(f"Warning: Failed to delete user: {await response.text()}")
async with session.post(
f"{base_url}/team/delete",
headers=headers,
json={"team_ids": [test_team_id]}
) as response:
if response.status == 200:
print(f"Successfully deleted team: {test_team_id}")
else:
print(f"Warning: Failed to delete team: {await response.text()}")
@pytest.mark.asyncio
async def test_list_key_helper_team_filtering(prisma_client):