[SSO-UI] Set new sso users as internal_view role users (#5824)

* use /user/list endpoint on admin ui

* sso insert user with role when user does not exist

* add sso sign in test

* linting fix

* rename self serve doc

* add doc for self serve flow

* test - sso sign in default values

* add test for /user/list endpoint
This commit is contained in:
Ishaan Jaff 2024-09-21 16:43:52 -07:00 committed by GitHub
parent a9caba33ef
commit d100b32573
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 404 additions and 102 deletions

View file

@ -2,7 +2,7 @@ import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Self-Serve
# Internal User Self-Serve
## Allow users to create their own keys on [Proxy UI](./ui.md).
@ -191,6 +191,29 @@ This budget only applies to personal keys created by that user - seen under `Def
This budget does not apply to keys created under non-default teams.
### Set max budget for teams
[**Go Here**](./team_budgets.md)
[**Go Here**](./team_budgets.md)
## **All Settings for Self Serve / SSO Flow**
```yaml
litellm_settings:
max_internal_user_budget: 10 # max budget for internal users
internal_user_budget_duration: "1mo" # reset every month
default_internal_user_params: # Default Params used when a new user signs in Via SSO
user_role: "internal_user" # one of "internal_user", "internal_user_viewer", "proxy_admin", "proxy_admin_viewer". New SSO users not in litellm will be created as this user
max_budget: 100 # Optional[float], optional): $100 budget for a new SSO sign in user
budget_duration: 30d # Optional[str], optional): 30 days budget_duration for a new SSO sign in user
upperbound_key_generate_params: # Upperbound for /key/generate requests when self-serve flow is on
max_budget: 100 # Optional[float], optional): upperbound of $100, for all /key/generate requests
budget_duration: "10d" # Optional[str], optional): upperbound of 10 days for budget_duration values
duration: "30d" # Optional[str], optional): upperbound of 30 days for all /key/generate requests
max_parallel_requests: 1000 # (Optional[int], optional): Max number of requests that can be made in parallel. Defaults to None.
tpm_limit: 1000 #(Optional[int], optional): Tpm limit. Defaults to None.
rpm_limit: 1000 #(Optional[int], optional): Rpm limit. Defaults to None.
```

View file

@ -257,7 +257,7 @@ s3_callback_params: Optional[Dict] = None
generic_logger_headers: Optional[Dict] = None
default_key_generate_params: Optional[Dict] = None
upperbound_key_generate_params: Optional[LiteLLM_UpperboundKeyGenerateParams] = None
default_user_params: Optional[Dict] = None
default_internal_user_params: Optional[Dict] = None
default_team_settings: Optional[List] = None
max_user_budget: Optional[float] = None
default_max_internal_user_budget: Optional[float] = None

View file

@ -8,6 +8,7 @@ These are members of a Team on LiteLLM
/user/update
/user/delete
/user/info
/user/list
"""
import asyncio
@ -298,10 +299,6 @@ async def user_info(
user_id: Optional[str] = fastapi.Query(
default=None, description="User ID in the request parameters"
),
view_all: bool = fastapi.Query(
default=False,
description="set to true to View all users. When using view_all, don't pass user_id",
),
page: Optional[int] = fastapi.Query(
default=0,
description="Page number for pagination. Only use when view_all is true",
@ -335,17 +332,6 @@ async def user_info(
## GET USER ROW ##
if user_id is not None:
user_info = await prisma_client.get_data(user_id=user_id)
elif view_all is True:
if page is None:
page = 0
if page_size is None:
page_size = 25
offset = (page) * page_size # default is 0
limit = page_size # default is 10
user_info = await prisma_client.get_data(
table_name="user", query_type="find_all", offset=offset, limit=limit
)
return user_info
else:
user_info = None
## GET ALL TEAMS ##
@ -732,16 +718,22 @@ async def user_get_requests():
tags=["Internal User management"],
dependencies=[Depends(user_api_key_auth)],
)
@router.get(
"/user/list",
tags=["Internal User management"],
dependencies=[Depends(user_api_key_auth)],
)
async def get_users(
role: str = fastapi.Query(
default=None,
description="Either 'proxy_admin', 'proxy_viewer', 'app_owner', 'app_user'",
)
role: Optional[str] = fastapi.Query(
default=None, description="Filter users by role"
),
page: int = fastapi.Query(default=1, ge=1, description="Page number"),
page_size: int = fastapi.Query(
default=25, ge=1, le=100, description="Number of items per page"
),
):
"""
[BETA] This could change without notice. Give feedback - https://github.com/BerriAI/litellm/issues
Get all users who are a specific `user_role`.
Get a paginated list of users, optionally filtered by role.
Used by the UI to populate the user lists.
@ -754,11 +746,36 @@ async def get_users(
status_code=500,
detail={"error": f"No db connected. prisma client={prisma_client}"},
)
all_users = await prisma_client.get_data(
table_name="user", query_type="find_all", key_val={"user_role": role}
# Calculate skip and take for pagination
skip = (page - 1) * page_size
take = page_size
# Prepare the query
query = {}
if role:
query["user_role"] = role
# Get total count
total_count = await prisma_client.db.litellm_usertable.count(where=query) # type: ignore
# Get paginated users
users = await prisma_client.db.litellm_usertable.find_many(
where=query, # type: ignore
skip=skip,
take=take,
)
return all_users
# Calculate total pages
total_pages = -(-total_count // page_size) # Ceiling division
return {
"users": users,
"total": total_count,
"page": page,
"page_size": page_size,
"total_pages": total_pages,
}
@router.post(

View file

@ -17,9 +17,11 @@ import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import (
LitellmUserRoles,
NewUserRequest,
ProxyErrorTypes,
ProxyException,
SSOUserDefinedValues,
UserAPIKeyAuth,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_utils.admin_ui_utils import (
@ -27,6 +29,7 @@ from litellm.proxy.common_utils.admin_ui_utils import (
html_form,
show_missing_vars_in_env,
)
from litellm.proxy.management_endpoints.internal_user_endpoints import new_user
from litellm.secret_managers.main import str_to_bool
router = APIRouter()
@ -459,7 +462,7 @@ async def auth_callback(request: Request):
if prisma_client is not None:
user_info = await prisma_client.get_data(user_id=user_id, table_name="user")
verbose_proxy_logger.debug(
f"user_info: {user_info}; litellm.default_user_params: {litellm.default_user_params}"
f"user_info: {user_info}; litellm.default_internal_user_params: {litellm.default_internal_user_params}"
)
if user_info is None:
## check if user-email in db ##
@ -487,24 +490,11 @@ async def auth_callback(request: Request):
await prisma_client.db.litellm_usertable.update_many(
where={"user_email": user_email}, data={"user_id": user_id} # type: ignore
)
elif litellm.default_user_params is not None and isinstance(
litellm.default_user_params, dict
):
user_defined_values = {
"models": litellm.default_user_params.get("models", user_id_models),
"user_id": litellm.default_user_params.get("user_id", user_id),
"user_email": litellm.default_user_params.get(
"user_email", user_email
),
"user_role": litellm.default_user_params.get("user_role", None),
"max_budget": litellm.default_user_params.get(
"max_budget", max_internal_user_budget
),
"budget_duration": litellm.default_user_params.get(
"budget_duration", internal_user_budget_duration
),
}
else:
# user not in DB, insert User into LiteLLM DB
user_role = await insert_sso_user(
user_defined_values=user_defined_values,
)
except Exception as e:
pass
@ -513,26 +503,6 @@ async def auth_callback(request: Request):
"Unable to map user identity to known values. 'user_defined_values' is None. File an issue - https://github.com/BerriAI/litellm/issues"
)
is_internal_user = False
if (
user_defined_values["user_role"] is not None
and user_defined_values["user_role"] == LitellmUserRoles.INTERNAL_USER.value
):
is_internal_user = True
if (
is_internal_user is True
and user_defined_values["max_budget"] is None
and litellm.max_internal_user_budget is not None
):
user_defined_values["max_budget"] = litellm.max_internal_user_budget
if (
is_internal_user is True
and user_defined_values["budget_duration"] is None
and litellm.internal_user_budget_duration is not None
):
user_defined_values["budget_duration"] = litellm.internal_user_budget_duration
verbose_proxy_logger.info(
f"user_defined_values for creating ui key: {user_defined_values}"
)
@ -541,7 +511,9 @@ async def auth_callback(request: Request):
default_ui_key_values["request_type"] = "key"
response = await generate_key_helper_fn(
**default_ui_key_values, # type: ignore
table_name="key",
)
key = response["token"] # type: ignore
user_id = response["user_id"] # type: ignore
@ -549,19 +521,22 @@ async def auth_callback(request: Request):
# User_id on SSO == user_id in the LiteLLM_VerificationToken Table
assert user_id == _user_id_from_sso
litellm_dashboard_ui = "/ui/"
user_role = user_role or "app_owner"
user_role = user_role or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value
if (
os.getenv("PROXY_ADMIN_ID", None) is not None
and os.environ["PROXY_ADMIN_ID"] == user_id
):
# checks if user is admin
user_role = "app_admin"
user_role = LitellmUserRoles.PROXY_ADMIN.value
verbose_proxy_logger.debug(
f"user_role: {user_role}; ui_access_mode: {ui_access_mode}"
)
## CHECK IF ROLE ALLOWED TO USE PROXY ##
if ui_access_mode == "admin_only" and "admin" not in user_role:
if ui_access_mode == "admin_only" and (
user_role != LitellmUserRoles.PROXY_ADMIN.value
or user_role != LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value
):
verbose_proxy_logger.debug("EXCEPTION RAISED")
raise HTTPException(
status_code=401,
@ -594,6 +569,47 @@ async def auth_callback(request: Request):
return redirect_response
async def insert_sso_user(
user_defined_values: Optional[SSOUserDefinedValues] = None,
) -> str:
"""
Helper function to create a New User in LiteLLM DB after a successful SSO login
"""
verbose_proxy_logger.debug(
f"Inserting SSO user into DB. User values: {user_defined_values}"
)
if user_defined_values is None:
raise ValueError("user_defined_values is None")
if litellm.default_internal_user_params:
user_defined_values.update(litellm.default_internal_user_params) # type: ignore
# Set budget for internal users
if user_defined_values.get("user_role") == LitellmUserRoles.INTERNAL_USER.value:
if user_defined_values.get("max_budget") is None:
user_defined_values["max_budget"] = litellm.max_internal_user_budget
if user_defined_values.get("budget_duration") is None:
user_defined_values["budget_duration"] = (
litellm.internal_user_budget_duration
)
if user_defined_values["user_role"] is None:
user_defined_values["user_role"] = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
new_user_request = NewUserRequest(
user_id=user_defined_values["user_id"],
user_email=user_defined_values["user_email"],
user_role=user_defined_values["user_role"], # type: ignore
max_budget=user_defined_values["max_budget"],
budget_duration=user_defined_values["budget_duration"],
)
await new_user(data=new_user_request, user_api_key_dict=UserAPIKeyAuth())
return user_defined_values["user_role"] or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
@router.get(
"/sso/get/ui_settings",
tags=["experimental"],

View file

@ -34,7 +34,7 @@ from litellm.proxy.utils import PrismaClient
def get_new_internal_user_defaults(
user_id: str, user_email: Optional[str] = None
) -> dict:
user_info = litellm.default_user_params or {}
user_info = litellm.default_internal_user_params or {}
returned_dict: SSOUserDefinedValues = {
"models": user_info.get("models", None),
@ -277,7 +277,7 @@ def management_endpoint_wrapper(func):
@wraps(func)
async def wrapper(*args, **kwargs):
start_time = datetime.now()
_http_request: Optional[Request] = None
try:
result = await func(*args, **kwargs)
end_time = datetime.now()
@ -293,8 +293,7 @@ def management_endpoint_wrapper(func):
user_api_key_dict=user_api_key_dict,
function_name=func.__name__,
)
_http_request: Request = kwargs.get("http_request")
_http_request = kwargs.get("http_request", None)
parent_otel_span = getattr(user_api_key_dict, "parent_otel_span", None)
if parent_otel_span is not None:
from litellm.proxy.proxy_server import open_telemetry_logger
@ -315,7 +314,7 @@ def management_endpoint_wrapper(func):
end_time=end_time,
)
await open_telemetry_logger.async_management_endpoint_success_hook(
await open_telemetry_logger.async_management_endpoint_success_hook( # type: ignore
logging_payload=logging_payload,
parent_otel_span=parent_otel_span,
)
@ -344,7 +343,7 @@ def management_endpoint_wrapper(func):
from litellm.proxy.proxy_server import open_telemetry_logger
if open_telemetry_logger is not None:
_http_request: Request = kwargs.get("http_request")
_http_request = kwargs.get("http_request")
if _http_request:
_route = _http_request.url.path
_request_body: dict = await _read_request_body(
@ -359,7 +358,7 @@ def management_endpoint_wrapper(func):
exception=e,
)
await open_telemetry_logger.async_management_endpoint_failure_hook(
await open_telemetry_logger.async_management_endpoint_failure_hook( # type: ignore
logging_payload=logging_payload,
parent_otel_span=parent_otel_span,
)

View file

@ -23,13 +23,13 @@ import asyncio
import logging
import pytest
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy.management_endpoints.internal_user_endpoints import (
new_user,
user_info,
user_update,
get_users,
)
from litellm.proxy.management_endpoints.key_management_endpoints import (
delete_key_fn,
@ -322,3 +322,58 @@ async def test_regenerate_key_ui(prisma_client):
),
)
print("response from regenerate_key_fn", new_key)
@pytest.mark.asyncio
async def test_get_users(prisma_client):
"""
Tests /users/list endpoint
Admin UI calls this endpoint to list all Internal Users
"""
litellm.set_verbose = True
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
# Create some test users
test_users = [
NewUserRequest(
user_id=f"test_user_{i}",
user_role=(
LitellmUserRoles.INTERNAL_USER.value
if i % 2 == 0
else LitellmUserRoles.PROXY_ADMIN.value
),
)
for i in range(5)
]
for user in test_users:
await new_user(
user,
UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="admin",
),
)
# Test get_users without filters
result = await get_users(
role=None,
page=1,
page_size=20,
)
print("get users result", result)
assert "users" in result
for user in result["users"]:
user = user.model_dump()
assert "user_id" in user
assert "spend" in user
assert "user_email" in user
assert "user_role" in user
# Clean up test users
for user in test_users:
await prisma_client.db.litellm_usertable.delete(where={"user_id": user.user_id})

View file

@ -0,0 +1,189 @@
import pytest
from fastapi.testclient import TestClient
from fastapi import Request, Header
from unittest.mock import patch, MagicMock, AsyncMock
import sys
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm.proxy.proxy_server import app
from litellm.proxy.utils import PrismaClient, ProxyLogging
from litellm.proxy.management_endpoints.ui_sso import auth_callback
from litellm.proxy._types import LitellmUserRoles
import os
import jwt
import time
from litellm.caching import DualCache
proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache())
@pytest.fixture
def mock_env_vars(monkeypatch):
monkeypatch.setenv("GOOGLE_CLIENT_ID", "mock_google_client_id")
monkeypatch.setenv("GOOGLE_CLIENT_SECRET", "mock_google_client_secret")
monkeypatch.setenv("PROXY_BASE_URL", "http://testserver")
monkeypatch.setenv("LITELLM_MASTER_KEY", "mock_master_key")
@pytest.fixture
def prisma_client():
from litellm.proxy.proxy_cli import append_query_params
### add connection pool + pool timeout args
params = {"connection_limit": 100, "pool_timeout": 60}
database_url = os.getenv("DATABASE_URL")
modified_url = append_query_params(database_url, params)
os.environ["DATABASE_URL"] = modified_url
# Assuming DBClient is a class that needs to be instantiated
prisma_client = PrismaClient(
database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj
)
# Reset litellm.proxy.proxy_server.prisma_client to None
litellm.proxy.proxy_server.custom_db_client = None
litellm.proxy.proxy_server.litellm_proxy_budget_name = (
f"litellm-proxy-budget-{time.time()}"
)
litellm.proxy.proxy_server.user_custom_key_generate = None
return prisma_client
@patch("fastapi_sso.sso.google.GoogleSSO")
@pytest.mark.asyncio
async def test_auth_callback_new_user(mock_google_sso, mock_env_vars, prisma_client):
"""
Tests that a new SSO Sign In user is by default given an 'INTERNAL_USER_VIEW_ONLY' role
"""
import uuid
# Generate a unique user ID
unique_user_id = str(uuid.uuid4())
try:
# Set up the prisma client
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
await litellm.proxy.proxy_server.prisma_client.connect()
# Set up the master key
litellm.proxy.proxy_server.master_key = "mock_master_key"
# Mock the GoogleSSO verify_and_process method
mock_sso_result = MagicMock()
mock_sso_result.email = "newuser@example.com"
mock_sso_result.id = unique_user_id
mock_google_sso.return_value.verify_and_process = AsyncMock(
return_value=mock_sso_result
)
# Create a mock Request object
mock_request = Request(
scope={
"type": "http",
"method": "GET",
"scheme": "http",
"server": ("testserver", 80),
"path": "/sso/callback",
"query_string": b"",
"headers": {},
}
)
# Call the auth_callback function directly
response = await auth_callback(request=mock_request)
# Assert the response
assert response.status_code == 303
assert response.headers["location"].startswith(f"/ui/?userID={unique_user_id}")
# Verify that the user was added to the database
user = await prisma_client.db.litellm_usertable.find_first(
where={"user_id": unique_user_id}
)
print("inserted user from SSO", user)
assert user is not None
assert user.user_email == "newuser@example.com"
assert user.user_role == LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
finally:
# Clean up: Delete the user from the database
await prisma_client.db.litellm_usertable.delete(
where={"user_id": unique_user_id}
)
@patch("fastapi_sso.sso.google.GoogleSSO")
@pytest.mark.asyncio
async def test_auth_callback_new_user_with_sso_default(
mock_google_sso, mock_env_vars, prisma_client
):
"""
When litellm_settings.default_internal_user_params.user_role = 'INTERNAL_USER'
Tests that a new SSO Sign In user is by default given an 'INTERNAL_USER' role
"""
import uuid
# Generate a unique user ID
unique_user_id = str(uuid.uuid4())
try:
# Set up the prisma client
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
litellm.default_internal_user_params = {
"user_role": LitellmUserRoles.INTERNAL_USER.value
}
await litellm.proxy.proxy_server.prisma_client.connect()
# Set up the master key
litellm.proxy.proxy_server.master_key = "mock_master_key"
# Mock the GoogleSSO verify_and_process method
mock_sso_result = MagicMock()
mock_sso_result.email = "newuser@example.com"
mock_sso_result.id = unique_user_id
mock_google_sso.return_value.verify_and_process = AsyncMock(
return_value=mock_sso_result
)
# Create a mock Request object
mock_request = Request(
scope={
"type": "http",
"method": "GET",
"scheme": "http",
"server": ("testserver", 80),
"path": "/sso/callback",
"query_string": b"",
"headers": {},
}
)
# Call the auth_callback function directly
response = await auth_callback(request=mock_request)
# Assert the response
assert response.status_code == 303
assert response.headers["location"].startswith(f"/ui/?userID={unique_user_id}")
# Verify that the user was added to the database
user = await prisma_client.db.litellm_usertable.find_first(
where={"user_id": unique_user_id}
)
print("inserted user from SSO", user)
assert user is not None
assert user.user_email == "newuser@example.com"
assert user.user_role == LitellmUserRoles.INTERNAL_USER
finally:
# Clean up: Delete the user from the database
await prisma_client.db.litellm_usertable.delete(
where={"user_id": unique_user_id}
)
litellm.default_internal_user_params = None

View file

@ -215,10 +215,13 @@ const AdminPanel: React.FC<AdminPanelProps> = ({
const fetchProxyAdminInfo = async () => {
if (accessToken != null) {
const combinedList: any[] = [];
const proxyViewers = await userGetAllUsersCall(
const response = await userGetAllUsersCall(
accessToken,
"proxy_admin_viewer"
);
console.log("proxy admin viewer response: ", response);
const proxyViewers: User[] = response["users"];
console.log(`proxy viewers response: ${proxyViewers}`);
proxyViewers.forEach((viewer: User) => {
combinedList.push({
user_role: viewer.user_role,
@ -229,11 +232,13 @@ const AdminPanel: React.FC<AdminPanelProps> = ({
console.log(`proxy viewers: ${proxyViewers}`);
const proxyAdmins = await userGetAllUsersCall(
const response2 = await userGetAllUsersCall(
accessToken,
"proxy_admin"
);
const proxyAdmins: User[] = response2["users"];
proxyAdmins.forEach((admins: User) => {
combinedList.push({
user_role: admins.user_role,

View file

@ -560,24 +560,24 @@ export const userInfoCall = async (
page_size: number | null
) => {
try {
let url = proxyBaseUrl ? `${proxyBaseUrl}/user/info` : `/user/info`;
if (userRole == "App Owner" && userID) {
url = `${url}?user_id=${userID}`;
let url: string;
if (viewAll) {
// Use /user/list endpoint when viewAll is true
url = proxyBaseUrl ? `${proxyBaseUrl}/user/list` : `/user/list`;
const queryParams = new URLSearchParams();
if (page != null) queryParams.append('page', page.toString());
if (page_size != null) queryParams.append('page_size', page_size.toString());
url += `?${queryParams.toString()}`;
} else {
// Use /user/info endpoint for individual user info
url = proxyBaseUrl ? `${proxyBaseUrl}/user/info` : `/user/info`;
if (userID) {
url += `?user_id=${userID}`;
}
}
if (userRole == "App User" && userID) {
url = `${url}?user_id=${userID}`;
}
if (
(userRole == "Internal User" || userRole == "Internal Viewer") &&
userID
) {
url = `${url}?user_id=${userID}`;
}
console.log("in userInfoCall viewAll=", viewAll);
if (viewAll && page_size && page != null && page != undefined) {
url = `${url}?view_all=true&page=${page}&page_size=${page_size}`;
}
//message.info("Requesting user data");
console.log("Requesting user data from:", url);
const response = await fetch(url, {
method: "GET",
headers: {
@ -594,11 +594,9 @@ export const userInfoCall = async (
const data = await response.json();
console.log("API Response:", data);
//message.info("Received user data");
return data;
// Handle success - you might want to update some state or UI based on the created key
} catch (error) {
console.error("Failed to create key:", error);
console.error("Failed to fetch user data:", error);
throw error;
}
};

View file

@ -69,7 +69,7 @@ const ViewUserDashboard: React.FC<ViewUserDashboardProps> = ({
}) => {
const [userData, setUserData] = useState<null | any[]>(null);
const [endUsers, setEndUsers] = useState<null | any[]>(null);
const [currentPage, setCurrentPage] = useState(0);
const [currentPage, setCurrentPage] = useState(1);
const [openDialogId, setOpenDialogId] = React.useState<null | number>(null);
const [selectedItem, setSelectedItem] = useState<null | any>(null);
const [editModalVisible, setEditModalVisible] = useState(false);
@ -124,7 +124,7 @@ const ViewUserDashboard: React.FC<ViewUserDashboardProps> = ({
defaultPageSize
);
console.log("user data response:", userDataResponse);
setUserData(userDataResponse);
setUserData(userDataResponse.users);
const availableUserRoles = await getPossibleUserRoles(accessToken);
setPossibleUIRoles(availableUserRoles);