diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index c68f585405..491c64611f 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -650,9 +650,9 @@ class GenerateRequestBase(LiteLLMPydanticObjectBase): allowed_cache_controls: Optional[list] = [] config: Optional[dict] = {} permissions: Optional[dict] = {} - model_max_budget: Optional[ - dict - ] = {} # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {} + model_max_budget: Optional[dict] = ( + {} + ) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {} model_config = ConfigDict(protected_namespaces=()) model_rpm_limit: Optional[dict] = None @@ -908,12 +908,12 @@ class NewCustomerRequest(BudgetNewRequest): alias: Optional[str] = None # human-friendly alias blocked: bool = False # allow/disallow requests for this end-user budget_id: Optional[str] = None # give either a budget_id or max_budget - allowed_model_region: Optional[ - AllowedModelRegion - ] = None # require all user requests to use models in this specific region - default_model: Optional[ - str - ] = None # if no equivalent model in allowed region - default all requests to this model + allowed_model_region: Optional[AllowedModelRegion] = ( + None # require all user requests to use models in this specific region + ) + default_model: Optional[str] = ( + None # if no equivalent model in allowed region - default all requests to this model + ) @model_validator(mode="before") @classmethod @@ -935,12 +935,12 @@ class UpdateCustomerRequest(LiteLLMPydanticObjectBase): blocked: bool = False # allow/disallow requests for this end-user max_budget: Optional[float] = None budget_id: Optional[str] = None # give either a budget_id or max_budget - allowed_model_region: Optional[ - AllowedModelRegion - ] = None # require all user requests to use models in this specific region - default_model: Optional[ - str - ] = None # if no equivalent model in allowed region - default all requests to this model + allowed_model_region: Optional[AllowedModelRegion] = ( + None # require all user requests to use models in this specific region + ) + default_model: Optional[str] = ( + None # if no equivalent model in allowed region - default all requests to this model + ) class DeleteCustomerRequest(LiteLLMPydanticObjectBase): @@ -1076,9 +1076,9 @@ class BlockKeyRequest(LiteLLMPydanticObjectBase): class AddTeamCallback(LiteLLMPydanticObjectBase): callback_name: str - callback_type: Optional[ - Literal["success", "failure", "success_and_failure"] - ] = "success_and_failure" + callback_type: Optional[Literal["success", "failure", "success_and_failure"]] = ( + "success_and_failure" + ) callback_vars: Dict[str, str] @model_validator(mode="before") @@ -1335,9 +1335,9 @@ class ConfigList(LiteLLMPydanticObjectBase): stored_in_db: Optional[bool] field_default_value: Any premium_field: bool = False - nested_fields: Optional[ - List[FieldDetail] - ] = None # For nested dictionary or Pydantic fields + nested_fields: Optional[List[FieldDetail]] = ( + None # For nested dictionary or Pydantic fields + ) class ConfigGeneralSettings(LiteLLMPydanticObjectBase): @@ -1604,9 +1604,9 @@ class LiteLLM_OrganizationMembershipTable(LiteLLMPydanticObjectBase): budget_id: Optional[str] = None created_at: datetime updated_at: datetime - user: Optional[ - Any - ] = None # You might want to replace 'Any' with a more specific type if available + user: Optional[Any] = ( + None # You might want to replace 'Any' with a more specific type if available + ) litellm_budget_table: Optional[LiteLLM_BudgetTable] = None model_config = ConfigDict(protected_namespaces=()) @@ -1671,6 +1671,8 @@ class LiteLLM_UserTable(LiteLLMPydanticObjectBase): budget_duration: Optional[str] = None budget_reset_at: Optional[datetime] = None metadata: Optional[dict] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None @model_validator(mode="before") @classmethod @@ -2352,9 +2354,9 @@ class TeamModelDeleteRequest(BaseModel): # Organization Member Requests class OrganizationMemberAddRequest(OrgMemberAddRequest): organization_id: str - max_budget_in_organization: Optional[ - float - ] = None # Users max budget within the organization + max_budget_in_organization: Optional[float] = ( + None # Users max budget within the organization + ) class OrganizationMemberDeleteRequest(MemberDeleteRequest): @@ -2543,9 +2545,9 @@ class ProviderBudgetResponse(LiteLLMPydanticObjectBase): Maps provider names to their budget configs. """ - providers: Dict[ - str, ProviderBudgetResponseObject - ] = {} # Dictionary mapping provider names to their budget configurations + providers: Dict[str, ProviderBudgetResponseObject] = ( + {} + ) # Dictionary mapping provider names to their budget configurations class ProxyStateVariables(TypedDict): @@ -2673,9 +2675,9 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase): enforce_rbac: bool = False roles_jwt_field: Optional[str] = None # v2 on role mappings role_mappings: Optional[List[RoleMapping]] = None - object_id_jwt_field: Optional[ - str - ] = None # can be either user / team, inferred from the role mapping + object_id_jwt_field: Optional[str] = ( + None # can be either user / team, inferred from the role mapping + ) scope_mappings: Optional[List[ScopeMapping]] = None enforce_scope_based_access: bool = False enforce_team_based_model_access: bool = False diff --git a/tests/litellm/proxy/management_endpoints/test_internal_user_endpoints.py b/tests/litellm/proxy/management_endpoints/test_internal_user_endpoints.py index 54fda943eb..73ed8e1a30 100644 --- a/tests/litellm/proxy/management_endpoints/test_internal_user_endpoints.py +++ b/tests/litellm/proxy/management_endpoints/test_internal_user_endpoints.py @@ -1,6 +1,7 @@ import json import os import sys +from datetime import datetime, timezone import pytest from fastapi.testclient import TestClient @@ -10,7 +11,12 @@ sys.path.insert( ) # Adds the parent directory to the system path from litellm.proxy._types import LiteLLM_UserTableFiltered, UserAPIKeyAuth -from litellm.proxy.management_endpoints.internal_user_endpoints import ui_view_users +from litellm.proxy.management_endpoints.internal_user_endpoints import ( + LiteLLM_UserTableWithKeyCount, + get_user_key_counts, + get_users, + ui_view_users, +) from litellm.proxy.proxy_server import app client = TestClient(app) @@ -82,3 +88,68 @@ def test_user_daily_activity_types(): assert not hasattr( daily_spend_metadata, field ), f"Field {field} is reported in DailySpendMetadata" + + +@pytest.mark.asyncio +async def test_get_users_includes_timestamps(mocker): + """ + Test that /user/list endpoint returns users with created_at and updated_at fields. + """ + # Mock the prisma client + mock_prisma_client = mocker.MagicMock() + + # Create mock user data with timestamps + mock_user_data = { + "user_id": "test-user-timestamps", + "user_email": "timestamps@example.com", + "user_role": "internal_user", + "created_at": datetime.now(timezone.utc), + "updated_at": datetime.now(timezone.utc), + } + mock_user_row = mocker.MagicMock() + mock_user_row.model_dump.return_value = mock_user_data + + # Setup the mock find_many response as an async function + async def mock_find_many(*args, **kwargs): + return [mock_user_row] + + # Setup the mock count response as an async function + async def mock_count(*args, **kwargs): + return 1 + + mock_prisma_client.db.litellm_usertable.find_many = mock_find_many + mock_prisma_client.db.litellm_usertable.count = mock_count + + # Patch the prisma client import in the endpoint + mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client) + + # Mock the helper function get_user_key_counts + async def mock_get_user_key_counts(*args, **kwargs): + return {"test-user-timestamps": 0} + + mocker.patch( + "litellm.proxy.management_endpoints.internal_user_endpoints.get_user_key_counts", + mock_get_user_key_counts, + ) + + # Call get_users function directly + response = await get_users(page=1, page_size=1) + + print("user /list response: ", response) + + # Assertions + assert response is not None + assert "users" in response + assert "total" in response + assert response["total"] == 1 + assert len(response["users"]) == 1 + + user_response = response["users"][0] + assert user_response.user_id == "test-user-timestamps" + assert user_response.created_at is not None + assert isinstance(user_response.created_at, datetime) + assert user_response.updated_at is not None + assert isinstance(user_response.updated_at, datetime) + assert user_response.created_at == mock_user_data["created_at"] + assert user_response.updated_at == mock_user_data["updated_at"] + assert user_response.key_count == 0