mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
Litellm dev 01 08 2025 p1 (#7640)
* feat(ui_sso.py): support reading team ids from sso token * feat(ui_sso.py): working upsert sso user teams membership in litellm - if team exists Adds user to relevant teams, if user is part of teams and team exists on litellm * fix(ui_sso.py): safely handle add team member task * build(ui/): support setting team id when creating team on UI * build(ui/): teams.tsx allow setting team id on ui * build(circle_ci/requirements.txt): add fastapi-sso to ci/cd testing * fix: fix linting errors
This commit is contained in:
parent
6d8cfeaf14
commit
b77832a793
10 changed files with 269 additions and 120 deletions
|
@ -9,3 +9,4 @@ anthropic
|
||||||
orjson==3.9.15
|
orjson==3.9.15
|
||||||
pydantic==2.7.1
|
pydantic==2.7.1
|
||||||
google-cloud-aiplatform==1.43.0
|
google-cloud-aiplatform==1.43.0
|
||||||
|
fastapi-sso==0.10.0
|
||||||
|
|
|
@ -24,4 +24,4 @@ model_list:
|
||||||
custom_tokenizer:
|
custom_tokenizer:
|
||||||
identifier: deepseek-ai/DeepSeek-V3-Base
|
identifier: deepseek-ai/DeepSeek-V3-Base
|
||||||
revision: main
|
revision: main
|
||||||
auth_token: os.environ/HUGGINGFACE_API_KEY
|
auth_token: os.environ/HUGGINGFACE_API_KEY
|
|
@ -420,6 +420,8 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
|
||||||
"info_routes",
|
"info_routes",
|
||||||
]
|
]
|
||||||
team_id_jwt_field: Optional[str] = None
|
team_id_jwt_field: Optional[str] = None
|
||||||
|
team_ids_jwt_field: Optional[str] = None
|
||||||
|
upsert_sso_user_to_team: bool = False
|
||||||
team_allowed_routes: List[
|
team_allowed_routes: List[
|
||||||
Literal["openai_routes", "info_routes", "management_routes"]
|
Literal["openai_routes", "info_routes", "management_routes"]
|
||||||
] = ["openai_routes", "info_routes"]
|
] = ["openai_routes", "info_routes"]
|
||||||
|
|
|
@ -8,7 +8,7 @@ JWT token must have 'litellm_proxy_admin' in scope.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Optional, cast
|
from typing import List, Optional, cast
|
||||||
|
|
||||||
from cryptography import x509
|
from cryptography import x509
|
||||||
from cryptography.hazmat.backends import default_backend
|
from cryptography.hazmat.backends import default_backend
|
||||||
|
@ -59,6 +59,11 @@ class JWTHandler:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def get_team_ids_from_jwt(self, token: dict) -> List[str]:
|
||||||
|
if self.litellm_jwtauth.team_ids_jwt_field is not None:
|
||||||
|
return token[self.litellm_jwtauth.team_ids_jwt_field]
|
||||||
|
return []
|
||||||
|
|
||||||
def get_end_user_id(
|
def get_end_user_id(
|
||||||
self, token: dict, default_value: Optional[str]
|
self, token: dict, default_value: Optional[str]
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
|
|
|
@ -457,7 +457,7 @@ async def update_team(
|
||||||
|
|
||||||
if existing_team_row is None:
|
if existing_team_row is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404,
|
status_code=400,
|
||||||
detail={"error": f"Team not found, passed team_id={data.team_id}"},
|
detail={"error": f"Team not found, passed team_id={data.team_id}"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
13
litellm/proxy/management_endpoints/types.py
Normal file
13
litellm/proxy/management_endpoints/types.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
"""
|
||||||
|
Types for the management endpoints
|
||||||
|
|
||||||
|
Might include fastapi/proxy requirements.txt related imports
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from fastapi_sso.sso.base import OpenID
|
||||||
|
|
||||||
|
|
||||||
|
class CustomOpenID(OpenID):
|
||||||
|
team_ids: List[str]
|
|
@ -8,7 +8,7 @@ Has all /sso/* routes
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
|
@ -18,13 +18,17 @@ from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.constants import MAX_SPENDLOG_ROWS_TO_QUERY
|
from litellm.constants import MAX_SPENDLOG_ROWS_TO_QUERY
|
||||||
from litellm.proxy._types import (
|
from litellm.proxy._types import (
|
||||||
LitellmUserRoles,
|
LitellmUserRoles,
|
||||||
|
Member,
|
||||||
NewUserRequest,
|
NewUserRequest,
|
||||||
|
NewUserResponse,
|
||||||
ProxyErrorTypes,
|
ProxyErrorTypes,
|
||||||
ProxyException,
|
ProxyException,
|
||||||
SSOUserDefinedValues,
|
SSOUserDefinedValues,
|
||||||
|
TeamMemberAddRequest,
|
||||||
UserAPIKeyAuth,
|
UserAPIKeyAuth,
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.auth_utils import _has_user_setup_sso
|
from litellm.proxy.auth.auth_utils import _has_user_setup_sso
|
||||||
|
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
from litellm.proxy.common_utils.admin_ui_utils import (
|
from litellm.proxy.common_utils.admin_ui_utils import (
|
||||||
admin_ui_disabled,
|
admin_ui_disabled,
|
||||||
|
@ -36,6 +40,8 @@ from litellm.proxy.management_endpoints.sso_helper_utils import (
|
||||||
check_is_admin_only_access,
|
check_is_admin_only_access,
|
||||||
has_admin_ui_access,
|
has_admin_ui_access,
|
||||||
)
|
)
|
||||||
|
from litellm.proxy.management_endpoints.team_endpoints import team_member_add
|
||||||
|
from litellm.proxy.management_endpoints.types import CustomOpenID
|
||||||
from litellm.secret_managers.main import str_to_bool
|
from litellm.secret_managers.main import str_to_bool
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -221,6 +227,170 @@ async def google_login(request: Request): # noqa: PLR0915
|
||||||
return HTMLResponse(content=html_form, status_code=200)
|
return HTMLResponse(content=html_form, status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
def generic_response_convertor(response, jwt_handler: JWTHandler):
|
||||||
|
generic_user_id_attribute_name = os.getenv(
|
||||||
|
"GENERIC_USER_ID_ATTRIBUTE", "preferred_username"
|
||||||
|
)
|
||||||
|
generic_user_display_name_attribute_name = os.getenv(
|
||||||
|
"GENERIC_USER_DISPLAY_NAME_ATTRIBUTE", "sub"
|
||||||
|
)
|
||||||
|
generic_user_email_attribute_name = os.getenv(
|
||||||
|
"GENERIC_USER_EMAIL_ATTRIBUTE", "email"
|
||||||
|
)
|
||||||
|
|
||||||
|
generic_user_first_name_attribute_name = os.getenv(
|
||||||
|
"GENERIC_USER_FIRST_NAME_ATTRIBUTE", "first_name"
|
||||||
|
)
|
||||||
|
generic_user_last_name_attribute_name = os.getenv(
|
||||||
|
"GENERIC_USER_LAST_NAME_ATTRIBUTE", "last_name"
|
||||||
|
)
|
||||||
|
|
||||||
|
generic_provider_attribute_name = os.getenv(
|
||||||
|
"GENERIC_USER_PROVIDER_ATTRIBUTE", "provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return CustomOpenID(
|
||||||
|
id=response.get(generic_user_id_attribute_name),
|
||||||
|
display_name=response.get(generic_user_display_name_attribute_name),
|
||||||
|
email=response.get(generic_user_email_attribute_name),
|
||||||
|
first_name=response.get(generic_user_first_name_attribute_name),
|
||||||
|
last_name=response.get(generic_user_last_name_attribute_name),
|
||||||
|
provider=response.get(generic_provider_attribute_name),
|
||||||
|
team_ids=jwt_handler.get_team_ids_from_jwt(cast(dict, response)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_generic_sso_response(
|
||||||
|
request: Request,
|
||||||
|
jwt_handler: JWTHandler,
|
||||||
|
generic_client_id: str,
|
||||||
|
redirect_url: str,
|
||||||
|
) -> Optional[OpenID]:
|
||||||
|
# make generic sso provider
|
||||||
|
from fastapi_sso.sso.base import DiscoveryDocument
|
||||||
|
from fastapi_sso.sso.generic import create_provider
|
||||||
|
|
||||||
|
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||||||
|
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
||||||
|
generic_authorization_endpoint = os.getenv("GENERIC_AUTHORIZATION_ENDPOINT", None)
|
||||||
|
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
||||||
|
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
||||||
|
generic_include_client_id = (
|
||||||
|
os.getenv("GENERIC_INCLUDE_CLIENT_ID", "false").lower() == "true"
|
||||||
|
)
|
||||||
|
if generic_client_secret is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param="GENERIC_CLIENT_SECRET",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
if generic_authorization_endpoint is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param="GENERIC_AUTHORIZATION_ENDPOINT",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
if generic_token_endpoint is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param="GENERIC_TOKEN_ENDPOINT",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
if generic_userinfo_endpoint is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param="GENERIC_USERINFO_ENDPOINT",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
discovery = DiscoveryDocument(
|
||||||
|
authorization_endpoint=generic_authorization_endpoint,
|
||||||
|
token_endpoint=generic_token_endpoint,
|
||||||
|
userinfo_endpoint=generic_userinfo_endpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
def response_convertor(response, client):
|
||||||
|
return generic_response_convertor(
|
||||||
|
response=response,
|
||||||
|
jwt_handler=jwt_handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
SSOProvider = create_provider(
|
||||||
|
name="oidc",
|
||||||
|
discovery_document=discovery,
|
||||||
|
response_convertor=response_convertor,
|
||||||
|
)
|
||||||
|
generic_sso = SSOProvider(
|
||||||
|
client_id=generic_client_id,
|
||||||
|
client_secret=generic_client_secret,
|
||||||
|
redirect_uri=redirect_url,
|
||||||
|
allow_insecure_http=True,
|
||||||
|
scope=generic_scope,
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug("calling generic_sso.verify_and_process")
|
||||||
|
result = await generic_sso.verify_and_process(
|
||||||
|
request, params={"include_client_id": generic_include_client_id}
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug("generic result: %s", result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def create_team_member_add_task(team_id, user_info):
|
||||||
|
"""Create a task for adding a member to a team."""
|
||||||
|
try:
|
||||||
|
member = Member(user_id=user_info.user_id, role="user")
|
||||||
|
team_member_add_request = TeamMemberAddRequest(
|
||||||
|
member=member,
|
||||||
|
team_id=team_id,
|
||||||
|
)
|
||||||
|
return await team_member_add(
|
||||||
|
data=team_member_add_request,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
||||||
|
http_request=Request(scope={"type": "http", "path": "/sso/callback"}),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"[Non-Blocking] Error trying to add sso user to db: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def add_missing_team_member(user_info: NewUserResponse, sso_teams: List[str]):
|
||||||
|
"""
|
||||||
|
- Get missing teams (diff b/w user_info.team_ids and sso_teams)
|
||||||
|
- Add missing user to missing teams
|
||||||
|
"""
|
||||||
|
if user_info.teams is None:
|
||||||
|
return
|
||||||
|
missing_teams = set(sso_teams) - set(user_info.teams)
|
||||||
|
missing_teams_list = list(missing_teams)
|
||||||
|
tasks = []
|
||||||
|
tasks = [
|
||||||
|
create_team_member_add_task(team_id, user_info)
|
||||||
|
for team_id in missing_teams_list
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"[Non-Blocking] Error trying to add sso user to db: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/sso/callback", tags=["experimental"], include_in_schema=False)
|
@router.get("/sso/callback", tags=["experimental"], include_in_schema=False)
|
||||||
async def auth_callback(request: Request): # noqa: PLR0915
|
async def auth_callback(request: Request): # noqa: PLR0915
|
||||||
"""Verify login"""
|
"""Verify login"""
|
||||||
|
@ -229,6 +399,7 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
||||||
)
|
)
|
||||||
from litellm.proxy.proxy_server import (
|
from litellm.proxy.proxy_server import (
|
||||||
general_settings,
|
general_settings,
|
||||||
|
jwt_handler,
|
||||||
master_key,
|
master_key,
|
||||||
premium_user,
|
premium_user,
|
||||||
prisma_client,
|
prisma_client,
|
||||||
|
@ -299,116 +470,12 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
||||||
)
|
)
|
||||||
result = await microsoft_sso.verify_and_process(request)
|
result = await microsoft_sso.verify_and_process(request)
|
||||||
elif generic_client_id is not None:
|
elif generic_client_id is not None:
|
||||||
# make generic sso provider
|
result = await get_generic_sso_response(
|
||||||
from fastapi_sso.sso.base import DiscoveryDocument, OpenID
|
request=request,
|
||||||
from fastapi_sso.sso.generic import create_provider
|
jwt_handler=jwt_handler,
|
||||||
|
generic_client_id=generic_client_id,
|
||||||
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
redirect_url=redirect_url,
|
||||||
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
|
||||||
generic_authorization_endpoint = os.getenv(
|
|
||||||
"GENERIC_AUTHORIZATION_ENDPOINT", None
|
|
||||||
)
|
)
|
||||||
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
|
||||||
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
|
||||||
generic_include_client_id = (
|
|
||||||
os.getenv("GENERIC_INCLUDE_CLIENT_ID", "false").lower() == "true"
|
|
||||||
)
|
|
||||||
if generic_client_secret is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param="GENERIC_CLIENT_SECRET",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
if generic_authorization_endpoint is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param="GENERIC_AUTHORIZATION_ENDPOINT",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
if generic_token_endpoint is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param="GENERIC_TOKEN_ENDPOINT",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
if generic_userinfo_endpoint is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param="GENERIC_USERINFO_ENDPOINT",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
|
|
||||||
)
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
generic_user_id_attribute_name = os.getenv(
|
|
||||||
"GENERIC_USER_ID_ATTRIBUTE", "preferred_username"
|
|
||||||
)
|
|
||||||
generic_user_display_name_attribute_name = os.getenv(
|
|
||||||
"GENERIC_USER_DISPLAY_NAME_ATTRIBUTE", "sub"
|
|
||||||
)
|
|
||||||
generic_user_email_attribute_name = os.getenv(
|
|
||||||
"GENERIC_USER_EMAIL_ATTRIBUTE", "email"
|
|
||||||
)
|
|
||||||
generic_user_role_attribute_name = os.getenv(
|
|
||||||
"GENERIC_USER_ROLE_ATTRIBUTE", "role"
|
|
||||||
)
|
|
||||||
generic_user_first_name_attribute_name = os.getenv(
|
|
||||||
"GENERIC_USER_FIRST_NAME_ATTRIBUTE", "first_name"
|
|
||||||
)
|
|
||||||
generic_user_last_name_attribute_name = os.getenv(
|
|
||||||
"GENERIC_USER_LAST_NAME_ATTRIBUTE", "last_name"
|
|
||||||
)
|
|
||||||
|
|
||||||
generic_provider_attribute_name = os.getenv(
|
|
||||||
"GENERIC_USER_PROVIDER_ATTRIBUTE", "provider"
|
|
||||||
)
|
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}\n generic_user_role_attribute_name: {generic_user_role_attribute_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
discovery = DiscoveryDocument(
|
|
||||||
authorization_endpoint=generic_authorization_endpoint,
|
|
||||||
token_endpoint=generic_token_endpoint,
|
|
||||||
userinfo_endpoint=generic_userinfo_endpoint,
|
|
||||||
)
|
|
||||||
|
|
||||||
def response_convertor(response, client):
|
|
||||||
return OpenID(
|
|
||||||
id=response.get(generic_user_id_attribute_name),
|
|
||||||
display_name=response.get(generic_user_display_name_attribute_name),
|
|
||||||
email=response.get(generic_user_email_attribute_name),
|
|
||||||
first_name=response.get(generic_user_first_name_attribute_name),
|
|
||||||
last_name=response.get(generic_user_last_name_attribute_name),
|
|
||||||
provider=response.get(generic_provider_attribute_name),
|
|
||||||
)
|
|
||||||
|
|
||||||
SSOProvider = create_provider(
|
|
||||||
name="oidc",
|
|
||||||
discovery_document=discovery,
|
|
||||||
response_convertor=response_convertor,
|
|
||||||
)
|
|
||||||
generic_sso = SSOProvider(
|
|
||||||
client_id=generic_client_id,
|
|
||||||
client_secret=generic_client_secret,
|
|
||||||
redirect_uri=redirect_url,
|
|
||||||
allow_insecure_http=True,
|
|
||||||
scope=generic_scope,
|
|
||||||
)
|
|
||||||
verbose_proxy_logger.debug("calling generic_sso.verify_and_process")
|
|
||||||
result = await generic_sso.verify_and_process(
|
|
||||||
request, params={"include_client_id": generic_include_client_id}
|
|
||||||
)
|
|
||||||
verbose_proxy_logger.debug("generic result: %s", result)
|
|
||||||
|
|
||||||
# User is Authe'd in - generate key for the UI to access Proxy
|
# User is Authe'd in - generate key for the UI to access Proxy
|
||||||
user_email: Optional[str] = getattr(result, "email", None)
|
user_email: Optional[str] = getattr(result, "email", None)
|
||||||
user_id: Optional[str] = getattr(result, "id", None) if result is not None else None
|
user_id: Optional[str] = getattr(result, "id", None) if result is not None else None
|
||||||
|
@ -428,6 +495,9 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
||||||
|
|
||||||
# generic client id
|
# generic client id
|
||||||
if generic_client_id is not None and result is not None:
|
if generic_client_id is not None and result is not None:
|
||||||
|
generic_user_role_attribute_name = os.getenv(
|
||||||
|
"GENERIC_USER_ROLE_ATTRIBUTE", "role"
|
||||||
|
)
|
||||||
user_id = getattr(result, "id", None)
|
user_id = getattr(result, "id", None)
|
||||||
user_email = getattr(result, "email", None)
|
user_email = getattr(result, "email", None)
|
||||||
user_role = getattr(result, generic_user_role_attribute_name, None) # type: ignore
|
user_role = getattr(result, generic_user_role_attribute_name, None) # type: ignore
|
||||||
|
@ -508,12 +578,21 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# user not in DB, insert User into LiteLLM DB
|
# user not in DB, insert User into LiteLLM DB
|
||||||
user_role = await insert_sso_user(
|
user_info = await insert_sso_user(
|
||||||
result_openid=result,
|
result_openid=result,
|
||||||
user_defined_values=user_defined_values,
|
user_defined_values=user_defined_values,
|
||||||
)
|
)
|
||||||
except Exception:
|
|
||||||
pass
|
user_role = (
|
||||||
|
user_info.user_role or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
||||||
|
)
|
||||||
|
sso_teams = getattr(result, "team_ids", [])
|
||||||
|
await add_missing_team_member(user_info=user_info, sso_teams=sso_teams)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"[Non-Blocking] Error trying to add sso user to db: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
if user_defined_values is None:
|
if user_defined_values is None:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
@ -588,13 +667,16 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
||||||
async def insert_sso_user(
|
async def insert_sso_user(
|
||||||
result_openid: Optional[OpenID],
|
result_openid: Optional[OpenID],
|
||||||
user_defined_values: Optional[SSOUserDefinedValues] = None,
|
user_defined_values: Optional[SSOUserDefinedValues] = None,
|
||||||
) -> str:
|
) -> NewUserResponse:
|
||||||
"""
|
"""
|
||||||
Helper function to create a New User in LiteLLM DB after a successful SSO login
|
Helper function to create a New User in LiteLLM DB after a successful SSO login
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
result_openid (OpenID): User information in OpenID format if the login was successful.
|
result_openid (OpenID): User information in OpenID format if the login was successful.
|
||||||
user_defined_values (Optional[SSOUserDefinedValues], optional): LiteLLM SSOValues / fields that were read
|
user_defined_values (Optional[SSOUserDefinedValues], optional): LiteLLM SSOValues / fields that were read
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[str, str]: User ID and User Role
|
||||||
"""
|
"""
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"Inserting SSO user into DB. User values: {user_defined_values}"
|
f"Inserting SSO user into DB. User values: {user_defined_values}"
|
||||||
|
@ -629,9 +711,9 @@ async def insert_sso_user(
|
||||||
if result_openid:
|
if result_openid:
|
||||||
new_user_request.metadata = {"auth_provider": result_openid.provider}
|
new_user_request.metadata = {"auth_provider": result_openid.provider}
|
||||||
|
|
||||||
await new_user(data=new_user_request, user_api_key_dict=UserAPIKeyAuth())
|
response = await new_user(data=new_user_request, user_api_key_dict=UserAPIKeyAuth())
|
||||||
|
|
||||||
return user_defined_values["user_role"] or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
return response
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
|
|
|
@ -1327,3 +1327,35 @@ async def test_get_user_info_for_proxy_admin(mock_team_data, mock_key_data):
|
||||||
# Verify the result structure
|
# Verify the result structure
|
||||||
assert isinstance(result, UserInfoResponse)
|
assert isinstance(result, UserInfoResponse)
|
||||||
assert len(result.keys) == 2
|
assert len(result.keys) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_openid_response():
|
||||||
|
from litellm.proxy.management_endpoints.ui_sso import generic_response_convertor
|
||||||
|
from litellm.proxy.management_endpoints.ui_sso import JWTHandler
|
||||||
|
from litellm.proxy._types import LiteLLM_JWTAuth
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
|
||||||
|
jwt_handler = JWTHandler()
|
||||||
|
jwt_handler.update_environment(
|
||||||
|
prisma_client={},
|
||||||
|
user_api_key_cache=DualCache(),
|
||||||
|
litellm_jwtauth=LiteLLM_JWTAuth(
|
||||||
|
team_ids_jwt_field="department",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
response = {
|
||||||
|
"sub": "3f196e06-7484-451e-be5a-ea6c6bb86c5b",
|
||||||
|
"email_verified": True,
|
||||||
|
"name": "Krish Dholakia",
|
||||||
|
"preferred_username": "krrishd",
|
||||||
|
"given_name": "Krish",
|
||||||
|
"department": ["/test-group"],
|
||||||
|
"family_name": "Dholakia",
|
||||||
|
"email": "krrishdholakia@gmail.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = generic_response_convertor(
|
||||||
|
response=response,
|
||||||
|
jwt_handler=jwt_handler,
|
||||||
|
)
|
||||||
|
assert resp.team_ids == ["/test-group"]
|
||||||
|
|
|
@ -33,7 +33,7 @@ const menuItems: MenuItem[] = [
|
||||||
{ key: "2", page: "models", label: "Models", roles: all_admin_roles },
|
{ key: "2", page: "models", label: "Models", roles: all_admin_roles },
|
||||||
{ key: "4", page: "usage", label: "Usage"}, // all roles
|
{ key: "4", page: "usage", label: "Usage"}, // all roles
|
||||||
{ key: "6", page: "teams", label: "Teams" },
|
{ key: "6", page: "teams", label: "Teams" },
|
||||||
{ key: "17", page: "organizations", label: "Organizations" },
|
{ key: "17", page: "organizations", label: "Organizations", roles: all_admin_roles },
|
||||||
{ key: "5", page: "users", label: "Internal Users", roles: all_admin_roles },
|
{ key: "5", page: "users", label: "Internal Users", roles: all_admin_roles },
|
||||||
{ key: "8", page: "settings", label: "Logging & Alerts", roles: all_admin_roles },
|
{ key: "8", page: "settings", label: "Logging & Alerts", roles: all_admin_roles },
|
||||||
{ key: "9", page: "caching", label: "Caching", roles: all_admin_roles },
|
{ key: "9", page: "caching", label: "Caching", roles: all_admin_roles },
|
||||||
|
|
|
@ -20,6 +20,7 @@ import {
|
||||||
Tooltip
|
Tooltip
|
||||||
} from "antd";
|
} from "antd";
|
||||||
import { Select, SelectItem } from "@tremor/react";
|
import { Select, SelectItem } from "@tremor/react";
|
||||||
|
|
||||||
import {
|
import {
|
||||||
Table,
|
Table,
|
||||||
TableBody,
|
TableBody,
|
||||||
|
@ -69,6 +70,7 @@ import {
|
||||||
teamListCall
|
teamListCall
|
||||||
} from "./networking";
|
} from "./networking";
|
||||||
|
|
||||||
|
|
||||||
const Team: React.FC<TeamProps> = ({
|
const Team: React.FC<TeamProps> = ({
|
||||||
teams,
|
teams,
|
||||||
searchParams,
|
searchParams,
|
||||||
|
@ -365,6 +367,7 @@ const Team: React.FC<TeamProps> = ({
|
||||||
|
|
||||||
const handleCreate = async (formValues: Record<string, any>) => {
|
const handleCreate = async (formValues: Record<string, any>) => {
|
||||||
try {
|
try {
|
||||||
|
console.log(`formValues: ${JSON.stringify(formValues)}`);
|
||||||
if (accessToken != null) {
|
if (accessToken != null) {
|
||||||
const newTeamAlias = formValues?.team_alias;
|
const newTeamAlias = formValues?.team_alias;
|
||||||
const existingTeamAliases = teams?.map((t) => t.team_alias) ?? [];
|
const existingTeamAliases = teams?.map((t) => t.team_alias) ?? [];
|
||||||
|
@ -746,6 +749,17 @@ const Team: React.FC<TeamProps> = ({
|
||||||
<b>Additional Settings</b>
|
<b>Additional Settings</b>
|
||||||
</AccordionHeader>
|
</AccordionHeader>
|
||||||
<AccordionBody>
|
<AccordionBody>
|
||||||
|
<Form.Item
|
||||||
|
label="Team ID"
|
||||||
|
name="team_id"
|
||||||
|
help="ID of the team you want to create. If not provided, it will be generated automatically."
|
||||||
|
>
|
||||||
|
<TextInput
|
||||||
|
onChange={(e) => {
|
||||||
|
e.target.value = e.target.value.trim();
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
<Form.Item
|
<Form.Item
|
||||||
label="Organization ID"
|
label="Organization ID"
|
||||||
name="organization_id"
|
name="organization_id"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue