diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index f6fd1a4b86..d2f20e56cf 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1578,6 +1578,10 @@ class LiteLLM_UserTableFiltered(BaseModel): # done to avoid exposing sensitive user_email: str +class LiteLLM_UserTableWithKeyCount(LiteLLM_UserTable): + key_count: int = 0 + + class LiteLLM_EndUserTable(LiteLLMPydanticObjectBase): user_id: str blocked: bool diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py index 1d381ab145..a414f48847 100644 --- a/litellm/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -753,6 +753,9 @@ async def get_users( role: Optional[str] = fastapi.Query( default=None, description="Filter users by role" ), + user_ids: Optional[str] = fastapi.Query( + default=None, description="Get list of users by user_ids" + ), page: int = fastapi.Query(default=1, ge=1, description="Page number"), page_size: int = fastapi.Query( default=25, ge=1, le=100, description="Number of items per page" @@ -770,12 +773,19 @@ async def get_users( - proxy_admin_viewer - internal_user - internal_user_viewer + user_ids: Optional[str] + Get list of users by user_ids. Comma separated list of user_ids. page: int The page number to return page_size: int The number of items per page Currently - admin-only endpoint. + + Example curl: + ``` + http://0.0.0.0:4000/user/list?user_ids=default_user_id,693c1a4a-1cc0-4c7c-afe8-b5d2c8d52e17 + ``` """ from litellm.proxy.proxy_server import prisma_client @@ -787,49 +797,69 @@ async def get_users( # Calculate skip and take for pagination skip = (page - 1) * page_size - take = page_size # Prepare the query conditions - where_clause = "" + # Build where conditions based on provided parameters + where_conditions: Dict[str, Any] = {} + if role: - where_clause = f"""WHERE "user_role" = '{role}'""" + where_conditions["user_role"] = { + "contains": role, + "mode": "insensitive", # Case-insensitive search + } - # Single optimized SQL query that gets both users and total count - sql_query = f""" - WITH total_users AS ( - SELECT COUNT(*) AS total_number_internal_users - FROM "LiteLLM_UserTable" - ), - paginated_users AS ( - SELECT - u.*, - ( - SELECT COUNT(*) - FROM "LiteLLM_VerificationToken" vt - WHERE vt."user_id" = u."user_id" - ) AS key_count - FROM "LiteLLM_UserTable" u - {where_clause} - LIMIT {take} OFFSET {skip} + if user_ids and isinstance(user_ids, str): + user_id_list = [uid.strip() for uid in user_ids.split(",") if uid.strip()] + where_conditions["user_id"] = { + "in": user_id_list, # Now passing a list of strings as required by Prisma + } + + users: Optional[List[LiteLLM_UserTable]] = ( + await prisma_client.db.litellm_usertable.find_many( + where=where_conditions, + skip=skip, + take=page_size, + order={"created_at": "desc"}, + ) ) - SELECT - (SELECT total_number_internal_users FROM total_users), - * - FROM paginated_users; - """ - # Execute the query - results = await prisma_client.db.query_raw(sql_query) - # Get total count from the first row (if results exist) - total_count = 0 - if len(results) > 0: - total_count = results[0].get("total_number_internal_users") + # Get total count of user rows + total_count = await prisma_client.db.litellm_usertable.count( + where=where_conditions # type: ignore + ) + + # Get key count for each user + if users is not None: + user_keys = await prisma_client.db.litellm_verificationtoken.group_by( + by=["user_id"], + count={"user_id": True}, + where={"user_id": {"in": [user.user_id for user in users]}}, + ) + user_key_counts = { + item["user_id"]: item["_count"]["user_id"] for item in user_keys + } + else: + user_key_counts = {} + + verbose_proxy_logger.debug(f"Total count of users: {total_count}") # Calculate total pages total_pages = -(-total_count // page_size) # Ceiling division + # Prepare response + user_list: List[LiteLLM_UserTableWithKeyCount] = [] + if users is not None: + for user in users: + user_list.append( + LiteLLM_UserTableWithKeyCount( + **user.model_dump(), key_count=user_key_counts.get(user.user_id, 0) + ) + ) # Return full key object + else: + user_list = [] + return { - "users": results, + "users": user_list, "total": total_count, "page": page, "page_size": page_size, diff --git a/tests/proxy_admin_ui_tests/test_key_management.py b/tests/proxy_admin_ui_tests/test_key_management.py index 6935344070..9d6c24db0e 100644 --- a/tests/proxy_admin_ui_tests/test_key_management.py +++ b/tests/proxy_admin_ui_tests/test_key_management.py @@ -370,11 +370,7 @@ async def test_get_users(prisma_client): assert "users" in result for user in result["users"]: - assert "user_id" in user - assert "spend" in user - assert "user_email" in user - assert "user_role" in user - assert "key_count" in user + assert isinstance(user, LiteLLM_UserTable) # Clean up test users for user in test_users: @@ -397,12 +393,12 @@ async def test_get_users_key_count(prisma_client): assert len(initial_users["users"]) > 0, "No users found to test with" test_user = initial_users["users"][0] - initial_key_count = test_user["key_count"] + initial_key_count = test_user.key_count # Create a new key for the selected user new_key = await generate_key_fn( data=GenerateKeyRequest( - user_id=test_user["user_id"], + user_id=test_user.user_id, key_alias=f"test_key_{uuid.uuid4()}", models=["fake-model"], ), @@ -418,8 +414,8 @@ async def test_get_users_key_count(prisma_client): print("updated_users", updated_users) updated_key_count = None for user in updated_users["users"]: - if user["user_id"] == test_user["user_id"]: - updated_key_count = user["key_count"] + if user.user_id == test_user.user_id: + updated_key_count = user.key_count break assert updated_key_count is not None, "Test user not found in updated users list" diff --git a/ui/litellm-dashboard/src/components/all_keys_table.tsx b/ui/litellm-dashboard/src/components/all_keys_table.tsx index 008ac1c878..274567b732 100644 --- a/ui/litellm-dashboard/src/components/all_keys_table.tsx +++ b/ui/litellm-dashboard/src/components/all_keys_table.tsx @@ -1,5 +1,5 @@ "use client"; -import React, { useState } from "react"; +import React, { useEffect, useState } from "react"; import { ColumnDef, Row } from "@tanstack/react-table"; import { DataTable } from "./view_logs/table"; import { Select, SelectItem } from "@tremor/react" @@ -9,7 +9,7 @@ import { Tooltip } from "antd"; import { Team, KeyResponse } from "./key_team_helpers/key_list"; import FilterComponent from "./common_components/filter"; import { FilterOption } from "./common_components/filter"; -import { Organization } from "./networking"; +import { Organization, userListCall } from "./networking"; import { createTeamSearchFunction } from "./key_team_helpers/team_search_fn"; import { createOrgSearchFunction } from "./key_team_helpers/organization_search_fn"; interface AllKeysTableProps { @@ -34,6 +34,12 @@ interface AllKeysTableProps { // Define columns similar to our logs table +interface UserResponse { + user_id: string; + user_email: string; + user_role: string; +} + const TeamFilter = ({ teams, selectedTeam, @@ -99,6 +105,18 @@ export function AllKeysTable({ 'Team ID': '', 'Organization ID': '' }); + const [userList, setUserList] = useState([]); + + useEffect(() => { + if (accessToken) { + const user_IDs = keys.map(key => key.user_id).filter(id => id !== null); + const fetchUserList = async () => { + const userListData = await userListCall(accessToken, user_IDs, 1, 100); + setUserList(userListData.users); + }; + fetchUserList(); + } + }, [accessToken, keys]); const handleFilterChange = (newFilters: Record) => { // Update filters state @@ -163,7 +181,7 @@ export function AllKeysTable({ className="font-mono text-blue-500 bg-blue-50 hover:bg-blue-100 text-xs font-normal px-2 py-0.5 text-left overflow-hidden truncate max-w-[200px]" onClick={() => setSelectedKeyId(info.getValue() as string)} > - {info.getValue() ? `${(info.getValue() as string).slice(0, 7)}...` : "Not Set"} + {info.getValue() ? `${(info.getValue() as string).slice(0, 7)}...` : "-"} @@ -186,31 +204,28 @@ export function AllKeysTable({ { header: "Team ID", accessorKey: "team_id", - cell: (info) => {info.getValue() ? `${(info.getValue() as string).slice(0, 7)}...` : "Not Set"} + cell: (info) => {info.getValue() ? `${(info.getValue() as string).slice(0, 7)}...` : "-"} }, { header: "Key Alias", accessorKey: "key_alias", - cell: (info) => {info.getValue() ? `${(info.getValue() as string).slice(0, 7)}...` : "Not Set"} + cell: (info) => {info.getValue() ? `${(info.getValue() as string).slice(0, 7)}...` : "-"} }, { header: "Organization ID", accessorKey: "organization_id", - cell: (info) => {info.getValue() ? `${(info.getValue() as string).slice(0, 7)}...` : "Not Set"} + cell: (info) => info.getValue() ? info.renderValue() : "-", + }, + { + header: "User Email", + accessorKey: "user_id", + cell: (info) => { + const userId = info.getValue() as string; + const user = userList.find(u => u.user_id === userId); + return user?.user_email ? user.user_email : "-"; + }, }, - // { - // header: "User Email", - // accessorKey: "user_id", - // cell: (info) => { - // const userId = info.getValue() as string; - // return userId ? ( - // - // {userId.slice(0, 5)}... - // - // ) : "Not Set"; - // }, - // }, { header: "User ID", accessorKey: "user_id", @@ -218,9 +233,9 @@ export function AllKeysTable({ const userId = info.getValue() as string; return userId ? ( - {userId.slice(0, 5)}... + {userId.slice(0, 7)}... - ) : "Not Set"; + ) : "-"; }, }, { diff --git a/ui/litellm-dashboard/src/components/create_user_button.tsx b/ui/litellm-dashboard/src/components/create_user_button.tsx index 87cb8cc3d1..6cae581af5 100644 --- a/ui/litellm-dashboard/src/components/create_user_button.tsx +++ b/ui/litellm-dashboard/src/components/create_user_button.tsx @@ -21,6 +21,8 @@ import { } from "./networking"; import BulkCreateUsers from "./bulk_create_users_button"; const { Option } = Select; +import { Tooltip } from "antd"; +import { InfoCircleOutlined } from '@ant-design/icons'; interface CreateuserProps { userID: string; @@ -258,7 +260,15 @@ const Createuser: React.FC = ({ - + + Global Proxy Role{' '} + + + + + } + name="user_role"> {possibleUIRoles && Object.entries(possibleUIRoles).map( @@ -278,7 +288,7 @@ const Createuser: React.FC = ({ )} - +