feat(internal_user_endpoints.py): support sort by on /user/list

This commit is contained in:
Krrish Dholakia 2025-04-22 14:38:37 -07:00
parent 3262e26817
commit 5969e8b650
2 changed files with 53 additions and 27 deletions

View file

@ -902,6 +902,42 @@ async def get_user_key_counts(
return result return result
def _validate_sort_params(
sort_by: Optional[str], sort_order: str
) -> Optional[Dict[str, str]]:
order_by: Dict[str, str] = {}
if sort_by is None:
return None
# Validate sort_by is a valid column
valid_columns = [
"user_id",
"user_email",
"created_at",
"spend",
"user_alias",
"user_role",
]
if sort_by not in valid_columns:
raise HTTPException(
status_code=400,
detail={
"error": f"Invalid sort column. Must be one of: {', '.join(valid_columns)}"
},
)
# Validate sort_order
if sort_order.lower() not in ["asc", "desc"]:
raise HTTPException(
status_code=400,
detail={"error": "Invalid sort order. Must be 'asc' or 'desc'"},
)
order_by[sort_by] = sort_order.lower()
return order_by
@router.get( @router.get(
"/user/list", "/user/list",
tags=["Internal User management"], tags=["Internal User management"],
@ -996,33 +1032,7 @@ async def get_users(
where_conditions = {k: v for k, v in where_conditions.items() if v is not None} where_conditions = {k: v for k, v in where_conditions.items() if v is not None}
# Build order_by conditions # Build order_by conditions
order_by: Dict[str, str] = {} order_by: Optional[Dict[str, str]] = _validate_sort_params(sort_by, sort_order)
if sort_by:
# Validate sort_by is a valid column
valid_columns = [
"user_id",
"user_email",
"created_at",
"spend",
"user_alias",
"user_role",
]
if sort_by not in valid_columns:
raise HTTPException(
status_code=400,
detail={
"error": f"Invalid sort column. Must be one of: {', '.join(valid_columns)}"
},
)
# Validate sort_order
if sort_order.lower() not in ["asc", "desc"]:
raise HTTPException(
status_code=400,
detail={"error": "Invalid sort order. Must be 'asc' or 'desc'"},
)
order_by[sort_by] = sort_order.lower()
users = await prisma_client.db.litellm_usertable.find_many( users = await prisma_client.db.litellm_usertable.find_many(
where=where_conditions, where=where_conditions,

View file

@ -153,3 +153,19 @@ async def test_get_users_includes_timestamps(mocker):
assert user_response.created_at == mock_user_data["created_at"] assert user_response.created_at == mock_user_data["created_at"]
assert user_response.updated_at == mock_user_data["updated_at"] assert user_response.updated_at == mock_user_data["updated_at"]
assert user_response.key_count == 0 assert user_response.key_count == 0
def test_validate_sort_params():
"""
Test that validate_sort_params returns None if sort_by is None
"""
from litellm.proxy.management_endpoints.internal_user_endpoints import (
_validate_sort_params,
)
assert _validate_sort_params(None, "asc") is None
assert _validate_sort_params(None, "desc") is None
assert _validate_sort_params("user_id", "asc") == {"user_id": "asc"}
assert _validate_sort_params("user_id", "desc") == {"user_id": "desc"}
with pytest.raises(Exception):
_validate_sort_params("user_id", "invalid")