diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py index 65bc4ebdaf..b69969bdee 100644 --- a/litellm/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -902,6 +902,42 @@ async def get_user_key_counts( return result +def _validate_sort_params( + sort_by: Optional[str], sort_order: str +) -> Optional[Dict[str, str]]: + order_by: Dict[str, str] = {} + + if sort_by is None: + return None + # Validate sort_by is a valid column + valid_columns = [ + "user_id", + "user_email", + "created_at", + "spend", + "user_alias", + "user_role", + ] + if sort_by not in valid_columns: + raise HTTPException( + status_code=400, + detail={ + "error": f"Invalid sort column. Must be one of: {', '.join(valid_columns)}" + }, + ) + + # Validate sort_order + if sort_order.lower() not in ["asc", "desc"]: + raise HTTPException( + status_code=400, + detail={"error": "Invalid sort order. Must be 'asc' or 'desc'"}, + ) + + order_by[sort_by] = sort_order.lower() + + return order_by + + @router.get( "/user/list", tags=["Internal User management"], @@ -996,33 +1032,7 @@ async def get_users( where_conditions = {k: v for k, v in where_conditions.items() if v is not None} # Build order_by conditions - order_by: Dict[str, str] = {} - if sort_by: - # Validate sort_by is a valid column - valid_columns = [ - "user_id", - "user_email", - "created_at", - "spend", - "user_alias", - "user_role", - ] - if sort_by not in valid_columns: - raise HTTPException( - status_code=400, - detail={ - "error": f"Invalid sort column. Must be one of: {', '.join(valid_columns)}" - }, - ) - - # Validate sort_order - if sort_order.lower() not in ["asc", "desc"]: - raise HTTPException( - status_code=400, - detail={"error": "Invalid sort order. Must be 'asc' or 'desc'"}, - ) - - order_by[sort_by] = sort_order.lower() + order_by: Optional[Dict[str, str]] = _validate_sort_params(sort_by, sort_order) users = await prisma_client.db.litellm_usertable.find_many( where=where_conditions, diff --git a/tests/litellm/proxy/management_endpoints/test_internal_user_endpoints.py b/tests/litellm/proxy/management_endpoints/test_internal_user_endpoints.py index deef94c15a..360f21f171 100644 --- a/tests/litellm/proxy/management_endpoints/test_internal_user_endpoints.py +++ b/tests/litellm/proxy/management_endpoints/test_internal_user_endpoints.py @@ -153,3 +153,19 @@ async def test_get_users_includes_timestamps(mocker): assert user_response.created_at == mock_user_data["created_at"] assert user_response.updated_at == mock_user_data["updated_at"] assert user_response.key_count == 0 + + +def test_validate_sort_params(): + """ + Test that validate_sort_params returns None if sort_by is None + """ + from litellm.proxy.management_endpoints.internal_user_endpoints import ( + _validate_sort_params, + ) + + assert _validate_sort_params(None, "asc") is None + assert _validate_sort_params(None, "desc") is None + assert _validate_sort_params("user_id", "asc") == {"user_id": "asc"} + assert _validate_sort_params("user_id", "desc") == {"user_id": "desc"} + with pytest.raises(Exception): + _validate_sort_params("user_id", "invalid")