mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
feat(internal_user_endpoints.py): support sort by on /user/list
This commit is contained in:
parent
3262e26817
commit
5969e8b650
2 changed files with 53 additions and 27 deletions
|
@ -902,6 +902,42 @@ async def get_user_key_counts(
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_sort_params(
|
||||||
|
sort_by: Optional[str], sort_order: str
|
||||||
|
) -> Optional[Dict[str, str]]:
|
||||||
|
order_by: Dict[str, str] = {}
|
||||||
|
|
||||||
|
if sort_by is None:
|
||||||
|
return None
|
||||||
|
# Validate sort_by is a valid column
|
||||||
|
valid_columns = [
|
||||||
|
"user_id",
|
||||||
|
"user_email",
|
||||||
|
"created_at",
|
||||||
|
"spend",
|
||||||
|
"user_alias",
|
||||||
|
"user_role",
|
||||||
|
]
|
||||||
|
if sort_by not in valid_columns:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": f"Invalid sort column. Must be one of: {', '.join(valid_columns)}"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate sort_order
|
||||||
|
if sort_order.lower() not in ["asc", "desc"]:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={"error": "Invalid sort order. Must be 'asc' or 'desc'"},
|
||||||
|
)
|
||||||
|
|
||||||
|
order_by[sort_by] = sort_order.lower()
|
||||||
|
|
||||||
|
return order_by
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/user/list",
|
"/user/list",
|
||||||
tags=["Internal User management"],
|
tags=["Internal User management"],
|
||||||
|
@ -996,33 +1032,7 @@ async def get_users(
|
||||||
where_conditions = {k: v for k, v in where_conditions.items() if v is not None}
|
where_conditions = {k: v for k, v in where_conditions.items() if v is not None}
|
||||||
|
|
||||||
# Build order_by conditions
|
# Build order_by conditions
|
||||||
order_by: Dict[str, str] = {}
|
order_by: Optional[Dict[str, str]] = _validate_sort_params(sort_by, sort_order)
|
||||||
if sort_by:
|
|
||||||
# Validate sort_by is a valid column
|
|
||||||
valid_columns = [
|
|
||||||
"user_id",
|
|
||||||
"user_email",
|
|
||||||
"created_at",
|
|
||||||
"spend",
|
|
||||||
"user_alias",
|
|
||||||
"user_role",
|
|
||||||
]
|
|
||||||
if sort_by not in valid_columns:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail={
|
|
||||||
"error": f"Invalid sort column. Must be one of: {', '.join(valid_columns)}"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate sort_order
|
|
||||||
if sort_order.lower() not in ["asc", "desc"]:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail={"error": "Invalid sort order. Must be 'asc' or 'desc'"},
|
|
||||||
)
|
|
||||||
|
|
||||||
order_by[sort_by] = sort_order.lower()
|
|
||||||
|
|
||||||
users = await prisma_client.db.litellm_usertable.find_many(
|
users = await prisma_client.db.litellm_usertable.find_many(
|
||||||
where=where_conditions,
|
where=where_conditions,
|
||||||
|
|
|
@ -153,3 +153,19 @@ async def test_get_users_includes_timestamps(mocker):
|
||||||
assert user_response.created_at == mock_user_data["created_at"]
|
assert user_response.created_at == mock_user_data["created_at"]
|
||||||
assert user_response.updated_at == mock_user_data["updated_at"]
|
assert user_response.updated_at == mock_user_data["updated_at"]
|
||||||
assert user_response.key_count == 0
|
assert user_response.key_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_sort_params():
|
||||||
|
"""
|
||||||
|
Test that validate_sort_params returns None if sort_by is None
|
||||||
|
"""
|
||||||
|
from litellm.proxy.management_endpoints.internal_user_endpoints import (
|
||||||
|
_validate_sort_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert _validate_sort_params(None, "asc") is None
|
||||||
|
assert _validate_sort_params(None, "desc") is None
|
||||||
|
assert _validate_sort_params("user_id", "asc") == {"user_id": "asc"}
|
||||||
|
assert _validate_sort_params("user_id", "desc") == {"user_id": "desc"}
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
_validate_sort_params("user_id", "invalid")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue