mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 10:14:26 +00:00
Merge 25693de37f
into b82af5b826
This commit is contained in:
commit
30c263aa75
4 changed files with 352 additions and 125 deletions
|
@ -400,6 +400,11 @@ async def generate_key_fn( # noqa: PLR0915
|
|||
user_api_key_cache,
|
||||
user_custom_key_generate,
|
||||
)
|
||||
from litellm.proxy.management_helpers.utils import validate_entity_exists
|
||||
|
||||
await validate_entity_exists(prisma_client, "user", data.user_id)
|
||||
# await validate_entity_exists(prisma_client, "team", data.team_id) # TODO: causing substantial regressions in tests/proxy_admin_ui_tests/
|
||||
await validate_entity_exists(prisma_client, "budget", data.budget_id)
|
||||
|
||||
verbose_proxy_logger.debug("entered /key/generate")
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from typing import Optional, Tuple
|
||||
from typing import Literal, Optional, Tuple, Any
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
|
@ -371,3 +371,44 @@ def management_endpoint_wrapper(func):
|
|||
raise e
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
async def validate_entity_exists(
|
||||
prisma_client: Any,
|
||||
table_name: Literal["budget", "team", "user"],
|
||||
entity_id: Optional[str],
|
||||
) -> None:
|
||||
"""
|
||||
Validates if an entity exists in the database.
|
||||
|
||||
Args:
|
||||
prisma_client: Database client
|
||||
entity_id: ID of the entity to check
|
||||
entity_type: Human-readable name of the entity type (e.g., "User", "Team")
|
||||
table_name: Database table name to check
|
||||
id_field: Field name for the ID in the table (defaults to entity_type_id format)
|
||||
|
||||
Raises:
|
||||
Exception: If the entity doesn't exist
|
||||
"""
|
||||
if entity_id is None:
|
||||
return
|
||||
|
||||
if prisma_client is None:
|
||||
raise Exception("Database client is not available")
|
||||
|
||||
if table_name == "budget":
|
||||
existing_entity = await prisma_client.db.litellm_budgettable.find_unique(
|
||||
where={"budget_id": entity_id}, # type: ignore
|
||||
)
|
||||
elif table_name == "team":
|
||||
existing_entity = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": entity_id}, # type: ignore
|
||||
)
|
||||
elif table_name == "user":
|
||||
existing_entity = await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": entity_id}, # type: ignore
|
||||
)
|
||||
|
||||
if existing_entity is None:
|
||||
raise Exception(f"'{table_name}' with the id '{entity_id}' does not exist")
|
||||
|
|
108
tests/proxy_admin_ui_tests/test_key_entity_validation.py
Normal file
108
tests/proxy_admin_ui_tests/test_key_entity_validation.py
Normal file
|
@ -0,0 +1,108 @@
|
|||
import os
|
||||
import uuid
|
||||
import pytest
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import json
|
||||
from typing import Dict, Optional, Tuple, List
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_key_entity_validation():
|
||||
"""
|
||||
End-to-end test for validating entity existence in key generation:
|
||||
1. Create a new user with a random UUID using POST /user/new
|
||||
2. Try to create a key with an incorrect user_id (should fail)
|
||||
3. Create a key with the correct user_id (should succeed)
|
||||
4. Clean up by deleting both the key and user regardless of test results
|
||||
"""
|
||||
# Set up base URL and auth
|
||||
base_url = "http://localhost:4000"
|
||||
master_key = "sk-1234" # This should match your proxy's master key
|
||||
headers = {
|
||||
"Authorization": f"Bearer {master_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# Variables to store created resources for cleanup
|
||||
user_id = str(uuid.uuid4())
|
||||
invalid_user_id = str(uuid.uuid4())
|
||||
key_value = None
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
# Step 1: Create a new user
|
||||
user_data = {
|
||||
"user_id": user_id,
|
||||
"user_email": f"test-{user_id[:8]}@example.com",
|
||||
"max_budget": 100,
|
||||
"user_role": "internal_user"
|
||||
}
|
||||
|
||||
async with session.post(
|
||||
f"{base_url}/user/new",
|
||||
headers=headers,
|
||||
json=user_data
|
||||
) as response:
|
||||
assert response.status == 200, f"Failed to create user: {await response.text()}"
|
||||
user_response = await response.json()
|
||||
print(f"Successfully created user: {user_id}")
|
||||
|
||||
# Step 2: Try to create a key with an incorrect user_id (should fail)
|
||||
invalid_key_data = {
|
||||
"user_id": invalid_user_id,
|
||||
"models": ["gpt-3.5-turbo"],
|
||||
"max_budget": 50
|
||||
}
|
||||
|
||||
async with session.post(
|
||||
f"{base_url}/key/generate",
|
||||
headers=headers,
|
||||
json=invalid_key_data
|
||||
) as response:
|
||||
response_text = await response.text()
|
||||
print(f"Response for invalid user ID: {response_text}")
|
||||
# This should fail with a 400 status code and error message about user not existing
|
||||
assert response.status != 200, "Key generation with invalid user_id should fail"
|
||||
assert "'user' with the id" in response_text and "does not exist" in response_text
|
||||
|
||||
# Step 3: Create a key with the correct user_id (should succeed)
|
||||
valid_key_data = {
|
||||
"user_id": user_id,
|
||||
"models": ["gpt-3.5-turbo"],
|
||||
"max_budget": 50
|
||||
}
|
||||
|
||||
async with session.post(
|
||||
f"{base_url}/key/generate",
|
||||
headers=headers,
|
||||
json=valid_key_data
|
||||
) as response:
|
||||
assert response.status == 200, f"Failed to create key: {await response.text()}"
|
||||
key_response = await response.json()
|
||||
key_value = key_response.get("key")
|
||||
print(f"Successfully created key: {key_value}")
|
||||
assert key_value is not None, "Response should contain a key"
|
||||
assert key_value.startswith("sk-"), "Key should start with 'sk-'"
|
||||
|
||||
finally:
|
||||
# Step 4: Clean up - Delete key and user regardless of test results
|
||||
if key_value:
|
||||
async with session.delete(
|
||||
f"{base_url}/key/delete",
|
||||
headers=headers,
|
||||
json={"keys": [key_value]}
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
print(f"Successfully deleted key: {key_value}")
|
||||
else:
|
||||
print(f"Warning: Failed to delete key: {await response.text()}")
|
||||
|
||||
async with session.post(
|
||||
f"{base_url}/user/delete",
|
||||
headers=headers,
|
||||
json={"user_ids": [user_id]}
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
print(f"Successfully deleted user: {user_id}")
|
||||
else:
|
||||
print(f"Warning: Failed to delete user: {await response.text()}")
|
|
@ -922,143 +922,216 @@ async def test_list_key_helper(prisma_client):
|
|||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
_list_key_helper,
|
||||
)
|
||||
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
delete_key_fn,
|
||||
)
|
||||
import aiohttp
|
||||
# Setup - create multiple test keys
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
||||
|
||||
# Create test data
|
||||
test_user_id = f"test_user_{uuid.uuid4()}"
|
||||
test_team_id = f"test_team_{uuid.uuid4()}"
|
||||
test_key_alias = f"test_alias_{uuid.uuid4()}"
|
||||
base_url = "http://localhost:4000"
|
||||
master_key = "sk-1234" # This should match your proxy's master key
|
||||
headers = {
|
||||
"Authorization": f"Bearer {master_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# Create test data with clear patterns
|
||||
test_keys = []
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
# Create test data
|
||||
test_user_id = f"test_user_{uuid.uuid4()}"
|
||||
test_other_user_id = f"test_other_user_{uuid.uuid4()}"
|
||||
test_team_id = f"test_team_{uuid.uuid4()}"
|
||||
test_key_alias = f"test_alias_{uuid.uuid4()}"
|
||||
|
||||
# 1. Create 2 keys for test user + test team
|
||||
for i in range(2):
|
||||
key = await generate_key_fn(
|
||||
data=GenerateKeyRequest(
|
||||
user_data = {
|
||||
"user_id": test_user_id,
|
||||
}
|
||||
other_user_data = {
|
||||
"user_id": test_other_user_id,
|
||||
|
||||
}
|
||||
team_data = {
|
||||
"team_id": test_team_id,
|
||||
}
|
||||
|
||||
async with session.post(
|
||||
f"{base_url}/user/new",
|
||||
headers=headers,
|
||||
json=user_data
|
||||
) as response:
|
||||
assert response.status == 200, f"Failed to create user: {await response.text()}"
|
||||
user_response = await response.json()
|
||||
print(f"Successfully created user: {test_user_id}")
|
||||
|
||||
async with session.post(
|
||||
f"{base_url}/user/new",
|
||||
headers=headers,
|
||||
json=other_user_data
|
||||
) as response:
|
||||
assert response.status == 200, f"Failed to create user: {await response.text()}"
|
||||
user_response = await response.json()
|
||||
print(f"Successfully created user: {test_other_user_id}")
|
||||
|
||||
async with session.post(
|
||||
f"{base_url}/team/new",
|
||||
headers=headers,
|
||||
json=team_data
|
||||
) as response:
|
||||
assert response.status == 200, f"Failed to create team: {await response.text()}"
|
||||
team_response = await response.json()
|
||||
print(f"Successfully created team: {test_team_id}")
|
||||
|
||||
|
||||
|
||||
# Create test data with clear patterns
|
||||
test_keys = []
|
||||
|
||||
# 1. Create 2 keys for test user + test team
|
||||
for i in range(2):
|
||||
key = await generate_key_fn(
|
||||
data=GenerateKeyRequest(
|
||||
user_id=test_user_id,
|
||||
team_id=test_team_id,
|
||||
key_alias=f"team_key_{uuid.uuid4()}", # Make unique with UUID
|
||||
),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="admin",
|
||||
),
|
||||
)
|
||||
test_keys.append(key)
|
||||
|
||||
# 2. Create 1 key for test user (no team)
|
||||
key = await generate_key_fn(
|
||||
data=GenerateKeyRequest(
|
||||
user_id=test_user_id,
|
||||
key_alias=test_key_alias, # Already unique from earlier UUID generation
|
||||
),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="admin",
|
||||
),
|
||||
)
|
||||
test_keys.append(key)
|
||||
|
||||
# 3. Create 2 keys for other users
|
||||
for i in range(2):
|
||||
key = await generate_key_fn(
|
||||
data=GenerateKeyRequest(
|
||||
user_id=test_other_user_id,
|
||||
key_alias=f"other_key_{uuid.uuid4()}", # Make unique with UUID
|
||||
),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="admin",
|
||||
),
|
||||
)
|
||||
test_keys.append(key)
|
||||
|
||||
# Test 1: Basic pagination
|
||||
result = await _list_key_helper(
|
||||
prisma_client=prisma_client,
|
||||
page=1,
|
||||
size=2,
|
||||
user_id=None,
|
||||
team_id=None,
|
||||
key_alias=None,
|
||||
organization_id=None,
|
||||
)
|
||||
assert len(result["keys"]) == 2, "Should return exactly 2 keys"
|
||||
assert result["total_count"] >= 5, "Should have at least 5 total keys"
|
||||
assert result["current_page"] == 1
|
||||
assert isinstance(result["keys"][0], str), "Should return token strings by default"
|
||||
|
||||
# Test 2: Filter by user_id
|
||||
result = await _list_key_helper(
|
||||
prisma_client=prisma_client,
|
||||
page=1,
|
||||
size=10,
|
||||
user_id=test_user_id,
|
||||
team_id=None,
|
||||
key_alias=None,
|
||||
organization_id=None,
|
||||
)
|
||||
assert len(result["keys"]) == 4, "Should return exactly 4 keys for test user (1 default key + 3 keys created)"
|
||||
|
||||
# Test 3: Filter by team_id
|
||||
result = await _list_key_helper(
|
||||
prisma_client=prisma_client,
|
||||
page=1,
|
||||
size=10,
|
||||
user_id=None,
|
||||
team_id=test_team_id,
|
||||
key_alias=f"team_key_{uuid.uuid4()}", # Make unique with UUID
|
||||
),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="admin",
|
||||
),
|
||||
)
|
||||
test_keys.append(key)
|
||||
key_alias=None,
|
||||
organization_id=None,
|
||||
)
|
||||
assert len(result["keys"]) == 2, "Should return exactly 2 keys for test team"
|
||||
|
||||
# 2. Create 1 key for test user (no team)
|
||||
key = await generate_key_fn(
|
||||
data=GenerateKeyRequest(
|
||||
user_id=test_user_id,
|
||||
key_alias=test_key_alias, # Already unique from earlier UUID generation
|
||||
),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="admin",
|
||||
),
|
||||
)
|
||||
test_keys.append(key)
|
||||
# Test 4: Filter by key_alias
|
||||
result = await _list_key_helper(
|
||||
prisma_client=prisma_client,
|
||||
page=1,
|
||||
size=10,
|
||||
user_id=None,
|
||||
team_id=None,
|
||||
key_alias=test_key_alias,
|
||||
organization_id=None,
|
||||
)
|
||||
assert len(result["keys"]) == 1, "Should return exactly 1 key with test alias"
|
||||
|
||||
# 3. Create 2 keys for other users
|
||||
for i in range(2):
|
||||
key = await generate_key_fn(
|
||||
data=GenerateKeyRequest(
|
||||
user_id=f"other_user_{i}",
|
||||
key_alias=f"other_key_{uuid.uuid4()}", # Make unique with UUID
|
||||
),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="admin",
|
||||
),
|
||||
)
|
||||
test_keys.append(key)
|
||||
# Test 5: Return full object
|
||||
result = await _list_key_helper(
|
||||
prisma_client=prisma_client,
|
||||
page=1,
|
||||
size=10,
|
||||
user_id=test_user_id,
|
||||
team_id=None,
|
||||
key_alias=None,
|
||||
return_full_object=True,
|
||||
organization_id=None,
|
||||
)
|
||||
assert all(
|
||||
isinstance(key, UserAPIKeyAuth) for key in result["keys"]
|
||||
), "Should return UserAPIKeyAuth objects"
|
||||
assert len(result["keys"]) == 4, "Should return exactly 4 keys for test user (1 default key + 3 keys created)"
|
||||
|
||||
# Test 1: Basic pagination
|
||||
result = await _list_key_helper(
|
||||
prisma_client=prisma_client,
|
||||
page=1,
|
||||
size=2,
|
||||
user_id=None,
|
||||
team_id=None,
|
||||
key_alias=None,
|
||||
organization_id=None,
|
||||
)
|
||||
assert len(result["keys"]) == 2, "Should return exactly 2 keys"
|
||||
assert result["total_count"] >= 5, "Should have at least 5 total keys"
|
||||
assert result["current_page"] == 1
|
||||
assert isinstance(result["keys"][0], str), "Should return token strings by default"
|
||||
|
||||
# Test 2: Filter by user_id
|
||||
result = await _list_key_helper(
|
||||
prisma_client=prisma_client,
|
||||
page=1,
|
||||
size=10,
|
||||
user_id=test_user_id,
|
||||
team_id=None,
|
||||
key_alias=None,
|
||||
organization_id=None,
|
||||
)
|
||||
assert len(result["keys"]) == 3, "Should return exactly 3 keys for test user"
|
||||
|
||||
# Test 3: Filter by team_id
|
||||
result = await _list_key_helper(
|
||||
prisma_client=prisma_client,
|
||||
page=1,
|
||||
size=10,
|
||||
user_id=None,
|
||||
team_id=test_team_id,
|
||||
key_alias=None,
|
||||
organization_id=None,
|
||||
)
|
||||
assert len(result["keys"]) == 2, "Should return exactly 2 keys for test team"
|
||||
|
||||
# Test 4: Filter by key_alias
|
||||
result = await _list_key_helper(
|
||||
prisma_client=prisma_client,
|
||||
page=1,
|
||||
size=10,
|
||||
user_id=None,
|
||||
team_id=None,
|
||||
key_alias=test_key_alias,
|
||||
organization_id=None,
|
||||
)
|
||||
assert len(result["keys"]) == 1, "Should return exactly 1 key with test alias"
|
||||
|
||||
# Test 5: Return full object
|
||||
result = await _list_key_helper(
|
||||
prisma_client=prisma_client,
|
||||
page=1,
|
||||
size=10,
|
||||
user_id=test_user_id,
|
||||
team_id=None,
|
||||
key_alias=None,
|
||||
return_full_object=True,
|
||||
organization_id=None,
|
||||
)
|
||||
assert all(
|
||||
isinstance(key, UserAPIKeyAuth) for key in result["keys"]
|
||||
), "Should return UserAPIKeyAuth objects"
|
||||
assert len(result["keys"]) == 3, "Should return exactly 3 keys for test user"
|
||||
|
||||
# Clean up test keys
|
||||
for key in test_keys:
|
||||
await delete_key_fn(
|
||||
data=KeyRequest(keys=[key.key]),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="admin",
|
||||
),
|
||||
)
|
||||
# Clean up test keys
|
||||
for key in test_keys:
|
||||
await delete_key_fn(
|
||||
data=KeyRequest(keys=[key.key]),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="admin",
|
||||
),
|
||||
)
|
||||
finally:
|
||||
async with session.post(
|
||||
f"{base_url}/user/delete",
|
||||
headers=headers,
|
||||
json={"user_ids": [test_user_id]}
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
print(f"Successfully deleted user: {test_user_id}")
|
||||
else:
|
||||
print(f"Warning: Failed to delete user: {await response.text()}")
|
||||
|
||||
async with session.post(
|
||||
f"{base_url}/team/delete",
|
||||
headers=headers,
|
||||
json={"team_ids": [test_team_id]}
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
print(f"Successfully deleted team: {test_team_id}")
|
||||
else:
|
||||
print(f"Warning: Failed to delete team: {await response.text()}")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_key_helper_team_filtering(prisma_client):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue