UI - Users page - Enable global sorting (allows finding users with highest spend) (#10211)

* fix(view_users.tsx): add time tracking logic to debounce search - prevent new queries from being overwritten by previous ones

* fix(internal_user_endpoints.py): add sort functionality to user list endpoint

* feat(internal_user_endpoints.py): support sort by on `/user/list`

* fix(view_users.tsx): enable global sorting

allows finding user with highest spend

* feat(view_users.tsx): support filtering by sso user id

* test(search_users.spec.ts): add tests to ensure filtering works

* test: add more unit testing
This commit is contained in:
Krish Dholakia 2025-04-22 19:59:53 -07:00 committed by Christian Owusu
parent 7c8a9e216b
commit aba37e9f56
6 changed files with 287 additions and 17 deletions

View file

@ -902,6 +902,42 @@ async def get_user_key_counts(
return result return result
def _validate_sort_params(
sort_by: Optional[str], sort_order: str
) -> Optional[Dict[str, str]]:
order_by: Dict[str, str] = {}
if sort_by is None:
return None
# Validate sort_by is a valid column
valid_columns = [
"user_id",
"user_email",
"created_at",
"spend",
"user_alias",
"user_role",
]
if sort_by not in valid_columns:
raise HTTPException(
status_code=400,
detail={
"error": f"Invalid sort column. Must be one of: {', '.join(valid_columns)}"
},
)
# Validate sort_order
if sort_order.lower() not in ["asc", "desc"]:
raise HTTPException(
status_code=400,
detail={"error": "Invalid sort order. Must be 'asc' or 'desc'"},
)
order_by[sort_by] = sort_order.lower()
return order_by
@router.get( @router.get(
"/user/list", "/user/list",
tags=["Internal User management"], tags=["Internal User management"],
@ -915,6 +951,9 @@ async def get_users(
user_ids: Optional[str] = fastapi.Query( user_ids: Optional[str] = fastapi.Query(
default=None, description="Get list of users by user_ids" default=None, description="Get list of users by user_ids"
), ),
sso_user_ids: Optional[str] = fastapi.Query(
default=None, description="Get list of users by sso_user_id"
),
user_email: Optional[str] = fastapi.Query( user_email: Optional[str] = fastapi.Query(
default=None, description="Filter users by partial email match" default=None, description="Filter users by partial email match"
), ),
@ -925,9 +964,16 @@ async def get_users(
page_size: int = fastapi.Query( page_size: int = fastapi.Query(
default=25, ge=1, le=100, description="Number of items per page" default=25, ge=1, le=100, description="Number of items per page"
), ),
sort_by: Optional[str] = fastapi.Query(
default=None,
description="Column to sort by (e.g. 'user_id', 'user_email', 'created_at', 'spend')",
),
sort_order: str = fastapi.Query(
default="asc", description="Sort order ('asc' or 'desc')"
),
): ):
""" """
Get a paginated list of users with filtering options. Get a paginated list of users with filtering and sorting options.
Parameters: Parameters:
role: Optional[str] role: Optional[str]
@ -938,6 +984,8 @@ async def get_users(
- internal_user_viewer - internal_user_viewer
user_ids: Optional[str] user_ids: Optional[str]
Get list of users by user_ids. Comma separated list of user_ids. Get list of users by user_ids. Comma separated list of user_ids.
sso_ids: Optional[str]
Get list of users by sso_ids. Comma separated list of sso_ids.
user_email: Optional[str] user_email: Optional[str]
Filter users by partial email match Filter users by partial email match
team: Optional[str] team: Optional[str]
@ -946,9 +994,10 @@ async def get_users(
The page number to return The page number to return
page_size: int page_size: int
The number of items per page The number of items per page
sort_by: Optional[str]
Returns: Column to sort by (e.g. 'user_id', 'user_email', 'created_at', 'spend')
UserListResponse with filtered and paginated users sort_order: Optional[str]
Sort order ('asc' or 'desc')
""" """
from litellm.proxy.proxy_server import prisma_client from litellm.proxy.proxy_server import prisma_client
@ -984,13 +1033,25 @@ async def get_users(
"has": team # Array contains for string arrays in Prisma "has": team # Array contains for string arrays in Prisma
} }
if sso_user_ids is not None and isinstance(sso_user_ids, str):
sso_id_list = [sid.strip() for sid in sso_user_ids.split(",") if sid.strip()]
where_conditions["sso_user_id"] = {
"in": sso_id_list,
}
## Filter any none fastapi.Query params - e.g. where_conditions: {'user_email': {'contains': Query(None), 'mode': 'insensitive'}, 'teams': {'has': Query(None)}} ## Filter any none fastapi.Query params - e.g. where_conditions: {'user_email': {'contains': Query(None), 'mode': 'insensitive'}, 'teams': {'has': Query(None)}}
where_conditions = {k: v for k, v in where_conditions.items() if v is not None} where_conditions = {k: v for k, v in where_conditions.items() if v is not None}
# Build order_by conditions
order_by: Optional[Dict[str, str]] = _validate_sort_params(sort_by, sort_order)
users = await prisma_client.db.litellm_usertable.find_many( users = await prisma_client.db.litellm_usertable.find_many(
where=where_conditions, where=where_conditions,
skip=skip, skip=skip,
take=page_size, take=page_size,
order={"created_at": "desc"}, order=order_by
if order_by
else {"created_at": "desc"}, # Default to created_at desc if no sort specified
) )
# Get total count of user rows # Get total count of user rows

View file

@ -153,3 +153,19 @@ async def test_get_users_includes_timestamps(mocker):
assert user_response.created_at == mock_user_data["created_at"] assert user_response.created_at == mock_user_data["created_at"]
assert user_response.updated_at == mock_user_data["updated_at"] assert user_response.updated_at == mock_user_data["updated_at"]
assert user_response.key_count == 0 assert user_response.key_count == 0
def test_validate_sort_params():
"""
Test that validate_sort_params returns None if sort_by is None
"""
from litellm.proxy.management_endpoints.internal_user_endpoints import (
_validate_sort_params,
)
assert _validate_sort_params(None, "asc") is None
assert _validate_sort_params(None, "desc") is None
assert _validate_sort_params("user_id", "asc") == {"user_id": "asc"}
assert _validate_sort_params("user_id", "desc") == {"user_id": "desc"}
with pytest.raises(Exception):
_validate_sort_params("user_id", "invalid")

View file

@ -7,6 +7,7 @@ Tests:
2. Verify search input exists 2. Verify search input exists
3. Test search functionality 3. Test search functionality
4. Verify results update 4. Verify results update
5. Test filtering by email, user ID, and SSO user ID
*/ */
import { test, expect } from "@playwright/test"; import { test, expect } from "@playwright/test";
@ -61,7 +62,7 @@ test("user search test", async ({ page }) => {
console.log("Clicked Internal User tab"); console.log("Clicked Internal User tab");
// Wait for the page to load and table to be visible // Wait for the page to load and table to be visible
await page.waitForSelector("tbody tr", { timeout: 10000 }); await page.waitForSelector("tbody tr", { timeout: 30000 });
await page.waitForTimeout(2000); // Additional wait for table to stabilize await page.waitForTimeout(2000); // Additional wait for table to stabilize
console.log("Table is visible"); console.log("Table is visible");
@ -117,3 +118,97 @@ test("user search test", async ({ page }) => {
expect(resetUserCount).toBe(initialUserCount); expect(resetUserCount).toBe(initialUserCount);
}); });
test("user filter test", async ({ page }) => {
// Set a longer timeout for the entire test
test.setTimeout(60000);
// Enable console logging
page.on("console", (msg) => console.log("PAGE LOG:", msg.text()));
// Login first
await page.goto("http://localhost:4000/ui");
console.log("Navigated to login page");
// Wait for login form to be visible
await page.waitForSelector('input[name="username"]', { timeout: 10000 });
console.log("Login form is visible");
await page.fill('input[name="username"]', "admin");
await page.fill('input[name="password"]', "gm");
console.log("Filled login credentials");
const loginButton = page.locator('input[type="submit"]');
await expect(loginButton).toBeEnabled();
await loginButton.click();
console.log("Clicked login button");
// Wait for navigation to complete and dashboard to load
await page.waitForLoadState("networkidle");
console.log("Page loaded after login");
// Navigate to Internal Users tab
const internalUserTab = page.locator("span.ant-menu-title-content", {
hasText: "Internal User",
});
await internalUserTab.waitFor({ state: "visible", timeout: 10000 });
await internalUserTab.click();
console.log("Clicked Internal User tab");
// Wait for the page to load and table to be visible
await page.waitForSelector("tbody tr", { timeout: 30000 });
await page.waitForTimeout(2000); // Additional wait for table to stabilize
console.log("Table is visible");
// Get initial user count
const initialUserCount = await page.locator("tbody tr").count();
console.log(`Initial user count: ${initialUserCount}`);
// Click the filter button to show additional filters
const filterButton = page.getByRole("button", {
name: "Filters",
exact: true,
});
await filterButton.click();
console.log("Clicked filter button");
await page.waitForTimeout(500); // Wait for filters to appear
// Test user ID filter
const userIdInput = page.locator('input[placeholder="Filter by User ID"]');
await expect(userIdInput).toBeVisible();
console.log("User ID filter is visible");
await userIdInput.fill("user");
console.log("Filled user ID filter");
await page.waitForTimeout(1000);
const userIdFilteredCount = await page.locator("tbody tr").count();
console.log(`User ID filtered count: ${userIdFilteredCount}`);
expect(userIdFilteredCount).toBeLessThan(initialUserCount);
// Clear user ID filter
await userIdInput.clear();
await page.waitForTimeout(1000);
console.log("Cleared user ID filter");
// Test SSO user ID filter
const ssoUserIdInput = page.locator('input[placeholder="Filter by SSO ID"]');
await expect(ssoUserIdInput).toBeVisible();
console.log("SSO user ID filter is visible");
await ssoUserIdInput.fill("sso");
console.log("Filled SSO user ID filter");
await page.waitForTimeout(1000);
const ssoUserIdFilteredCount = await page.locator("tbody tr").count();
console.log(`SSO user ID filtered count: ${ssoUserIdFilteredCount}`);
expect(ssoUserIdFilteredCount).toBeLessThan(initialUserCount);
// Clear SSO user ID filter
await ssoUserIdInput.clear();
await page.waitForTimeout(5000);
console.log("Cleared SSO user ID filter");
// Verify count returns to initial after clearing all filters
const finalUserCount = await page.locator("tbody tr").count();
console.log(`Final user count: ${finalUserCount}`);
expect(finalUserCount).toBe(initialUserCount);
});

View file

@ -679,6 +679,9 @@ export const userListCall = async (
userEmail: string | null = null, userEmail: string | null = null,
userRole: string | null = null, userRole: string | null = null,
team: string | null = null, team: string | null = null,
sso_user_id: string | null = null,
sortBy: string | null = null,
sortOrder: 'asc' | 'desc' | null = null,
) => { ) => {
/** /**
* Get all available teams on proxy * Get all available teams on proxy
@ -714,6 +717,18 @@ export const userListCall = async (
queryParams.append('team', team); queryParams.append('team', team);
} }
if (sso_user_id) {
queryParams.append('sso_user_ids', sso_user_id);
}
if (sortBy) {
queryParams.append('sort_by', sortBy);
}
if (sortOrder) {
queryParams.append('sort_order', sortOrder);
}
const queryString = queryParams.toString(); const queryString = queryParams.toString();
if (queryString) { if (queryString) {
url += `?${queryString}`; url += `?${queryString}`;

View file

@ -84,10 +84,13 @@ interface FilterState {
email: string; email: string;
user_id: string; user_id: string;
user_role: string; user_role: string;
sso_user_id: string;
team: string; team: string;
model: string; model: string;
min_spend: number | null; min_spend: number | null;
max_spend: number | null; max_spend: number | null;
sort_by: string;
sort_order: 'asc' | 'desc';
} }
const isLocal = process.env.NODE_ENV === "development"; const isLocal = process.env.NODE_ENV === "development";
@ -124,15 +127,19 @@ const ViewUserDashboard: React.FC<ViewUserDashboardProps> = ({
email: "", email: "",
user_id: "", user_id: "",
user_role: "", user_role: "",
sso_user_id: "",
team: "", team: "",
model: "", model: "",
min_spend: null, min_spend: null,
max_spend: null max_spend: null,
sort_by: "created_at",
sort_order: "desc"
}); });
const [showFilters, setShowFilters] = useState(false); const [showFilters, setShowFilters] = useState(false);
const [showColumnDropdown, setShowColumnDropdown] = useState(false); const [showColumnDropdown, setShowColumnDropdown] = useState(false);
const [selectedFilter, setSelectedFilter] = useState("Email"); const [selectedFilter, setSelectedFilter] = useState("Email");
const filtersRef = useRef(null); const filtersRef = useRef(null);
const lastSearchTimestamp = useRef(0);
// check if window is not undefined // check if window is not undefined
if (typeof window !== "undefined") { if (typeof window !== "undefined") {
@ -150,6 +157,17 @@ const ViewUserDashboard: React.FC<ViewUserDashboardProps> = ({
const handleFilterChange = (key: keyof FilterState, value: string | number | null) => { const handleFilterChange = (key: keyof FilterState, value: string | number | null) => {
const newFilters = { ...filters, [key]: value }; const newFilters = { ...filters, [key]: value };
setFilters(newFilters); setFilters(newFilters);
console.log("called from handleFilterChange - newFilters:", JSON.stringify(newFilters));
debouncedSearch(newFilters);
};
const handleSortChange = (sortBy: string, sortOrder: 'asc' | 'desc') => {
const newFilters = {
...filters,
sort_by: sortBy,
sort_order: sortOrder
};
setFilters(newFilters);
debouncedSearch(newFilters); debouncedSearch(newFilters);
}; };
@ -159,6 +177,10 @@ const ViewUserDashboard: React.FC<ViewUserDashboardProps> = ({
if (!accessToken || !token || !userRole || !userID) { if (!accessToken || !token || !userRole || !userID) {
return; return;
} }
const currentTimestamp = Date.now();
lastSearchTimestamp.current = currentTimestamp;
try { try {
// Make the API call using userListCall with all filter parameters // Make the API call using userListCall with all filter parameters
const data = await userListCall( const data = await userListCall(
@ -168,12 +190,19 @@ const ViewUserDashboard: React.FC<ViewUserDashboardProps> = ({
defaultPageSize, defaultPageSize,
filters.email || null, filters.email || null,
filters.user_role || null, filters.user_role || null,
filters.team || null filters.team || null,
filters.sso_user_id || null,
filters.sort_by,
filters.sort_order
); );
if (data) { // Only update state if this is the most recent search
setUserListResponse(data); if (currentTimestamp === lastSearchTimestamp.current) {
console.log("called from debouncedSearch"); if (data) {
setUserListResponse(data);
console.log("called from debouncedSearch filters:", JSON.stringify(filters));
console.log("called from debouncedSearch data:", JSON.stringify(data));
}
} }
} catch (error) { } catch (error) {
console.error("Error searching users:", error); console.error("Error searching users:", error);
@ -252,6 +281,7 @@ const ViewUserDashboard: React.FC<ViewUserDashboardProps> = ({
}; };
const refreshUserData = async () => { const refreshUserData = async () => {
console.log("called from refreshUserData");
if (!accessToken || !token || !userRole || !userID) { if (!accessToken || !token || !userRole || !userID) {
return; return;
} }
@ -291,7 +321,10 @@ const ViewUserDashboard: React.FC<ViewUserDashboardProps> = ({
defaultPageSize, defaultPageSize,
filters.email || null, filters.email || null,
filters.user_role || null, filters.user_role || null,
filters.team || null filters.team || null,
filters.sso_user_id || null,
filters.sort_by,
filters.sort_order
); );
// Update session storage with new data // Update session storage with new data
@ -328,7 +361,10 @@ const ViewUserDashboard: React.FC<ViewUserDashboardProps> = ({
defaultPageSize, defaultPageSize,
filters.email || null, filters.email || null,
filters.user_role || null, filters.user_role || null,
filters.team || null filters.team || null,
filters.sso_user_id || null,
filters.sort_by,
filters.sort_order
); );
// Store in session storage // Store in session storage
@ -462,9 +498,12 @@ const ViewUserDashboard: React.FC<ViewUserDashboardProps> = ({
user_id: "", user_id: "",
user_role: "", user_role: "",
team: "", team: "",
sso_user_id: "",
model: "", model: "",
min_spend: null, min_spend: null,
max_spend: null max_spend: null,
sort_by: "created_at",
sort_order: "desc"
}); });
}} }}
> >
@ -541,6 +580,17 @@ const ViewUserDashboard: React.FC<ViewUserDashboardProps> = ({
))} ))}
</Select> </Select>
</div> </div>
{/* SSO ID Search */}
<div className="relative w-64">
<input
type="text"
placeholder="Filter by SSO ID"
className="w-full px-3 py-2 pl-8 border rounded-md text-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500"
value={filters.sso_user_id}
onChange={(e) => handleFilterChange('sso_user_id', e.target.value)}
/>
</div>
</div> </div>
)} )}
@ -591,9 +641,14 @@ const ViewUserDashboard: React.FC<ViewUserDashboardProps> = ({
</div> </div>
<UserDataTable <UserDataTable
data={userListResponse.users || []} data={userListResponse?.users || []}
columns={tableColumns} columns={tableColumns}
isLoading={!userListResponse} isLoading={!userListResponse}
onSortChange={handleSortChange}
currentSort={{
sortBy: filters.sort_by,
sortOrder: filters.sort_order
}}
/> />
</div> </div>
</TabPanel> </TabPanel>

View file

@ -23,15 +23,25 @@ interface UserDataTableProps {
data: UserInfo[]; data: UserInfo[];
columns: ColumnDef<UserInfo, any>[]; columns: ColumnDef<UserInfo, any>[];
isLoading?: boolean; isLoading?: boolean;
onSortChange?: (sortBy: string, sortOrder: 'asc' | 'desc') => void;
currentSort?: {
sortBy: string;
sortOrder: 'asc' | 'desc';
};
} }
export function UserDataTable({ export function UserDataTable({
data = [], data = [],
columns, columns,
isLoading = false, isLoading = false,
onSortChange,
currentSort,
}: UserDataTableProps) { }: UserDataTableProps) {
const [sorting, setSorting] = React.useState<SortingState>([ const [sorting, setSorting] = React.useState<SortingState>([
{ id: "created_at", desc: true } {
id: currentSort?.sortBy || "created_at",
desc: currentSort?.sortOrder === "desc"
}
]); ]);
const table = useReactTable({ const table = useReactTable({
@ -40,12 +50,30 @@ export function UserDataTable({
state: { state: {
sorting, sorting,
}, },
onSortingChange: setSorting, onSortingChange: (newSorting: any) => {
setSorting(newSorting);
if (newSorting.length > 0) {
const sortState = newSorting[0];
const sortBy = sortState.id;
const sortOrder = sortState.desc ? 'desc' : 'asc';
onSortChange?.(sortBy, sortOrder);
}
},
getCoreRowModel: getCoreRowModel(), getCoreRowModel: getCoreRowModel(),
getSortedRowModel: getSortedRowModel(), getSortedRowModel: getSortedRowModel(),
enableSorting: true, enableSorting: true,
}); });
// Update local sorting state when currentSort prop changes
React.useEffect(() => {
if (currentSort) {
setSorting([{
id: currentSort.sortBy,
desc: currentSort.sortOrder === 'desc'
}]);
}
}, [currentSort]);
return ( return (
<div className="rounded-lg custom-border relative"> <div className="rounded-lg custom-border relative">
<div className="overflow-x-auto"> <div className="overflow-x-auto">