Litellm dev 01 08 2025 p1 (#7640)

* feat(ui_sso.py): support reading team ids from sso token

* feat(ui_sso.py): working upsert sso user teams membership in litellm - if team exists

Adds user to relevant teams, if user is part of teams and team exists on litellm

* fix(ui_sso.py): safely handle add team member task

* build(ui/): support setting team id when creating team on UI

* build(ui/): teams.tsx

allow setting team id on ui

* build(circle_ci/requirements.txt): add fastapi-sso to ci/cd testing

* fix: fix linting errors
This commit is contained in:
Krish Dholakia 2025-01-08 22:08:20 -08:00 committed by GitHub
parent 6d8cfeaf14
commit b77832a793
10 changed files with 269 additions and 120 deletions

View file

@ -9,3 +9,4 @@ anthropic
orjson==3.9.15
pydantic==2.7.1
google-cloud-aiplatform==1.43.0
fastapi-sso==0.10.0

View file

@ -24,4 +24,4 @@ model_list:
custom_tokenizer:
identifier: deepseek-ai/DeepSeek-V3-Base
revision: main
auth_token: os.environ/HUGGINGFACE_API_KEY
auth_token: os.environ/HUGGINGFACE_API_KEY

View file

@ -420,6 +420,8 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
"info_routes",
]
team_id_jwt_field: Optional[str] = None
team_ids_jwt_field: Optional[str] = None
upsert_sso_user_to_team: bool = False
team_allowed_routes: List[
Literal["openai_routes", "info_routes", "management_routes"]
] = ["openai_routes", "info_routes"]

View file

@ -8,7 +8,7 @@ JWT token must have 'litellm_proxy_admin' in scope.
import json
import os
from typing import Optional, cast
from typing import List, Optional, cast
from cryptography import x509
from cryptography.hazmat.backends import default_backend
@ -59,6 +59,11 @@ class JWTHandler:
return True
return False
def get_team_ids_from_jwt(self, token: dict) -> List[str]:
if self.litellm_jwtauth.team_ids_jwt_field is not None:
return token[self.litellm_jwtauth.team_ids_jwt_field]
return []
def get_end_user_id(
self, token: dict, default_value: Optional[str]
) -> Optional[str]:

View file

@ -457,7 +457,7 @@ async def update_team(
if existing_team_row is None:
raise HTTPException(
status_code=404,
status_code=400,
detail={"error": f"Team not found, passed team_id={data.team_id}"},
)

View file

@ -0,0 +1,13 @@
"""
Types for the management endpoints
Might include fastapi/proxy requirements.txt related imports
"""
from typing import List
from fastapi_sso.sso.base import OpenID
class CustomOpenID(OpenID):
team_ids: List[str]

View file

@ -8,7 +8,7 @@ Has all /sso/* routes
import asyncio
import os
import uuid
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse
@ -18,13 +18,17 @@ from litellm._logging import verbose_proxy_logger
from litellm.constants import MAX_SPENDLOG_ROWS_TO_QUERY
from litellm.proxy._types import (
LitellmUserRoles,
Member,
NewUserRequest,
NewUserResponse,
ProxyErrorTypes,
ProxyException,
SSOUserDefinedValues,
TeamMemberAddRequest,
UserAPIKeyAuth,
)
from litellm.proxy.auth.auth_utils import _has_user_setup_sso
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_utils.admin_ui_utils import (
admin_ui_disabled,
@ -36,6 +40,8 @@ from litellm.proxy.management_endpoints.sso_helper_utils import (
check_is_admin_only_access,
has_admin_ui_access,
)
from litellm.proxy.management_endpoints.team_endpoints import team_member_add
from litellm.proxy.management_endpoints.types import CustomOpenID
from litellm.secret_managers.main import str_to_bool
if TYPE_CHECKING:
@ -221,6 +227,170 @@ async def google_login(request: Request): # noqa: PLR0915
return HTMLResponse(content=html_form, status_code=200)
def generic_response_convertor(response, jwt_handler: JWTHandler):
generic_user_id_attribute_name = os.getenv(
"GENERIC_USER_ID_ATTRIBUTE", "preferred_username"
)
generic_user_display_name_attribute_name = os.getenv(
"GENERIC_USER_DISPLAY_NAME_ATTRIBUTE", "sub"
)
generic_user_email_attribute_name = os.getenv(
"GENERIC_USER_EMAIL_ATTRIBUTE", "email"
)
generic_user_first_name_attribute_name = os.getenv(
"GENERIC_USER_FIRST_NAME_ATTRIBUTE", "first_name"
)
generic_user_last_name_attribute_name = os.getenv(
"GENERIC_USER_LAST_NAME_ATTRIBUTE", "last_name"
)
generic_provider_attribute_name = os.getenv(
"GENERIC_USER_PROVIDER_ATTRIBUTE", "provider"
)
verbose_proxy_logger.debug(
f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}"
)
return CustomOpenID(
id=response.get(generic_user_id_attribute_name),
display_name=response.get(generic_user_display_name_attribute_name),
email=response.get(generic_user_email_attribute_name),
first_name=response.get(generic_user_first_name_attribute_name),
last_name=response.get(generic_user_last_name_attribute_name),
provider=response.get(generic_provider_attribute_name),
team_ids=jwt_handler.get_team_ids_from_jwt(cast(dict, response)),
)
async def get_generic_sso_response(
request: Request,
jwt_handler: JWTHandler,
generic_client_id: str,
redirect_url: str,
) -> Optional[OpenID]:
# make generic sso provider
from fastapi_sso.sso.base import DiscoveryDocument
from fastapi_sso.sso.generic import create_provider
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
generic_authorization_endpoint = os.getenv("GENERIC_AUTHORIZATION_ENDPOINT", None)
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
generic_include_client_id = (
os.getenv("GENERIC_INCLUDE_CLIENT_ID", "false").lower() == "true"
)
if generic_client_secret is None:
raise ProxyException(
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_authorization_endpoint is None:
raise ProxyException(
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_AUTHORIZATION_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_token_endpoint is None:
raise ProxyException(
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_TOKEN_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_userinfo_endpoint is None:
raise ProxyException(
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_USERINFO_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
verbose_proxy_logger.debug(
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
)
verbose_proxy_logger.debug(
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
)
discovery = DiscoveryDocument(
authorization_endpoint=generic_authorization_endpoint,
token_endpoint=generic_token_endpoint,
userinfo_endpoint=generic_userinfo_endpoint,
)
def response_convertor(response, client):
return generic_response_convertor(
response=response,
jwt_handler=jwt_handler,
)
SSOProvider = create_provider(
name="oidc",
discovery_document=discovery,
response_convertor=response_convertor,
)
generic_sso = SSOProvider(
client_id=generic_client_id,
client_secret=generic_client_secret,
redirect_uri=redirect_url,
allow_insecure_http=True,
scope=generic_scope,
)
verbose_proxy_logger.debug("calling generic_sso.verify_and_process")
result = await generic_sso.verify_and_process(
request, params={"include_client_id": generic_include_client_id}
)
verbose_proxy_logger.debug("generic result: %s", result)
return result
async def create_team_member_add_task(team_id, user_info):
"""Create a task for adding a member to a team."""
try:
member = Member(user_id=user_info.user_id, role="user")
team_member_add_request = TeamMemberAddRequest(
member=member,
team_id=team_id,
)
return await team_member_add(
data=team_member_add_request,
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
http_request=Request(scope={"type": "http", "path": "/sso/callback"}),
)
except Exception as e:
verbose_proxy_logger.debug(
f"[Non-Blocking] Error trying to add sso user to db: {e}"
)
async def add_missing_team_member(user_info: NewUserResponse, sso_teams: List[str]):
"""
- Get missing teams (diff b/w user_info.team_ids and sso_teams)
- Add missing user to missing teams
"""
if user_info.teams is None:
return
missing_teams = set(sso_teams) - set(user_info.teams)
missing_teams_list = list(missing_teams)
tasks = []
tasks = [
create_team_member_add_task(team_id, user_info)
for team_id in missing_teams_list
]
try:
await asyncio.gather(*tasks)
except Exception as e:
verbose_proxy_logger.debug(
f"[Non-Blocking] Error trying to add sso user to db: {e}"
)
@router.get("/sso/callback", tags=["experimental"], include_in_schema=False)
async def auth_callback(request: Request): # noqa: PLR0915
"""Verify login"""
@ -229,6 +399,7 @@ async def auth_callback(request: Request): # noqa: PLR0915
)
from litellm.proxy.proxy_server import (
general_settings,
jwt_handler,
master_key,
premium_user,
prisma_client,
@ -299,116 +470,12 @@ async def auth_callback(request: Request): # noqa: PLR0915
)
result = await microsoft_sso.verify_and_process(request)
elif generic_client_id is not None:
# make generic sso provider
from fastapi_sso.sso.base import DiscoveryDocument, OpenID
from fastapi_sso.sso.generic import create_provider
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
generic_authorization_endpoint = os.getenv(
"GENERIC_AUTHORIZATION_ENDPOINT", None
result = await get_generic_sso_response(
request=request,
jwt_handler=jwt_handler,
generic_client_id=generic_client_id,
redirect_url=redirect_url,
)
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
generic_include_client_id = (
os.getenv("GENERIC_INCLUDE_CLIENT_ID", "false").lower() == "true"
)
if generic_client_secret is None:
raise ProxyException(
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_authorization_endpoint is None:
raise ProxyException(
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_AUTHORIZATION_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_token_endpoint is None:
raise ProxyException(
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_TOKEN_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_userinfo_endpoint is None:
raise ProxyException(
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_USERINFO_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
verbose_proxy_logger.debug(
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
)
verbose_proxy_logger.debug(
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
)
generic_user_id_attribute_name = os.getenv(
"GENERIC_USER_ID_ATTRIBUTE", "preferred_username"
)
generic_user_display_name_attribute_name = os.getenv(
"GENERIC_USER_DISPLAY_NAME_ATTRIBUTE", "sub"
)
generic_user_email_attribute_name = os.getenv(
"GENERIC_USER_EMAIL_ATTRIBUTE", "email"
)
generic_user_role_attribute_name = os.getenv(
"GENERIC_USER_ROLE_ATTRIBUTE", "role"
)
generic_user_first_name_attribute_name = os.getenv(
"GENERIC_USER_FIRST_NAME_ATTRIBUTE", "first_name"
)
generic_user_last_name_attribute_name = os.getenv(
"GENERIC_USER_LAST_NAME_ATTRIBUTE", "last_name"
)
generic_provider_attribute_name = os.getenv(
"GENERIC_USER_PROVIDER_ATTRIBUTE", "provider"
)
verbose_proxy_logger.debug(
f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}\n generic_user_role_attribute_name: {generic_user_role_attribute_name}"
)
discovery = DiscoveryDocument(
authorization_endpoint=generic_authorization_endpoint,
token_endpoint=generic_token_endpoint,
userinfo_endpoint=generic_userinfo_endpoint,
)
def response_convertor(response, client):
return OpenID(
id=response.get(generic_user_id_attribute_name),
display_name=response.get(generic_user_display_name_attribute_name),
email=response.get(generic_user_email_attribute_name),
first_name=response.get(generic_user_first_name_attribute_name),
last_name=response.get(generic_user_last_name_attribute_name),
provider=response.get(generic_provider_attribute_name),
)
SSOProvider = create_provider(
name="oidc",
discovery_document=discovery,
response_convertor=response_convertor,
)
generic_sso = SSOProvider(
client_id=generic_client_id,
client_secret=generic_client_secret,
redirect_uri=redirect_url,
allow_insecure_http=True,
scope=generic_scope,
)
verbose_proxy_logger.debug("calling generic_sso.verify_and_process")
result = await generic_sso.verify_and_process(
request, params={"include_client_id": generic_include_client_id}
)
verbose_proxy_logger.debug("generic result: %s", result)
# User is Authe'd in - generate key for the UI to access Proxy
user_email: Optional[str] = getattr(result, "email", None)
user_id: Optional[str] = getattr(result, "id", None) if result is not None else None
@ -428,6 +495,9 @@ async def auth_callback(request: Request): # noqa: PLR0915
# generic client id
if generic_client_id is not None and result is not None:
generic_user_role_attribute_name = os.getenv(
"GENERIC_USER_ROLE_ATTRIBUTE", "role"
)
user_id = getattr(result, "id", None)
user_email = getattr(result, "email", None)
user_role = getattr(result, generic_user_role_attribute_name, None) # type: ignore
@ -508,12 +578,21 @@ async def auth_callback(request: Request): # noqa: PLR0915
)
else:
# user not in DB, insert User into LiteLLM DB
user_role = await insert_sso_user(
user_info = await insert_sso_user(
result_openid=result,
user_defined_values=user_defined_values,
)
except Exception:
pass
user_role = (
user_info.user_role or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
)
sso_teams = getattr(result, "team_ids", [])
await add_missing_team_member(user_info=user_info, sso_teams=sso_teams)
except Exception as e:
verbose_proxy_logger.debug(
f"[Non-Blocking] Error trying to add sso user to db: {e}"
)
if user_defined_values is None:
raise Exception(
@ -588,13 +667,16 @@ async def auth_callback(request: Request): # noqa: PLR0915
async def insert_sso_user(
result_openid: Optional[OpenID],
user_defined_values: Optional[SSOUserDefinedValues] = None,
) -> str:
) -> NewUserResponse:
"""
Helper function to create a New User in LiteLLM DB after a successful SSO login
Args:
result_openid (OpenID): User information in OpenID format if the login was successful.
user_defined_values (Optional[SSOUserDefinedValues], optional): LiteLLM SSOValues / fields that were read
Returns:
Tuple[str, str]: User ID and User Role
"""
verbose_proxy_logger.debug(
f"Inserting SSO user into DB. User values: {user_defined_values}"
@ -629,9 +711,9 @@ async def insert_sso_user(
if result_openid:
new_user_request.metadata = {"auth_provider": result_openid.provider}
await new_user(data=new_user_request, user_api_key_dict=UserAPIKeyAuth())
response = await new_user(data=new_user_request, user_api_key_dict=UserAPIKeyAuth())
return user_defined_values["user_role"] or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
return response
@router.get(

View file

@ -1327,3 +1327,35 @@ async def test_get_user_info_for_proxy_admin(mock_team_data, mock_key_data):
# Verify the result structure
assert isinstance(result, UserInfoResponse)
assert len(result.keys) == 2
def test_custom_openid_response():
from litellm.proxy.management_endpoints.ui_sso import generic_response_convertor
from litellm.proxy.management_endpoints.ui_sso import JWTHandler
from litellm.proxy._types import LiteLLM_JWTAuth
from litellm.caching import DualCache
jwt_handler = JWTHandler()
jwt_handler.update_environment(
prisma_client={},
user_api_key_cache=DualCache(),
litellm_jwtauth=LiteLLM_JWTAuth(
team_ids_jwt_field="department",
),
)
response = {
"sub": "3f196e06-7484-451e-be5a-ea6c6bb86c5b",
"email_verified": True,
"name": "Krish Dholakia",
"preferred_username": "krrishd",
"given_name": "Krish",
"department": ["/test-group"],
"family_name": "Dholakia",
"email": "krrishdholakia@gmail.com",
}
resp = generic_response_convertor(
response=response,
jwt_handler=jwt_handler,
)
assert resp.team_ids == ["/test-group"]

View file

@ -33,7 +33,7 @@ const menuItems: MenuItem[] = [
{ key: "2", page: "models", label: "Models", roles: all_admin_roles },
{ key: "4", page: "usage", label: "Usage"}, // all roles
{ key: "6", page: "teams", label: "Teams" },
{ key: "17", page: "organizations", label: "Organizations" },
{ key: "17", page: "organizations", label: "Organizations", roles: all_admin_roles },
{ key: "5", page: "users", label: "Internal Users", roles: all_admin_roles },
{ key: "8", page: "settings", label: "Logging & Alerts", roles: all_admin_roles },
{ key: "9", page: "caching", label: "Caching", roles: all_admin_roles },

View file

@ -20,6 +20,7 @@ import {
Tooltip
} from "antd";
import { Select, SelectItem } from "@tremor/react";
import {
Table,
TableBody,
@ -69,6 +70,7 @@ import {
teamListCall
} from "./networking";
const Team: React.FC<TeamProps> = ({
teams,
searchParams,
@ -365,6 +367,7 @@ const Team: React.FC<TeamProps> = ({
const handleCreate = async (formValues: Record<string, any>) => {
try {
console.log(`formValues: ${JSON.stringify(formValues)}`);
if (accessToken != null) {
const newTeamAlias = formValues?.team_alias;
const existingTeamAliases = teams?.map((t) => t.team_alias) ?? [];
@ -746,6 +749,17 @@ const Team: React.FC<TeamProps> = ({
<b>Additional Settings</b>
</AccordionHeader>
<AccordionBody>
<Form.Item
label="Team ID"
name="team_id"
help="ID of the team you want to create. If not provided, it will be generated automatically."
>
<TextInput
onChange={(e) => {
e.target.value = e.target.value.trim();
}}
/>
</Form.Item>
<Form.Item
label="Organization ID"
name="organization_id"