test(test_key_management_endpoints.py): add unit test

This commit is contained in:
Krrish Dholakia 2025-04-21 14:42:18 -07:00
parent e738a77d4a
commit e4c88ce68d

View file

@ -46,3 +46,54 @@ async def test_list_keys():
assert json.dumps({"team_id": {"not": "litellm-dashboard"}}) in json.dumps(
where_condition
)
@pytest.mark.asyncio
async def test_key_token_handling(monkeypatch):
"""
Test that token handling in key generation follows the expected behavior:
1. token field should not equal key field
2. if token_id exists, it should equal token field
"""
mock_prisma_client = AsyncMock()
mock_insert_data = AsyncMock(
return_value=MagicMock(token="hashed_token_123", litellm_budget_table=None)
)
mock_prisma_client.insert_data = mock_insert_data
mock_prisma_client.db = MagicMock()
mock_prisma_client.db.litellm_verificationtoken = MagicMock()
mock_prisma_client.db.litellm_verificationtoken.find_unique = AsyncMock(
return_value=None
)
mock_prisma_client.db.litellm_verificationtoken.find_many = AsyncMock(
return_value=[]
)
mock_prisma_client.db.litellm_verificationtoken.count = AsyncMock(return_value=0)
mock_prisma_client.db.litellm_verificationtoken.update = AsyncMock(
return_value=MagicMock(token="hashed_token_123", litellm_budget_table=None)
)
from litellm.proxy._types import GenerateKeyRequest, LitellmUserRoles
from litellm.proxy.auth.user_api_key_auth import UserAPIKeyAuth
from litellm.proxy.management_endpoints.key_management_endpoints import (
generate_key_fn,
)
from litellm.proxy.proxy_server import prisma_client
# Use monkeypatch to set the prisma_client
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
# Test key generation
response = await generate_key_fn(
data=GenerateKeyRequest(),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="1234"
),
)
# Verify token handling
assert response.key != response.token, "Token should not equal key"
if hasattr(response, "token_id"):
assert (
response.token == response.token_id
), "Token should equal token_id if token_id exists"