forked from phoenix/litellm-mirror
[SSO-UI] Set new sso users as internal_view role users (#5824)
* use /user/list endpoint on admin ui * sso insert user with role when user does not exist * add sso sign in test * linting fix * rename self serve doc * add doc for self serve flow * test - sso sign in default values * add test for /user/list endpoint
This commit is contained in:
parent
a9caba33ef
commit
d100b32573
10 changed files with 404 additions and 102 deletions
|
@ -2,7 +2,7 @@ import Image from '@theme/IdealImage';
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Self-Serve
|
||||
# Internal User Self-Serve
|
||||
|
||||
## Allow users to create their own keys on [Proxy UI](./ui.md).
|
||||
|
||||
|
@ -191,6 +191,29 @@ This budget only applies to personal keys created by that user - seen under `Def
|
|||
|
||||
This budget does not apply to keys created under non-default teams.
|
||||
|
||||
|
||||
### Set max budget for teams
|
||||
|
||||
[**Go Here**](./team_budgets.md)
|
||||
[**Go Here**](./team_budgets.md)
|
||||
|
||||
## **All Settings for Self Serve / SSO Flow**
|
||||
|
||||
```yaml
|
||||
litellm_settings:
|
||||
max_internal_user_budget: 10 # max budget for internal users
|
||||
internal_user_budget_duration: "1mo" # reset every month
|
||||
|
||||
default_internal_user_params: # Default Params used when a new user signs in Via SSO
|
||||
user_role: "internal_user" # one of "internal_user", "internal_user_viewer", "proxy_admin", "proxy_admin_viewer". New SSO users not in litellm will be created as this user
|
||||
max_budget: 100 # Optional[float], optional): $100 budget for a new SSO sign in user
|
||||
budget_duration: 30d # Optional[str], optional): 30 days budget_duration for a new SSO sign in user
|
||||
|
||||
|
||||
upperbound_key_generate_params: # Upperbound for /key/generate requests when self-serve flow is on
|
||||
max_budget: 100 # Optional[float], optional): upperbound of $100, for all /key/generate requests
|
||||
budget_duration: "10d" # Optional[str], optional): upperbound of 10 days for budget_duration values
|
||||
duration: "30d" # Optional[str], optional): upperbound of 30 days for all /key/generate requests
|
||||
max_parallel_requests: 1000 # (Optional[int], optional): Max number of requests that can be made in parallel. Defaults to None.
|
||||
tpm_limit: 1000 #(Optional[int], optional): Tpm limit. Defaults to None.
|
||||
rpm_limit: 1000 #(Optional[int], optional): Rpm limit. Defaults to None.
|
||||
```
|
|
@ -257,7 +257,7 @@ s3_callback_params: Optional[Dict] = None
|
|||
generic_logger_headers: Optional[Dict] = None
|
||||
default_key_generate_params: Optional[Dict] = None
|
||||
upperbound_key_generate_params: Optional[LiteLLM_UpperboundKeyGenerateParams] = None
|
||||
default_user_params: Optional[Dict] = None
|
||||
default_internal_user_params: Optional[Dict] = None
|
||||
default_team_settings: Optional[List] = None
|
||||
max_user_budget: Optional[float] = None
|
||||
default_max_internal_user_budget: Optional[float] = None
|
||||
|
|
|
@ -8,6 +8,7 @@ These are members of a Team on LiteLLM
|
|||
/user/update
|
||||
/user/delete
|
||||
/user/info
|
||||
/user/list
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
@ -298,10 +299,6 @@ async def user_info(
|
|||
user_id: Optional[str] = fastapi.Query(
|
||||
default=None, description="User ID in the request parameters"
|
||||
),
|
||||
view_all: bool = fastapi.Query(
|
||||
default=False,
|
||||
description="set to true to View all users. When using view_all, don't pass user_id",
|
||||
),
|
||||
page: Optional[int] = fastapi.Query(
|
||||
default=0,
|
||||
description="Page number for pagination. Only use when view_all is true",
|
||||
|
@ -335,17 +332,6 @@ async def user_info(
|
|||
## GET USER ROW ##
|
||||
if user_id is not None:
|
||||
user_info = await prisma_client.get_data(user_id=user_id)
|
||||
elif view_all is True:
|
||||
if page is None:
|
||||
page = 0
|
||||
if page_size is None:
|
||||
page_size = 25
|
||||
offset = (page) * page_size # default is 0
|
||||
limit = page_size # default is 10
|
||||
user_info = await prisma_client.get_data(
|
||||
table_name="user", query_type="find_all", offset=offset, limit=limit
|
||||
)
|
||||
return user_info
|
||||
else:
|
||||
user_info = None
|
||||
## GET ALL TEAMS ##
|
||||
|
@ -732,16 +718,22 @@ async def user_get_requests():
|
|||
tags=["Internal User management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
@router.get(
|
||||
"/user/list",
|
||||
tags=["Internal User management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def get_users(
|
||||
role: str = fastapi.Query(
|
||||
default=None,
|
||||
description="Either 'proxy_admin', 'proxy_viewer', 'app_owner', 'app_user'",
|
||||
)
|
||||
role: Optional[str] = fastapi.Query(
|
||||
default=None, description="Filter users by role"
|
||||
),
|
||||
page: int = fastapi.Query(default=1, ge=1, description="Page number"),
|
||||
page_size: int = fastapi.Query(
|
||||
default=25, ge=1, le=100, description="Number of items per page"
|
||||
),
|
||||
):
|
||||
"""
|
||||
[BETA] This could change without notice. Give feedback - https://github.com/BerriAI/litellm/issues
|
||||
|
||||
Get all users who are a specific `user_role`.
|
||||
Get a paginated list of users, optionally filtered by role.
|
||||
|
||||
Used by the UI to populate the user lists.
|
||||
|
||||
|
@ -754,11 +746,36 @@ async def get_users(
|
|||
status_code=500,
|
||||
detail={"error": f"No db connected. prisma client={prisma_client}"},
|
||||
)
|
||||
all_users = await prisma_client.get_data(
|
||||
table_name="user", query_type="find_all", key_val={"user_role": role}
|
||||
|
||||
# Calculate skip and take for pagination
|
||||
skip = (page - 1) * page_size
|
||||
take = page_size
|
||||
|
||||
# Prepare the query
|
||||
query = {}
|
||||
if role:
|
||||
query["user_role"] = role
|
||||
|
||||
# Get total count
|
||||
total_count = await prisma_client.db.litellm_usertable.count(where=query) # type: ignore
|
||||
|
||||
# Get paginated users
|
||||
users = await prisma_client.db.litellm_usertable.find_many(
|
||||
where=query, # type: ignore
|
||||
skip=skip,
|
||||
take=take,
|
||||
)
|
||||
|
||||
return all_users
|
||||
# Calculate total pages
|
||||
total_pages = -(-total_count // page_size) # Ceiling division
|
||||
|
||||
return {
|
||||
"users": users,
|
||||
"total": total_count,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total_pages": total_pages,
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
|
|
|
@ -17,9 +17,11 @@ import litellm
|
|||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import (
|
||||
LitellmUserRoles,
|
||||
NewUserRequest,
|
||||
ProxyErrorTypes,
|
||||
ProxyException,
|
||||
SSOUserDefinedValues,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_utils.admin_ui_utils import (
|
||||
|
@ -27,6 +29,7 @@ from litellm.proxy.common_utils.admin_ui_utils import (
|
|||
html_form,
|
||||
show_missing_vars_in_env,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.internal_user_endpoints import new_user
|
||||
from litellm.secret_managers.main import str_to_bool
|
||||
|
||||
router = APIRouter()
|
||||
|
@ -459,7 +462,7 @@ async def auth_callback(request: Request):
|
|||
if prisma_client is not None:
|
||||
user_info = await prisma_client.get_data(user_id=user_id, table_name="user")
|
||||
verbose_proxy_logger.debug(
|
||||
f"user_info: {user_info}; litellm.default_user_params: {litellm.default_user_params}"
|
||||
f"user_info: {user_info}; litellm.default_internal_user_params: {litellm.default_internal_user_params}"
|
||||
)
|
||||
if user_info is None:
|
||||
## check if user-email in db ##
|
||||
|
@ -487,24 +490,11 @@ async def auth_callback(request: Request):
|
|||
await prisma_client.db.litellm_usertable.update_many(
|
||||
where={"user_email": user_email}, data={"user_id": user_id} # type: ignore
|
||||
)
|
||||
elif litellm.default_user_params is not None and isinstance(
|
||||
litellm.default_user_params, dict
|
||||
):
|
||||
user_defined_values = {
|
||||
"models": litellm.default_user_params.get("models", user_id_models),
|
||||
"user_id": litellm.default_user_params.get("user_id", user_id),
|
||||
"user_email": litellm.default_user_params.get(
|
||||
"user_email", user_email
|
||||
),
|
||||
"user_role": litellm.default_user_params.get("user_role", None),
|
||||
"max_budget": litellm.default_user_params.get(
|
||||
"max_budget", max_internal_user_budget
|
||||
),
|
||||
"budget_duration": litellm.default_user_params.get(
|
||||
"budget_duration", internal_user_budget_duration
|
||||
),
|
||||
}
|
||||
|
||||
else:
|
||||
# user not in DB, insert User into LiteLLM DB
|
||||
user_role = await insert_sso_user(
|
||||
user_defined_values=user_defined_values,
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
@ -513,26 +503,6 @@ async def auth_callback(request: Request):
|
|||
"Unable to map user identity to known values. 'user_defined_values' is None. File an issue - https://github.com/BerriAI/litellm/issues"
|
||||
)
|
||||
|
||||
is_internal_user = False
|
||||
if (
|
||||
user_defined_values["user_role"] is not None
|
||||
and user_defined_values["user_role"] == LitellmUserRoles.INTERNAL_USER.value
|
||||
):
|
||||
is_internal_user = True
|
||||
if (
|
||||
is_internal_user is True
|
||||
and user_defined_values["max_budget"] is None
|
||||
and litellm.max_internal_user_budget is not None
|
||||
):
|
||||
user_defined_values["max_budget"] = litellm.max_internal_user_budget
|
||||
|
||||
if (
|
||||
is_internal_user is True
|
||||
and user_defined_values["budget_duration"] is None
|
||||
and litellm.internal_user_budget_duration is not None
|
||||
):
|
||||
user_defined_values["budget_duration"] = litellm.internal_user_budget_duration
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
f"user_defined_values for creating ui key: {user_defined_values}"
|
||||
)
|
||||
|
@ -541,7 +511,9 @@ async def auth_callback(request: Request):
|
|||
default_ui_key_values["request_type"] = "key"
|
||||
response = await generate_key_helper_fn(
|
||||
**default_ui_key_values, # type: ignore
|
||||
table_name="key",
|
||||
)
|
||||
|
||||
key = response["token"] # type: ignore
|
||||
user_id = response["user_id"] # type: ignore
|
||||
|
||||
|
@ -549,19 +521,22 @@ async def auth_callback(request: Request):
|
|||
# User_id on SSO == user_id in the LiteLLM_VerificationToken Table
|
||||
assert user_id == _user_id_from_sso
|
||||
litellm_dashboard_ui = "/ui/"
|
||||
user_role = user_role or "app_owner"
|
||||
user_role = user_role or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value
|
||||
if (
|
||||
os.getenv("PROXY_ADMIN_ID", None) is not None
|
||||
and os.environ["PROXY_ADMIN_ID"] == user_id
|
||||
):
|
||||
# checks if user is admin
|
||||
user_role = "app_admin"
|
||||
user_role = LitellmUserRoles.PROXY_ADMIN.value
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"user_role: {user_role}; ui_access_mode: {ui_access_mode}"
|
||||
)
|
||||
## CHECK IF ROLE ALLOWED TO USE PROXY ##
|
||||
if ui_access_mode == "admin_only" and "admin" not in user_role:
|
||||
if ui_access_mode == "admin_only" and (
|
||||
user_role != LitellmUserRoles.PROXY_ADMIN.value
|
||||
or user_role != LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value
|
||||
):
|
||||
verbose_proxy_logger.debug("EXCEPTION RAISED")
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
|
@ -594,6 +569,47 @@ async def auth_callback(request: Request):
|
|||
return redirect_response
|
||||
|
||||
|
||||
async def insert_sso_user(
|
||||
user_defined_values: Optional[SSOUserDefinedValues] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Helper function to create a New User in LiteLLM DB after a successful SSO login
|
||||
"""
|
||||
verbose_proxy_logger.debug(
|
||||
f"Inserting SSO user into DB. User values: {user_defined_values}"
|
||||
)
|
||||
|
||||
if user_defined_values is None:
|
||||
raise ValueError("user_defined_values is None")
|
||||
|
||||
if litellm.default_internal_user_params:
|
||||
user_defined_values.update(litellm.default_internal_user_params) # type: ignore
|
||||
|
||||
# Set budget for internal users
|
||||
if user_defined_values.get("user_role") == LitellmUserRoles.INTERNAL_USER.value:
|
||||
if user_defined_values.get("max_budget") is None:
|
||||
user_defined_values["max_budget"] = litellm.max_internal_user_budget
|
||||
if user_defined_values.get("budget_duration") is None:
|
||||
user_defined_values["budget_duration"] = (
|
||||
litellm.internal_user_budget_duration
|
||||
)
|
||||
|
||||
if user_defined_values["user_role"] is None:
|
||||
user_defined_values["user_role"] = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
||||
|
||||
new_user_request = NewUserRequest(
|
||||
user_id=user_defined_values["user_id"],
|
||||
user_email=user_defined_values["user_email"],
|
||||
user_role=user_defined_values["user_role"], # type: ignore
|
||||
max_budget=user_defined_values["max_budget"],
|
||||
budget_duration=user_defined_values["budget_duration"],
|
||||
)
|
||||
|
||||
await new_user(data=new_user_request, user_api_key_dict=UserAPIKeyAuth())
|
||||
|
||||
return user_defined_values["user_role"] or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sso/get/ui_settings",
|
||||
tags=["experimental"],
|
||||
|
|
|
@ -34,7 +34,7 @@ from litellm.proxy.utils import PrismaClient
|
|||
def get_new_internal_user_defaults(
|
||||
user_id: str, user_email: Optional[str] = None
|
||||
) -> dict:
|
||||
user_info = litellm.default_user_params or {}
|
||||
user_info = litellm.default_internal_user_params or {}
|
||||
|
||||
returned_dict: SSOUserDefinedValues = {
|
||||
"models": user_info.get("models", None),
|
||||
|
@ -277,7 +277,7 @@ def management_endpoint_wrapper(func):
|
|||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
start_time = datetime.now()
|
||||
|
||||
_http_request: Optional[Request] = None
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
end_time = datetime.now()
|
||||
|
@ -293,8 +293,7 @@ def management_endpoint_wrapper(func):
|
|||
user_api_key_dict=user_api_key_dict,
|
||||
function_name=func.__name__,
|
||||
)
|
||||
|
||||
_http_request: Request = kwargs.get("http_request")
|
||||
_http_request = kwargs.get("http_request", None)
|
||||
parent_otel_span = getattr(user_api_key_dict, "parent_otel_span", None)
|
||||
if parent_otel_span is not None:
|
||||
from litellm.proxy.proxy_server import open_telemetry_logger
|
||||
|
@ -315,7 +314,7 @@ def management_endpoint_wrapper(func):
|
|||
end_time=end_time,
|
||||
)
|
||||
|
||||
await open_telemetry_logger.async_management_endpoint_success_hook(
|
||||
await open_telemetry_logger.async_management_endpoint_success_hook( # type: ignore
|
||||
logging_payload=logging_payload,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
|
@ -344,7 +343,7 @@ def management_endpoint_wrapper(func):
|
|||
from litellm.proxy.proxy_server import open_telemetry_logger
|
||||
|
||||
if open_telemetry_logger is not None:
|
||||
_http_request: Request = kwargs.get("http_request")
|
||||
_http_request = kwargs.get("http_request")
|
||||
if _http_request:
|
||||
_route = _http_request.url.path
|
||||
_request_body: dict = await _read_request_body(
|
||||
|
@ -359,7 +358,7 @@ def management_endpoint_wrapper(func):
|
|||
exception=e,
|
||||
)
|
||||
|
||||
await open_telemetry_logger.async_management_endpoint_failure_hook(
|
||||
await open_telemetry_logger.async_management_endpoint_failure_hook( # type: ignore
|
||||
logging_payload=logging_payload,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
|
|
|
@ -23,13 +23,13 @@ import asyncio
|
|||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.management_endpoints.internal_user_endpoints import (
|
||||
new_user,
|
||||
user_info,
|
||||
user_update,
|
||||
get_users,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
delete_key_fn,
|
||||
|
@ -322,3 +322,58 @@ async def test_regenerate_key_ui(prisma_client):
|
|||
),
|
||||
)
|
||||
print("response from regenerate_key_fn", new_key)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_users(prisma_client):
|
||||
"""
|
||||
Tests /users/list endpoint
|
||||
|
||||
Admin UI calls this endpoint to list all Internal Users
|
||||
"""
|
||||
litellm.set_verbose = True
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
||||
# Create some test users
|
||||
test_users = [
|
||||
NewUserRequest(
|
||||
user_id=f"test_user_{i}",
|
||||
user_role=(
|
||||
LitellmUserRoles.INTERNAL_USER.value
|
||||
if i % 2 == 0
|
||||
else LitellmUserRoles.PROXY_ADMIN.value
|
||||
),
|
||||
)
|
||||
for i in range(5)
|
||||
]
|
||||
for user in test_users:
|
||||
await new_user(
|
||||
user,
|
||||
UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="admin",
|
||||
),
|
||||
)
|
||||
|
||||
# Test get_users without filters
|
||||
result = await get_users(
|
||||
role=None,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
print("get users result", result)
|
||||
assert "users" in result
|
||||
|
||||
for user in result["users"]:
|
||||
user = user.model_dump()
|
||||
assert "user_id" in user
|
||||
assert "spend" in user
|
||||
assert "user_email" in user
|
||||
assert "user_role" in user
|
||||
|
||||
# Clean up test users
|
||||
for user in test_users:
|
||||
await prisma_client.db.litellm_usertable.delete(where={"user_id": user.user_id})
|
||||
|
|
189
tests/proxy_admin_ui_tests/test_sso_sign_in.py
Normal file
189
tests/proxy_admin_ui_tests/test_sso_sign_in.py
Normal file
|
@ -0,0 +1,189 @@
|
|||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from fastapi import Request, Header
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm.proxy.proxy_server import app
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
||||
from litellm.proxy.management_endpoints.ui_sso import auth_callback
|
||||
from litellm.proxy._types import LitellmUserRoles
|
||||
import os
|
||||
import jwt
|
||||
import time
|
||||
from litellm.caching import DualCache
|
||||
|
||||
proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_env_vars(monkeypatch):
|
||||
monkeypatch.setenv("GOOGLE_CLIENT_ID", "mock_google_client_id")
|
||||
monkeypatch.setenv("GOOGLE_CLIENT_SECRET", "mock_google_client_secret")
|
||||
monkeypatch.setenv("PROXY_BASE_URL", "http://testserver")
|
||||
monkeypatch.setenv("LITELLM_MASTER_KEY", "mock_master_key")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prisma_client():
|
||||
from litellm.proxy.proxy_cli import append_query_params
|
||||
|
||||
### add connection pool + pool timeout args
|
||||
params = {"connection_limit": 100, "pool_timeout": 60}
|
||||
database_url = os.getenv("DATABASE_URL")
|
||||
modified_url = append_query_params(database_url, params)
|
||||
os.environ["DATABASE_URL"] = modified_url
|
||||
|
||||
# Assuming DBClient is a class that needs to be instantiated
|
||||
prisma_client = PrismaClient(
|
||||
database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
|
||||
# Reset litellm.proxy.proxy_server.prisma_client to None
|
||||
litellm.proxy.proxy_server.custom_db_client = None
|
||||
litellm.proxy.proxy_server.litellm_proxy_budget_name = (
|
||||
f"litellm-proxy-budget-{time.time()}"
|
||||
)
|
||||
litellm.proxy.proxy_server.user_custom_key_generate = None
|
||||
|
||||
return prisma_client
|
||||
|
||||
|
||||
@patch("fastapi_sso.sso.google.GoogleSSO")
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_callback_new_user(mock_google_sso, mock_env_vars, prisma_client):
|
||||
"""
|
||||
Tests that a new SSO Sign In user is by default given an 'INTERNAL_USER_VIEW_ONLY' role
|
||||
"""
|
||||
import uuid
|
||||
|
||||
# Generate a unique user ID
|
||||
unique_user_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
# Set up the prisma client
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
||||
# Set up the master key
|
||||
litellm.proxy.proxy_server.master_key = "mock_master_key"
|
||||
|
||||
# Mock the GoogleSSO verify_and_process method
|
||||
mock_sso_result = MagicMock()
|
||||
mock_sso_result.email = "newuser@example.com"
|
||||
mock_sso_result.id = unique_user_id
|
||||
mock_google_sso.return_value.verify_and_process = AsyncMock(
|
||||
return_value=mock_sso_result
|
||||
)
|
||||
|
||||
# Create a mock Request object
|
||||
mock_request = Request(
|
||||
scope={
|
||||
"type": "http",
|
||||
"method": "GET",
|
||||
"scheme": "http",
|
||||
"server": ("testserver", 80),
|
||||
"path": "/sso/callback",
|
||||
"query_string": b"",
|
||||
"headers": {},
|
||||
}
|
||||
)
|
||||
|
||||
# Call the auth_callback function directly
|
||||
response = await auth_callback(request=mock_request)
|
||||
|
||||
# Assert the response
|
||||
assert response.status_code == 303
|
||||
assert response.headers["location"].startswith(f"/ui/?userID={unique_user_id}")
|
||||
|
||||
# Verify that the user was added to the database
|
||||
user = await prisma_client.db.litellm_usertable.find_first(
|
||||
where={"user_id": unique_user_id}
|
||||
)
|
||||
print("inserted user from SSO", user)
|
||||
assert user is not None
|
||||
assert user.user_email == "newuser@example.com"
|
||||
assert user.user_role == LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
||||
|
||||
finally:
|
||||
# Clean up: Delete the user from the database
|
||||
await prisma_client.db.litellm_usertable.delete(
|
||||
where={"user_id": unique_user_id}
|
||||
)
|
||||
|
||||
|
||||
@patch("fastapi_sso.sso.google.GoogleSSO")
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_callback_new_user_with_sso_default(
|
||||
mock_google_sso, mock_env_vars, prisma_client
|
||||
):
|
||||
"""
|
||||
When litellm_settings.default_internal_user_params.user_role = 'INTERNAL_USER'
|
||||
|
||||
Tests that a new SSO Sign In user is by default given an 'INTERNAL_USER' role
|
||||
"""
|
||||
import uuid
|
||||
|
||||
# Generate a unique user ID
|
||||
unique_user_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
# Set up the prisma client
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
litellm.default_internal_user_params = {
|
||||
"user_role": LitellmUserRoles.INTERNAL_USER.value
|
||||
}
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
||||
# Set up the master key
|
||||
litellm.proxy.proxy_server.master_key = "mock_master_key"
|
||||
|
||||
# Mock the GoogleSSO verify_and_process method
|
||||
mock_sso_result = MagicMock()
|
||||
mock_sso_result.email = "newuser@example.com"
|
||||
mock_sso_result.id = unique_user_id
|
||||
mock_google_sso.return_value.verify_and_process = AsyncMock(
|
||||
return_value=mock_sso_result
|
||||
)
|
||||
|
||||
# Create a mock Request object
|
||||
mock_request = Request(
|
||||
scope={
|
||||
"type": "http",
|
||||
"method": "GET",
|
||||
"scheme": "http",
|
||||
"server": ("testserver", 80),
|
||||
"path": "/sso/callback",
|
||||
"query_string": b"",
|
||||
"headers": {},
|
||||
}
|
||||
)
|
||||
|
||||
# Call the auth_callback function directly
|
||||
response = await auth_callback(request=mock_request)
|
||||
|
||||
# Assert the response
|
||||
assert response.status_code == 303
|
||||
assert response.headers["location"].startswith(f"/ui/?userID={unique_user_id}")
|
||||
|
||||
# Verify that the user was added to the database
|
||||
user = await prisma_client.db.litellm_usertable.find_first(
|
||||
where={"user_id": unique_user_id}
|
||||
)
|
||||
print("inserted user from SSO", user)
|
||||
assert user is not None
|
||||
assert user.user_email == "newuser@example.com"
|
||||
assert user.user_role == LitellmUserRoles.INTERNAL_USER
|
||||
|
||||
finally:
|
||||
# Clean up: Delete the user from the database
|
||||
await prisma_client.db.litellm_usertable.delete(
|
||||
where={"user_id": unique_user_id}
|
||||
)
|
||||
litellm.default_internal_user_params = None
|
|
@ -215,10 +215,13 @@ const AdminPanel: React.FC<AdminPanelProps> = ({
|
|||
const fetchProxyAdminInfo = async () => {
|
||||
if (accessToken != null) {
|
||||
const combinedList: any[] = [];
|
||||
const proxyViewers = await userGetAllUsersCall(
|
||||
const response = await userGetAllUsersCall(
|
||||
accessToken,
|
||||
"proxy_admin_viewer"
|
||||
);
|
||||
console.log("proxy admin viewer response: ", response);
|
||||
const proxyViewers: User[] = response["users"];
|
||||
console.log(`proxy viewers response: ${proxyViewers}`);
|
||||
proxyViewers.forEach((viewer: User) => {
|
||||
combinedList.push({
|
||||
user_role: viewer.user_role,
|
||||
|
@ -229,11 +232,13 @@ const AdminPanel: React.FC<AdminPanelProps> = ({
|
|||
|
||||
console.log(`proxy viewers: ${proxyViewers}`);
|
||||
|
||||
const proxyAdmins = await userGetAllUsersCall(
|
||||
const response2 = await userGetAllUsersCall(
|
||||
accessToken,
|
||||
"proxy_admin"
|
||||
);
|
||||
|
||||
const proxyAdmins: User[] = response2["users"];
|
||||
|
||||
proxyAdmins.forEach((admins: User) => {
|
||||
combinedList.push({
|
||||
user_role: admins.user_role,
|
||||
|
|
|
@ -560,24 +560,24 @@ export const userInfoCall = async (
|
|||
page_size: number | null
|
||||
) => {
|
||||
try {
|
||||
let url = proxyBaseUrl ? `${proxyBaseUrl}/user/info` : `/user/info`;
|
||||
if (userRole == "App Owner" && userID) {
|
||||
url = `${url}?user_id=${userID}`;
|
||||
let url: string;
|
||||
|
||||
if (viewAll) {
|
||||
// Use /user/list endpoint when viewAll is true
|
||||
url = proxyBaseUrl ? `${proxyBaseUrl}/user/list` : `/user/list`;
|
||||
const queryParams = new URLSearchParams();
|
||||
if (page != null) queryParams.append('page', page.toString());
|
||||
if (page_size != null) queryParams.append('page_size', page_size.toString());
|
||||
url += `?${queryParams.toString()}`;
|
||||
} else {
|
||||
// Use /user/info endpoint for individual user info
|
||||
url = proxyBaseUrl ? `${proxyBaseUrl}/user/info` : `/user/info`;
|
||||
if (userID) {
|
||||
url += `?user_id=${userID}`;
|
||||
}
|
||||
}
|
||||
if (userRole == "App User" && userID) {
|
||||
url = `${url}?user_id=${userID}`;
|
||||
}
|
||||
if (
|
||||
(userRole == "Internal User" || userRole == "Internal Viewer") &&
|
||||
userID
|
||||
) {
|
||||
url = `${url}?user_id=${userID}`;
|
||||
}
|
||||
console.log("in userInfoCall viewAll=", viewAll);
|
||||
if (viewAll && page_size && page != null && page != undefined) {
|
||||
url = `${url}?view_all=true&page=${page}&page_size=${page_size}`;
|
||||
}
|
||||
//message.info("Requesting user data");
|
||||
|
||||
console.log("Requesting user data from:", url);
|
||||
const response = await fetch(url, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
|
@ -594,11 +594,9 @@ export const userInfoCall = async (
|
|||
|
||||
const data = await response.json();
|
||||
console.log("API Response:", data);
|
||||
//message.info("Received user data");
|
||||
return data;
|
||||
// Handle success - you might want to update some state or UI based on the created key
|
||||
} catch (error) {
|
||||
console.error("Failed to create key:", error);
|
||||
console.error("Failed to fetch user data:", error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
|
|
@ -69,7 +69,7 @@ const ViewUserDashboard: React.FC<ViewUserDashboardProps> = ({
|
|||
}) => {
|
||||
const [userData, setUserData] = useState<null | any[]>(null);
|
||||
const [endUsers, setEndUsers] = useState<null | any[]>(null);
|
||||
const [currentPage, setCurrentPage] = useState(0);
|
||||
const [currentPage, setCurrentPage] = useState(1);
|
||||
const [openDialogId, setOpenDialogId] = React.useState<null | number>(null);
|
||||
const [selectedItem, setSelectedItem] = useState<null | any>(null);
|
||||
const [editModalVisible, setEditModalVisible] = useState(false);
|
||||
|
@ -124,7 +124,7 @@ const ViewUserDashboard: React.FC<ViewUserDashboardProps> = ({
|
|||
defaultPageSize
|
||||
);
|
||||
console.log("user data response:", userDataResponse);
|
||||
setUserData(userDataResponse);
|
||||
setUserData(userDataResponse.users);
|
||||
|
||||
const availableUserRoles = await getPossibleUserRoles(accessToken);
|
||||
setPossibleUIRoles(availableUserRoles);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue