mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
feat(internal_user_endpoints.py): support sort by on /user/list
This commit is contained in:
parent
3262e26817
commit
5969e8b650
2 changed files with 53 additions and 27 deletions
|
@ -902,6 +902,42 @@ async def get_user_key_counts(
|
|||
return result
|
||||
|
||||
|
||||
def _validate_sort_params(
|
||||
sort_by: Optional[str], sort_order: str
|
||||
) -> Optional[Dict[str, str]]:
|
||||
order_by: Dict[str, str] = {}
|
||||
|
||||
if sort_by is None:
|
||||
return None
|
||||
# Validate sort_by is a valid column
|
||||
valid_columns = [
|
||||
"user_id",
|
||||
"user_email",
|
||||
"created_at",
|
||||
"spend",
|
||||
"user_alias",
|
||||
"user_role",
|
||||
]
|
||||
if sort_by not in valid_columns:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Invalid sort column. Must be one of: {', '.join(valid_columns)}"
|
||||
},
|
||||
)
|
||||
|
||||
# Validate sort_order
|
||||
if sort_order.lower() not in ["asc", "desc"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "Invalid sort order. Must be 'asc' or 'desc'"},
|
||||
)
|
||||
|
||||
order_by[sort_by] = sort_order.lower()
|
||||
|
||||
return order_by
|
||||
|
||||
|
||||
@router.get(
|
||||
"/user/list",
|
||||
tags=["Internal User management"],
|
||||
|
@ -996,33 +1032,7 @@ async def get_users(
|
|||
where_conditions = {k: v for k, v in where_conditions.items() if v is not None}
|
||||
|
||||
# Build order_by conditions
|
||||
order_by: Dict[str, str] = {}
|
||||
if sort_by:
|
||||
# Validate sort_by is a valid column
|
||||
valid_columns = [
|
||||
"user_id",
|
||||
"user_email",
|
||||
"created_at",
|
||||
"spend",
|
||||
"user_alias",
|
||||
"user_role",
|
||||
]
|
||||
if sort_by not in valid_columns:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Invalid sort column. Must be one of: {', '.join(valid_columns)}"
|
||||
},
|
||||
)
|
||||
|
||||
# Validate sort_order
|
||||
if sort_order.lower() not in ["asc", "desc"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "Invalid sort order. Must be 'asc' or 'desc'"},
|
||||
)
|
||||
|
||||
order_by[sort_by] = sort_order.lower()
|
||||
order_by: Optional[Dict[str, str]] = _validate_sort_params(sort_by, sort_order)
|
||||
|
||||
users = await prisma_client.db.litellm_usertable.find_many(
|
||||
where=where_conditions,
|
||||
|
|
|
@ -153,3 +153,19 @@ async def test_get_users_includes_timestamps(mocker):
|
|||
assert user_response.created_at == mock_user_data["created_at"]
|
||||
assert user_response.updated_at == mock_user_data["updated_at"]
|
||||
assert user_response.key_count == 0
|
||||
|
||||
|
||||
def test_validate_sort_params():
|
||||
"""
|
||||
Test that validate_sort_params returns None if sort_by is None
|
||||
"""
|
||||
from litellm.proxy.management_endpoints.internal_user_endpoints import (
|
||||
_validate_sort_params,
|
||||
)
|
||||
|
||||
assert _validate_sort_params(None, "asc") is None
|
||||
assert _validate_sort_params(None, "desc") is None
|
||||
assert _validate_sort_params("user_id", "asc") == {"user_id": "asc"}
|
||||
assert _validate_sort_params("user_id", "desc") == {"user_id": "desc"}
|
||||
with pytest.raises(Exception):
|
||||
_validate_sort_params("user_id", "invalid")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue