[Feat SSO] - Allow admins to set default_team_params to have default params for when litellm SSO creates default teams (#9895)

* add default_team_params as a config.yaml setting

* create_litellm_team_from_sso_group

* test_default_team_params

* test_create_team_without_default_params

* docs default team settings
This commit is contained in:
Ishaan Jaff 2025-04-10 16:58:28 -07:00 committed by GitHub
parent 7c007c0be3
commit 557a2ca102
4 changed files with 252 additions and 86 deletions

View file

@ -207,9 +207,14 @@ This walks through setting up sso auto-add for **Microsoft Entra ID**
Follow along this video for a walkthrough of how to set this up with Microsoft Entra ID Follow along this video for a walkthrough of how to set this up with Microsoft Entra ID
<iframe width="840" height="500" src="https://www.loom.com/embed/ea711323aa9a496d84a01fd7b2a12f54?sid=c53e238c-5bfd-4135-b8fb-b5b1a08632cf" frameborder="0" webkitallowfullscreen mozallowfullscreen allowfullscreen></iframe> <iframe width="840" height="500" src="https://www.loom.com/embed/ea711323aa9a496d84a01fd7b2a12f54?sid=c53e238c-5bfd-4135-b8fb-b5b1a08632cf" frameborder="0" webkitallowfullscreen mozallowfullscreen allowfullscreen></iframe>
<br />
<br />
**Next steps**
1. [Set default params for new teams auto-created from SSO](#set-default-params-for-new-teams)
### Debugging SSO JWT fields ### Debugging SSO JWT fields
@ -279,6 +284,26 @@ This budget does not apply to keys created under non-default teams.
[**Go Here**](./team_budgets.md) [**Go Here**](./team_budgets.md)
### Set default params for new teams
When you connect litellm to your SSO provider, litellm can auto-create teams. Use this to set the default `models`, `max_budget`, `budget_duration` for these auto-created teams.
**How it works**
1. When litellm fetches `groups` from your SSO provider, it will check if the corresponding group_id exists as a `team_id` in litellm.
2. If the team_id does not exist, litellm will auto-create a team with the default params you've set.
3. If the team_id already exist, litellm will not apply any settings on the team.
**Usage**
```yaml showLineNumbers title="Default Params for new teams"
litellm_settings:
default_team_params: # Default Params to apply when litellm auto creates a team from SSO IDP provider
max_budget: 100 # Optional[float], optional): $100 budget for the team
budget_duration: 30d # Optional[str], optional): 30 days budget_duration for the team
models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by the team
```
### Restrict Users from creating personal keys ### Restrict Users from creating personal keys
@ -290,7 +315,7 @@ This will also prevent users from using their session tokens on the test keys ch
## **All Settings for Self Serve / SSO Flow** ## **All Settings for Self Serve / SSO Flow**
```yaml ```yaml showLineNumbers title="All Settings for Self Serve / SSO Flow"
litellm_settings: litellm_settings:
max_internal_user_budget: 10 # max budget for internal users max_internal_user_budget: 10 # max budget for internal users
internal_user_budget_duration: "1mo" # reset every month internal_user_budget_duration: "1mo" # reset every month
@ -300,6 +325,11 @@ litellm_settings:
max_budget: 100 # Optional[float], optional): $100 budget for a new SSO sign in user max_budget: 100 # Optional[float], optional): $100 budget for a new SSO sign in user
budget_duration: 30d # Optional[str], optional): 30 days budget_duration for a new SSO sign in user budget_duration: 30d # Optional[str], optional): 30 days budget_duration for a new SSO sign in user
models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by a new SSO sign in user models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by a new SSO sign in user
default_team_params: # Default Params to apply when litellm auto creates a team from SSO IDP provider
max_budget: 100 # Optional[float], optional): $100 budget for the team
budget_duration: 30d # Optional[str], optional): 30 days budget_duration for the team
models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by the team
upperbound_key_generate_params: # Upperbound for /key/generate requests when self-serve flow is on upperbound_key_generate_params: # Upperbound for /key/generate requests when self-serve flow is on

View file

@ -65,6 +65,7 @@ from litellm.proxy._types import (
KeyManagementSystem, KeyManagementSystem,
KeyManagementSettings, KeyManagementSettings,
LiteLLM_UpperboundKeyGenerateParams, LiteLLM_UpperboundKeyGenerateParams,
NewTeamRequest,
) )
from litellm.types.utils import StandardKeyGenerationConfig, LlmProviders from litellm.types.utils import StandardKeyGenerationConfig, LlmProviders
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
@ -126,19 +127,19 @@ prometheus_initialize_budget_metrics: Optional[bool] = False
require_auth_for_metrics_endpoint: Optional[bool] = False require_auth_for_metrics_endpoint: Optional[bool] = False
argilla_batch_size: Optional[int] = None argilla_batch_size: Optional[int] = None
datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload
gcs_pub_sub_use_v1: Optional[ gcs_pub_sub_use_v1: Optional[bool] = (
bool False # if you want to use v1 gcs pubsub logged payload
] = False # if you want to use v1 gcs pubsub logged payload )
argilla_transformation_object: Optional[Dict[str, Any]] = None argilla_transformation_object: Optional[Dict[str, Any]] = None
_async_input_callback: List[ _async_input_callback: List[Union[str, Callable, CustomLogger]] = (
Union[str, Callable, CustomLogger] []
] = [] # internal variable - async custom callbacks are routed here. ) # internal variable - async custom callbacks are routed here.
_async_success_callback: List[ _async_success_callback: List[Union[str, Callable, CustomLogger]] = (
Union[str, Callable, CustomLogger] []
] = [] # internal variable - async custom callbacks are routed here. ) # internal variable - async custom callbacks are routed here.
_async_failure_callback: List[ _async_failure_callback: List[Union[str, Callable, CustomLogger]] = (
Union[str, Callable, CustomLogger] []
] = [] # internal variable - async custom callbacks are routed here. ) # internal variable - async custom callbacks are routed here.
pre_call_rules: List[Callable] = [] pre_call_rules: List[Callable] = []
post_call_rules: List[Callable] = [] post_call_rules: List[Callable] = []
turn_off_message_logging: Optional[bool] = False turn_off_message_logging: Optional[bool] = False
@ -146,18 +147,18 @@ log_raw_request_response: bool = False
redact_messages_in_exceptions: Optional[bool] = False redact_messages_in_exceptions: Optional[bool] = False
redact_user_api_key_info: Optional[bool] = False redact_user_api_key_info: Optional[bool] = False
filter_invalid_headers: Optional[bool] = False filter_invalid_headers: Optional[bool] = False
add_user_information_to_llm_headers: Optional[ add_user_information_to_llm_headers: Optional[bool] = (
bool None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
] = None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers )
store_audit_logs = False # Enterprise feature, allow users to see audit logs store_audit_logs = False # Enterprise feature, allow users to see audit logs
### end of callbacks ############# ### end of callbacks #############
email: Optional[ email: Optional[str] = (
str None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 )
token: Optional[ token: Optional[str] = (
str None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 )
telemetry = True telemetry = True
max_tokens: int = DEFAULT_MAX_TOKENS # OpenAI Defaults max_tokens: int = DEFAULT_MAX_TOKENS # OpenAI Defaults
drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False)) drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False))
@ -233,20 +234,24 @@ enable_loadbalancing_on_batch_endpoints: Optional[bool] = None
enable_caching_on_provider_specific_optional_params: bool = ( enable_caching_on_provider_specific_optional_params: bool = (
False # feature-flag for caching on optional params - e.g. 'top_k' False # feature-flag for caching on optional params - e.g. 'top_k'
) )
caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 caching: bool = (
caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
cache: Optional[ )
Cache caching_with_models: bool = (
] = None # cache object <- use this - https://docs.litellm.ai/docs/caching False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
)
cache: Optional[Cache] = (
None # cache object <- use this - https://docs.litellm.ai/docs/caching
)
default_in_memory_ttl: Optional[float] = None default_in_memory_ttl: Optional[float] = None
default_redis_ttl: Optional[float] = None default_redis_ttl: Optional[float] = None
default_redis_batch_cache_expiry: Optional[float] = None default_redis_batch_cache_expiry: Optional[float] = None
model_alias_map: Dict[str, str] = {} model_alias_map: Dict[str, str] = {}
model_group_alias_map: Dict[str, str] = {} model_group_alias_map: Dict[str, str] = {}
max_budget: float = 0.0 # set the max budget across all providers max_budget: float = 0.0 # set the max budget across all providers
budget_duration: Optional[ budget_duration: Optional[str] = (
str None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
] = None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). )
default_soft_budget: float = ( default_soft_budget: float = (
DEFAULT_SOFT_BUDGET # by default all litellm proxy keys have a soft budget of 50.0 DEFAULT_SOFT_BUDGET # by default all litellm proxy keys have a soft budget of 50.0
) )
@ -255,11 +260,15 @@ forward_traceparent_to_llm_provider: bool = False
_current_cost = 0.0 # private variable, used if max budget is set _current_cost = 0.0 # private variable, used if max budget is set
error_logs: Dict = {} error_logs: Dict = {}
add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt add_function_to_prompt: bool = (
False # if function calling not supported by api, append function call details to system prompt
)
client_session: Optional[httpx.Client] = None client_session: Optional[httpx.Client] = None
aclient_session: Optional[httpx.AsyncClient] = None aclient_session: Optional[httpx.AsyncClient] = None
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks' model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" model_cost_map_url: str = (
"https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
)
suppress_debug_info = False suppress_debug_info = False
dynamodb_table_name: Optional[str] = None dynamodb_table_name: Optional[str] = None
s3_callback_params: Optional[Dict] = None s3_callback_params: Optional[Dict] = None
@ -268,6 +277,7 @@ default_key_generate_params: Optional[Dict] = None
upperbound_key_generate_params: Optional[LiteLLM_UpperboundKeyGenerateParams] = None upperbound_key_generate_params: Optional[LiteLLM_UpperboundKeyGenerateParams] = None
key_generation_settings: Optional[StandardKeyGenerationConfig] = None key_generation_settings: Optional[StandardKeyGenerationConfig] = None
default_internal_user_params: Optional[Dict] = None default_internal_user_params: Optional[Dict] = None
default_team_params: Optional[NewTeamRequest] = None
default_team_settings: Optional[List] = None default_team_settings: Optional[List] = None
max_user_budget: Optional[float] = None max_user_budget: Optional[float] = None
default_max_internal_user_budget: Optional[float] = None default_max_internal_user_budget: Optional[float] = None
@ -281,7 +291,9 @@ disable_end_user_cost_tracking_prometheus_only: Optional[bool] = None
custom_prometheus_metadata_labels: List[str] = [] custom_prometheus_metadata_labels: List[str] = []
#### REQUEST PRIORITIZATION #### #### REQUEST PRIORITIZATION ####
priority_reservation: Optional[Dict[str, float]] = None priority_reservation: Optional[Dict[str, float]] = None
force_ipv4: bool = False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6. force_ipv4: bool = (
False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6.
)
module_level_aclient = AsyncHTTPHandler( module_level_aclient = AsyncHTTPHandler(
timeout=request_timeout, client_alias="module level aclient" timeout=request_timeout, client_alias="module level aclient"
) )
@ -295,13 +307,13 @@ fallbacks: Optional[List] = None
context_window_fallbacks: Optional[List] = None context_window_fallbacks: Optional[List] = None
content_policy_fallbacks: Optional[List] = None content_policy_fallbacks: Optional[List] = None
allowed_fails: int = 3 allowed_fails: int = 3
num_retries_per_request: Optional[ num_retries_per_request: Optional[int] = (
int None # for the request overall (incl. fallbacks + model retries)
] = None # for the request overall (incl. fallbacks + model retries) )
####### SECRET MANAGERS ##################### ####### SECRET MANAGERS #####################
secret_manager_client: Optional[ secret_manager_client: Optional[Any] = (
Any None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc. )
_google_kms_resource_name: Optional[str] = None _google_kms_resource_name: Optional[str] = None
_key_management_system: Optional[KeyManagementSystem] = None _key_management_system: Optional[KeyManagementSystem] = None
_key_management_settings: KeyManagementSettings = KeyManagementSettings() _key_management_settings: KeyManagementSettings = KeyManagementSettings()
@ -1050,10 +1062,10 @@ from .types.llms.custom_llm import CustomLLMItem
from .types.utils import GenericStreamingChunk from .types.utils import GenericStreamingChunk
custom_provider_map: List[CustomLLMItem] = [] custom_provider_map: List[CustomLLMItem] = []
_custom_providers: List[ _custom_providers: List[str] = (
str []
] = [] # internal helper util, used to track names of custom providers ) # internal helper util, used to track names of custom providers
disable_hf_tokenizer_download: Optional[ disable_hf_tokenizer_download: Optional[bool] = (
bool None # disable huggingface tokenizer download. Defaults to openai clk100
] = None # disable huggingface tokenizer download. Defaults to openai clk100 )
global_disable_no_log_param: bool = False global_disable_no_log_param: bool = False

View file

@ -896,6 +896,68 @@ class SSOAuthenticationHandler:
sso_teams = getattr(result, "team_ids", []) sso_teams = getattr(result, "team_ids", [])
await add_missing_team_member(user_info=user_info, sso_teams=sso_teams) await add_missing_team_member(user_info=user_info, sso_teams=sso_teams)
@staticmethod
async def create_litellm_team_from_sso_group(
litellm_team_id: str,
litellm_team_name: Optional[str] = None,
):
"""
Creates a Litellm Team from a SSO Group ID
Your SSO provider might have groups that should be created on LiteLLM
Use this helper to create a Litellm Team from a SSO Group ID
Args:
litellm_team_id (str): The ID of the Litellm Team
litellm_team_name (Optional[str]): The name of the Litellm Team
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise ProxyException(
message="Prisma client not found. Set it in the proxy_server.py file",
type=ProxyErrorTypes.auth_error,
param="prisma_client",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
try:
team_obj = await prisma_client.db.litellm_teamtable.find_first(
where={"team_id": litellm_team_id}
)
verbose_proxy_logger.debug(f"Team object: {team_obj}")
# only create a new team if it doesn't exist
if team_obj:
verbose_proxy_logger.debug(
f"Team already exists: {litellm_team_id} - {litellm_team_name}"
)
return
team_request: NewTeamRequest = NewTeamRequest(
team_id=litellm_team_id,
team_alias=litellm_team_name,
)
if litellm.default_team_params:
team_request = litellm.default_team_params.model_copy(
deep=True,
update={
"team_id": litellm_team_id,
"team_alias": litellm_team_name,
},
)
await new_team(
data=team_request,
# params used for Audit Logging
http_request=Request(scope={"type": "http", "method": "POST"}),
user_api_key_dict=UserAPIKeyAuth(
token="",
key_alias=f"litellm.{MicrosoftSSOHandler.__name__}",
),
)
except Exception as e:
verbose_proxy_logger.exception(f"Error creating Litellm Team: {e}")
class MicrosoftSSOHandler: class MicrosoftSSOHandler:
""" """
@ -1176,15 +1238,6 @@ class MicrosoftSSOHandler:
When a user sets a `SERVICE_PRINCIPAL_ID` in the env, litellm will fetch groups under that service principal and create Litellm Teams from them When a user sets a `SERVICE_PRINCIPAL_ID` in the env, litellm will fetch groups under that service principal and create Litellm Teams from them
""" """
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise ProxyException(
message="Prisma client not found. Set it in the proxy_server.py file",
type=ProxyErrorTypes.auth_error,
param="prisma_client",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"Creating Litellm Teams from Service Principal Teams: {service_principal_teams}" f"Creating Litellm Teams from Service Principal Teams: {service_principal_teams}"
) )
@ -1199,36 +1252,10 @@ class MicrosoftSSOHandler:
) )
continue continue
try: await SSOAuthenticationHandler.create_litellm_team_from_sso_group(
verbose_proxy_logger.debug( litellm_team_id=litellm_team_id,
f"Creating Litellm Team: {litellm_team_id} - {litellm_team_name}" litellm_team_name=litellm_team_name,
) )
team_obj = await prisma_client.db.litellm_teamtable.find_first(
where={"team_id": litellm_team_id}
)
verbose_proxy_logger.debug(f"Team object: {team_obj}")
# only create a new team if it doesn't exist
if team_obj:
verbose_proxy_logger.debug(
f"Team already exists: {litellm_team_id} - {litellm_team_name}"
)
continue
await new_team(
data=NewTeamRequest(
team_id=litellm_team_id,
team_alias=litellm_team_name,
),
# params used for Audit Logging
http_request=Request(scope={"type": "http", "method": "POST"}),
user_api_key_dict=UserAPIKeyAuth(
token="",
key_alias=f"litellm.{MicrosoftSSOHandler.__name__}",
),
)
except Exception as e:
verbose_proxy_logger.exception(f"Error creating Litellm Team: {e}")
class GoogleSSOHandler: class GoogleSSOHandler:

View file

@ -2,8 +2,9 @@ import asyncio
import json import json
import os import os
import sys import sys
import uuid
from typing import Optional, cast from typing import Optional, cast
from unittest.mock import MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from fastapi import Request from fastapi import Request
@ -13,6 +14,8 @@ sys.path.insert(
0, os.path.abspath("../../../") 0, os.path.abspath("../../../")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import litellm
from litellm.proxy._types import NewTeamRequest
from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.management_endpoints.types import CustomOpenID from litellm.proxy.management_endpoints.types import CustomOpenID
from litellm.proxy.management_endpoints.ui_sso import ( from litellm.proxy.management_endpoints.ui_sso import (
@ -22,6 +25,7 @@ from litellm.proxy.management_endpoints.ui_sso import (
from litellm.types.proxy.management_endpoints.ui_sso import ( from litellm.types.proxy.management_endpoints.ui_sso import (
MicrosoftGraphAPIUserGroupDirectoryObject, MicrosoftGraphAPIUserGroupDirectoryObject,
MicrosoftGraphAPIUserGroupResponse, MicrosoftGraphAPIUserGroupResponse,
MicrosoftServicePrincipalTeam,
) )
@ -379,3 +383,96 @@ def test_get_group_ids_from_graph_api_response():
assert len(result) == 2 assert len(result) == 2
assert "group1" in result assert "group1" in result
assert "group2" in result assert "group2" in result
@pytest.mark.asyncio
async def test_default_team_params():
"""
When litellm.default_team_params is set, it should be used to create a new team
"""
# Arrange
litellm.default_team_params = NewTeamRequest(
max_budget=10, budget_duration="1d", models=["special-gpt-5"]
)
def mock_jsonify_team_object(db_data):
return db_data
# Mock Prisma client
mock_prisma = MagicMock()
mock_prisma.db.litellm_teamtable.find_first = AsyncMock(return_value=None)
mock_prisma.db.litellm_teamtable.create = AsyncMock()
mock_prisma.get_data = AsyncMock(return_value=None)
mock_prisma.jsonify_team_object = MagicMock(side_effect=mock_jsonify_team_object)
with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma):
# Act
team_id = str(uuid.uuid4())
await MicrosoftSSOHandler.create_litellm_teams_from_service_principal_team_ids(
service_principal_teams=[
MicrosoftServicePrincipalTeam(
principalId=team_id,
principalDisplayName="Test Team",
)
]
)
# Assert
# Verify team was created with correct parameters
mock_prisma.db.litellm_teamtable.create.assert_called_once()
print(
"mock_prisma.db.litellm_teamtable.create.call_args",
mock_prisma.db.litellm_teamtable.create.call_args,
)
create_call_args = mock_prisma.db.litellm_teamtable.create.call_args.kwargs[
"data"
]
assert create_call_args["team_id"] == team_id
assert create_call_args["team_alias"] == "Test Team"
assert create_call_args["max_budget"] == 10
assert create_call_args["budget_duration"] == "1d"
assert create_call_args["models"] == ["special-gpt-5"]
@pytest.mark.asyncio
async def test_create_team_without_default_params():
"""
Test team creation when litellm.default_team_params is None
Should create team with just the basic required fields
"""
# Arrange
litellm.default_team_params = None
def mock_jsonify_team_object(db_data):
return db_data
# Mock Prisma client
mock_prisma = MagicMock()
mock_prisma.db.litellm_teamtable.find_first = AsyncMock(return_value=None)
mock_prisma.db.litellm_teamtable.create = AsyncMock()
mock_prisma.get_data = AsyncMock(return_value=None)
mock_prisma.jsonify_team_object = MagicMock(side_effect=mock_jsonify_team_object)
with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma):
# Act
team_id = str(uuid.uuid4())
await MicrosoftSSOHandler.create_litellm_teams_from_service_principal_team_ids(
service_principal_teams=[
MicrosoftServicePrincipalTeam(
principalId=team_id,
principalDisplayName="Test Team",
)
]
)
# Assert
mock_prisma.db.litellm_teamtable.create.assert_called_once()
create_call_args = mock_prisma.db.litellm_teamtable.create.call_args.kwargs[
"data"
]
assert create_call_args["team_id"] == team_id
assert create_call_args["team_alias"] == "Test Team"
# Should not have any of the optional fields
assert "max_budget" not in create_call_args
assert "budget_duration" not in create_call_args
assert create_call_args["models"] == []