diff --git a/docs/my-website/docs/proxy/self_serve.md b/docs/my-website/docs/proxy/self_serve.md
index d630c8e7f3..2fc17d952e 100644
--- a/docs/my-website/docs/proxy/self_serve.md
+++ b/docs/my-website/docs/proxy/self_serve.md
@@ -207,9 +207,14 @@ This walks through setting up sso auto-add for **Microsoft Entra ID**
Follow along this video for a walkthrough of how to set this up with Microsoft Entra ID
-
+
+
+
+**Next steps**
+
+1. [Set default params for new teams auto-created from SSO](#set-default-params-for-new-teams)
### Debugging SSO JWT fields
@@ -279,6 +284,26 @@ This budget does not apply to keys created under non-default teams.
[**Go Here**](./team_budgets.md)
+### Set default params for new teams
+
+When you connect litellm to your SSO provider, litellm can auto-create teams. Use this to set the default `models`, `max_budget`, `budget_duration` for these auto-created teams.
+
+**How it works**
+
+1. When litellm fetches `groups` from your SSO provider, it will check if the corresponding group_id exists as a `team_id` in litellm.
+2. If the team_id does not exist, litellm will auto-create a team with the default params you've set.
+3. If the team_id already exist, litellm will not apply any settings on the team.
+
+**Usage**
+
+```yaml showLineNumbers title="Default Params for new teams"
+litellm_settings:
+ default_team_params: # Default Params to apply when litellm auto creates a team from SSO IDP provider
+ max_budget: 100 # Optional[float], optional): $100 budget for the team
+ budget_duration: 30d # Optional[str], optional): 30 days budget_duration for the team
+ models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by the team
+```
+
### Restrict Users from creating personal keys
@@ -290,7 +315,7 @@ This will also prevent users from using their session tokens on the test keys ch
## **All Settings for Self Serve / SSO Flow**
-```yaml
+```yaml showLineNumbers title="All Settings for Self Serve / SSO Flow"
litellm_settings:
max_internal_user_budget: 10 # max budget for internal users
internal_user_budget_duration: "1mo" # reset every month
@@ -300,6 +325,11 @@ litellm_settings:
max_budget: 100 # Optional[float], optional): $100 budget for a new SSO sign in user
budget_duration: 30d # Optional[str], optional): 30 days budget_duration for a new SSO sign in user
models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by a new SSO sign in user
+
+ default_team_params: # Default Params to apply when litellm auto creates a team from SSO IDP provider
+ max_budget: 100 # Optional[float], optional): $100 budget for the team
+ budget_duration: 30d # Optional[str], optional): 30 days budget_duration for the team
+ models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by the team
upperbound_key_generate_params: # Upperbound for /key/generate requests when self-serve flow is on
diff --git a/litellm/__init__.py b/litellm/__init__.py
index e061643398..a3b37da2b4 100644
--- a/litellm/__init__.py
+++ b/litellm/__init__.py
@@ -65,6 +65,7 @@ from litellm.proxy._types import (
KeyManagementSystem,
KeyManagementSettings,
LiteLLM_UpperboundKeyGenerateParams,
+ NewTeamRequest,
)
from litellm.types.utils import StandardKeyGenerationConfig, LlmProviders
from litellm.integrations.custom_logger import CustomLogger
@@ -126,19 +127,19 @@ prometheus_initialize_budget_metrics: Optional[bool] = False
require_auth_for_metrics_endpoint: Optional[bool] = False
argilla_batch_size: Optional[int] = None
datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload
-gcs_pub_sub_use_v1: Optional[
- bool
-] = False # if you want to use v1 gcs pubsub logged payload
+gcs_pub_sub_use_v1: Optional[bool] = (
+ False # if you want to use v1 gcs pubsub logged payload
+)
argilla_transformation_object: Optional[Dict[str, Any]] = None
-_async_input_callback: List[
- Union[str, Callable, CustomLogger]
-] = [] # internal variable - async custom callbacks are routed here.
-_async_success_callback: List[
- Union[str, Callable, CustomLogger]
-] = [] # internal variable - async custom callbacks are routed here.
-_async_failure_callback: List[
- Union[str, Callable, CustomLogger]
-] = [] # internal variable - async custom callbacks are routed here.
+_async_input_callback: List[Union[str, Callable, CustomLogger]] = (
+ []
+) # internal variable - async custom callbacks are routed here.
+_async_success_callback: List[Union[str, Callable, CustomLogger]] = (
+ []
+) # internal variable - async custom callbacks are routed here.
+_async_failure_callback: List[Union[str, Callable, CustomLogger]] = (
+ []
+) # internal variable - async custom callbacks are routed here.
pre_call_rules: List[Callable] = []
post_call_rules: List[Callable] = []
turn_off_message_logging: Optional[bool] = False
@@ -146,18 +147,18 @@ log_raw_request_response: bool = False
redact_messages_in_exceptions: Optional[bool] = False
redact_user_api_key_info: Optional[bool] = False
filter_invalid_headers: Optional[bool] = False
-add_user_information_to_llm_headers: Optional[
- bool
-] = None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
+add_user_information_to_llm_headers: Optional[bool] = (
+ None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
+)
store_audit_logs = False # Enterprise feature, allow users to see audit logs
### end of callbacks #############
-email: Optional[
- str
-] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
-token: Optional[
- str
-] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
+email: Optional[str] = (
+ None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
+)
+token: Optional[str] = (
+ None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
+)
telemetry = True
max_tokens: int = DEFAULT_MAX_TOKENS # OpenAI Defaults
drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False))
@@ -233,20 +234,24 @@ enable_loadbalancing_on_batch_endpoints: Optional[bool] = None
enable_caching_on_provider_specific_optional_params: bool = (
False # feature-flag for caching on optional params - e.g. 'top_k'
)
-caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
-caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
-cache: Optional[
- Cache
-] = None # cache object <- use this - https://docs.litellm.ai/docs/caching
+caching: bool = (
+ False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
+)
+caching_with_models: bool = (
+ False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
+)
+cache: Optional[Cache] = (
+ None # cache object <- use this - https://docs.litellm.ai/docs/caching
+)
default_in_memory_ttl: Optional[float] = None
default_redis_ttl: Optional[float] = None
default_redis_batch_cache_expiry: Optional[float] = None
model_alias_map: Dict[str, str] = {}
model_group_alias_map: Dict[str, str] = {}
max_budget: float = 0.0 # set the max budget across all providers
-budget_duration: Optional[
- str
-] = None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
+budget_duration: Optional[str] = (
+ None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
+)
default_soft_budget: float = (
DEFAULT_SOFT_BUDGET # by default all litellm proxy keys have a soft budget of 50.0
)
@@ -255,11 +260,15 @@ forward_traceparent_to_llm_provider: bool = False
_current_cost = 0.0 # private variable, used if max budget is set
error_logs: Dict = {}
-add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
+add_function_to_prompt: bool = (
+ False # if function calling not supported by api, append function call details to system prompt
+)
client_session: Optional[httpx.Client] = None
aclient_session: Optional[httpx.AsyncClient] = None
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
-model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
+model_cost_map_url: str = (
+ "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
+)
suppress_debug_info = False
dynamodb_table_name: Optional[str] = None
s3_callback_params: Optional[Dict] = None
@@ -268,6 +277,7 @@ default_key_generate_params: Optional[Dict] = None
upperbound_key_generate_params: Optional[LiteLLM_UpperboundKeyGenerateParams] = None
key_generation_settings: Optional[StandardKeyGenerationConfig] = None
default_internal_user_params: Optional[Dict] = None
+default_team_params: Optional[NewTeamRequest] = None
default_team_settings: Optional[List] = None
max_user_budget: Optional[float] = None
default_max_internal_user_budget: Optional[float] = None
@@ -281,7 +291,9 @@ disable_end_user_cost_tracking_prometheus_only: Optional[bool] = None
custom_prometheus_metadata_labels: List[str] = []
#### REQUEST PRIORITIZATION ####
priority_reservation: Optional[Dict[str, float]] = None
-force_ipv4: bool = False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6.
+force_ipv4: bool = (
+ False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6.
+)
module_level_aclient = AsyncHTTPHandler(
timeout=request_timeout, client_alias="module level aclient"
)
@@ -295,13 +307,13 @@ fallbacks: Optional[List] = None
context_window_fallbacks: Optional[List] = None
content_policy_fallbacks: Optional[List] = None
allowed_fails: int = 3
-num_retries_per_request: Optional[
- int
-] = None # for the request overall (incl. fallbacks + model retries)
+num_retries_per_request: Optional[int] = (
+ None # for the request overall (incl. fallbacks + model retries)
+)
####### SECRET MANAGERS #####################
-secret_manager_client: Optional[
- Any
-] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
+secret_manager_client: Optional[Any] = (
+ None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
+)
_google_kms_resource_name: Optional[str] = None
_key_management_system: Optional[KeyManagementSystem] = None
_key_management_settings: KeyManagementSettings = KeyManagementSettings()
@@ -1050,10 +1062,10 @@ from .types.llms.custom_llm import CustomLLMItem
from .types.utils import GenericStreamingChunk
custom_provider_map: List[CustomLLMItem] = []
-_custom_providers: List[
- str
-] = [] # internal helper util, used to track names of custom providers
-disable_hf_tokenizer_download: Optional[
- bool
-] = None # disable huggingface tokenizer download. Defaults to openai clk100
+_custom_providers: List[str] = (
+ []
+) # internal helper util, used to track names of custom providers
+disable_hf_tokenizer_download: Optional[bool] = (
+ None # disable huggingface tokenizer download. Defaults to openai clk100
+)
global_disable_no_log_param: bool = False
diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py
index 0cd3600220..1e10aebedb 100644
--- a/litellm/proxy/management_endpoints/ui_sso.py
+++ b/litellm/proxy/management_endpoints/ui_sso.py
@@ -896,6 +896,68 @@ class SSOAuthenticationHandler:
sso_teams = getattr(result, "team_ids", [])
await add_missing_team_member(user_info=user_info, sso_teams=sso_teams)
+ @staticmethod
+ async def create_litellm_team_from_sso_group(
+ litellm_team_id: str,
+ litellm_team_name: Optional[str] = None,
+ ):
+ """
+ Creates a Litellm Team from a SSO Group ID
+
+ Your SSO provider might have groups that should be created on LiteLLM
+
+ Use this helper to create a Litellm Team from a SSO Group ID
+
+ Args:
+ litellm_team_id (str): The ID of the Litellm Team
+ litellm_team_name (Optional[str]): The name of the Litellm Team
+ """
+ from litellm.proxy.proxy_server import prisma_client
+
+ if prisma_client is None:
+ raise ProxyException(
+ message="Prisma client not found. Set it in the proxy_server.py file",
+ type=ProxyErrorTypes.auth_error,
+ param="prisma_client",
+ code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ )
+ try:
+ team_obj = await prisma_client.db.litellm_teamtable.find_first(
+ where={"team_id": litellm_team_id}
+ )
+ verbose_proxy_logger.debug(f"Team object: {team_obj}")
+
+ # only create a new team if it doesn't exist
+ if team_obj:
+ verbose_proxy_logger.debug(
+ f"Team already exists: {litellm_team_id} - {litellm_team_name}"
+ )
+ return
+
+ team_request: NewTeamRequest = NewTeamRequest(
+ team_id=litellm_team_id,
+ team_alias=litellm_team_name,
+ )
+ if litellm.default_team_params:
+ team_request = litellm.default_team_params.model_copy(
+ deep=True,
+ update={
+ "team_id": litellm_team_id,
+ "team_alias": litellm_team_name,
+ },
+ )
+ await new_team(
+ data=team_request,
+ # params used for Audit Logging
+ http_request=Request(scope={"type": "http", "method": "POST"}),
+ user_api_key_dict=UserAPIKeyAuth(
+ token="",
+ key_alias=f"litellm.{MicrosoftSSOHandler.__name__}",
+ ),
+ )
+ except Exception as e:
+ verbose_proxy_logger.exception(f"Error creating Litellm Team: {e}")
+
class MicrosoftSSOHandler:
"""
@@ -1176,15 +1238,6 @@ class MicrosoftSSOHandler:
When a user sets a `SERVICE_PRINCIPAL_ID` in the env, litellm will fetch groups under that service principal and create Litellm Teams from them
"""
- from litellm.proxy.proxy_server import prisma_client
-
- if prisma_client is None:
- raise ProxyException(
- message="Prisma client not found. Set it in the proxy_server.py file",
- type=ProxyErrorTypes.auth_error,
- param="prisma_client",
- code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- )
verbose_proxy_logger.debug(
f"Creating Litellm Teams from Service Principal Teams: {service_principal_teams}"
)
@@ -1199,36 +1252,10 @@ class MicrosoftSSOHandler:
)
continue
- try:
- verbose_proxy_logger.debug(
- f"Creating Litellm Team: {litellm_team_id} - {litellm_team_name}"
- )
-
- team_obj = await prisma_client.db.litellm_teamtable.find_first(
- where={"team_id": litellm_team_id}
- )
- verbose_proxy_logger.debug(f"Team object: {team_obj}")
-
- # only create a new team if it doesn't exist
- if team_obj:
- verbose_proxy_logger.debug(
- f"Team already exists: {litellm_team_id} - {litellm_team_name}"
- )
- continue
- await new_team(
- data=NewTeamRequest(
- team_id=litellm_team_id,
- team_alias=litellm_team_name,
- ),
- # params used for Audit Logging
- http_request=Request(scope={"type": "http", "method": "POST"}),
- user_api_key_dict=UserAPIKeyAuth(
- token="",
- key_alias=f"litellm.{MicrosoftSSOHandler.__name__}",
- ),
- )
- except Exception as e:
- verbose_proxy_logger.exception(f"Error creating Litellm Team: {e}")
+ await SSOAuthenticationHandler.create_litellm_team_from_sso_group(
+ litellm_team_id=litellm_team_id,
+ litellm_team_name=litellm_team_name,
+ )
class GoogleSSOHandler:
diff --git a/tests/litellm/proxy/management_endpoints/test_ui_sso.py b/tests/litellm/proxy/management_endpoints/test_ui_sso.py
index 606f3833be..ff9700393f 100644
--- a/tests/litellm/proxy/management_endpoints/test_ui_sso.py
+++ b/tests/litellm/proxy/management_endpoints/test_ui_sso.py
@@ -2,8 +2,9 @@ import asyncio
import json
import os
import sys
+import uuid
from typing import Optional, cast
-from unittest.mock import MagicMock, patch
+from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import Request
@@ -13,6 +14,8 @@ sys.path.insert(
0, os.path.abspath("../../../")
) # Adds the parent directory to the system path
+import litellm
+from litellm.proxy._types import NewTeamRequest
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.management_endpoints.types import CustomOpenID
from litellm.proxy.management_endpoints.ui_sso import (
@@ -22,6 +25,7 @@ from litellm.proxy.management_endpoints.ui_sso import (
from litellm.types.proxy.management_endpoints.ui_sso import (
MicrosoftGraphAPIUserGroupDirectoryObject,
MicrosoftGraphAPIUserGroupResponse,
+ MicrosoftServicePrincipalTeam,
)
@@ -379,3 +383,96 @@ def test_get_group_ids_from_graph_api_response():
assert len(result) == 2
assert "group1" in result
assert "group2" in result
+
+
+@pytest.mark.asyncio
+async def test_default_team_params():
+ """
+ When litellm.default_team_params is set, it should be used to create a new team
+ """
+ # Arrange
+ litellm.default_team_params = NewTeamRequest(
+ max_budget=10, budget_duration="1d", models=["special-gpt-5"]
+ )
+
+ def mock_jsonify_team_object(db_data):
+ return db_data
+
+ # Mock Prisma client
+ mock_prisma = MagicMock()
+ mock_prisma.db.litellm_teamtable.find_first = AsyncMock(return_value=None)
+ mock_prisma.db.litellm_teamtable.create = AsyncMock()
+ mock_prisma.get_data = AsyncMock(return_value=None)
+ mock_prisma.jsonify_team_object = MagicMock(side_effect=mock_jsonify_team_object)
+
+ with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma):
+ # Act
+ team_id = str(uuid.uuid4())
+ await MicrosoftSSOHandler.create_litellm_teams_from_service_principal_team_ids(
+ service_principal_teams=[
+ MicrosoftServicePrincipalTeam(
+ principalId=team_id,
+ principalDisplayName="Test Team",
+ )
+ ]
+ )
+
+ # Assert
+ # Verify team was created with correct parameters
+ mock_prisma.db.litellm_teamtable.create.assert_called_once()
+ print(
+ "mock_prisma.db.litellm_teamtable.create.call_args",
+ mock_prisma.db.litellm_teamtable.create.call_args,
+ )
+ create_call_args = mock_prisma.db.litellm_teamtable.create.call_args.kwargs[
+ "data"
+ ]
+ assert create_call_args["team_id"] == team_id
+ assert create_call_args["team_alias"] == "Test Team"
+ assert create_call_args["max_budget"] == 10
+ assert create_call_args["budget_duration"] == "1d"
+ assert create_call_args["models"] == ["special-gpt-5"]
+
+
+@pytest.mark.asyncio
+async def test_create_team_without_default_params():
+ """
+ Test team creation when litellm.default_team_params is None
+ Should create team with just the basic required fields
+ """
+ # Arrange
+ litellm.default_team_params = None
+
+ def mock_jsonify_team_object(db_data):
+ return db_data
+
+ # Mock Prisma client
+ mock_prisma = MagicMock()
+ mock_prisma.db.litellm_teamtable.find_first = AsyncMock(return_value=None)
+ mock_prisma.db.litellm_teamtable.create = AsyncMock()
+ mock_prisma.get_data = AsyncMock(return_value=None)
+ mock_prisma.jsonify_team_object = MagicMock(side_effect=mock_jsonify_team_object)
+
+ with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma):
+ # Act
+ team_id = str(uuid.uuid4())
+ await MicrosoftSSOHandler.create_litellm_teams_from_service_principal_team_ids(
+ service_principal_teams=[
+ MicrosoftServicePrincipalTeam(
+ principalId=team_id,
+ principalDisplayName="Test Team",
+ )
+ ]
+ )
+
+ # Assert
+ mock_prisma.db.litellm_teamtable.create.assert_called_once()
+ create_call_args = mock_prisma.db.litellm_teamtable.create.call_args.kwargs[
+ "data"
+ ]
+ assert create_call_args["team_id"] == team_id
+ assert create_call_args["team_alias"] == "Test Team"
+ # Should not have any of the optional fields
+ assert "max_budget" not in create_call_args
+ assert "budget_duration" not in create_call_args
+ assert create_call_args["models"] == []