diff --git a/docs/my-website/docs/proxy/self_serve.md b/docs/my-website/docs/proxy/self_serve.md index d630c8e7f3..2fc17d952e 100644 --- a/docs/my-website/docs/proxy/self_serve.md +++ b/docs/my-website/docs/proxy/self_serve.md @@ -207,9 +207,14 @@ This walks through setting up sso auto-add for **Microsoft Entra ID** Follow along this video for a walkthrough of how to set this up with Microsoft Entra ID - +
+
+ +**Next steps** + +1. [Set default params for new teams auto-created from SSO](#set-default-params-for-new-teams) ### Debugging SSO JWT fields @@ -279,6 +284,26 @@ This budget does not apply to keys created under non-default teams. [**Go Here**](./team_budgets.md) +### Set default params for new teams + +When you connect litellm to your SSO provider, litellm can auto-create teams. Use this to set the default `models`, `max_budget`, `budget_duration` for these auto-created teams. + +**How it works** + +1. When litellm fetches `groups` from your SSO provider, it will check if the corresponding group_id exists as a `team_id` in litellm. +2. If the team_id does not exist, litellm will auto-create a team with the default params you've set. +3. If the team_id already exist, litellm will not apply any settings on the team. + +**Usage** + +```yaml showLineNumbers title="Default Params for new teams" +litellm_settings: + default_team_params: # Default Params to apply when litellm auto creates a team from SSO IDP provider + max_budget: 100 # Optional[float], optional): $100 budget for the team + budget_duration: 30d # Optional[str], optional): 30 days budget_duration for the team + models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by the team +``` + ### Restrict Users from creating personal keys @@ -290,7 +315,7 @@ This will also prevent users from using their session tokens on the test keys ch ## **All Settings for Self Serve / SSO Flow** -```yaml +```yaml showLineNumbers title="All Settings for Self Serve / SSO Flow" litellm_settings: max_internal_user_budget: 10 # max budget for internal users internal_user_budget_duration: "1mo" # reset every month @@ -300,6 +325,11 @@ litellm_settings: max_budget: 100 # Optional[float], optional): $100 budget for a new SSO sign in user budget_duration: 30d # Optional[str], optional): 30 days budget_duration for a new SSO sign in user models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by a new SSO sign in user + + default_team_params: # Default Params to apply when litellm auto creates a team from SSO IDP provider + max_budget: 100 # Optional[float], optional): $100 budget for the team + budget_duration: 30d # Optional[str], optional): 30 days budget_duration for the team + models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by the team upperbound_key_generate_params: # Upperbound for /key/generate requests when self-serve flow is on diff --git a/litellm/__init__.py b/litellm/__init__.py index e061643398..a3b37da2b4 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -65,6 +65,7 @@ from litellm.proxy._types import ( KeyManagementSystem, KeyManagementSettings, LiteLLM_UpperboundKeyGenerateParams, + NewTeamRequest, ) from litellm.types.utils import StandardKeyGenerationConfig, LlmProviders from litellm.integrations.custom_logger import CustomLogger @@ -126,19 +127,19 @@ prometheus_initialize_budget_metrics: Optional[bool] = False require_auth_for_metrics_endpoint: Optional[bool] = False argilla_batch_size: Optional[int] = None datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload -gcs_pub_sub_use_v1: Optional[ - bool -] = False # if you want to use v1 gcs pubsub logged payload +gcs_pub_sub_use_v1: Optional[bool] = ( + False # if you want to use v1 gcs pubsub logged payload +) argilla_transformation_object: Optional[Dict[str, Any]] = None -_async_input_callback: List[ - Union[str, Callable, CustomLogger] -] = [] # internal variable - async custom callbacks are routed here. -_async_success_callback: List[ - Union[str, Callable, CustomLogger] -] = [] # internal variable - async custom callbacks are routed here. -_async_failure_callback: List[ - Union[str, Callable, CustomLogger] -] = [] # internal variable - async custom callbacks are routed here. +_async_input_callback: List[Union[str, Callable, CustomLogger]] = ( + [] +) # internal variable - async custom callbacks are routed here. +_async_success_callback: List[Union[str, Callable, CustomLogger]] = ( + [] +) # internal variable - async custom callbacks are routed here. +_async_failure_callback: List[Union[str, Callable, CustomLogger]] = ( + [] +) # internal variable - async custom callbacks are routed here. pre_call_rules: List[Callable] = [] post_call_rules: List[Callable] = [] turn_off_message_logging: Optional[bool] = False @@ -146,18 +147,18 @@ log_raw_request_response: bool = False redact_messages_in_exceptions: Optional[bool] = False redact_user_api_key_info: Optional[bool] = False filter_invalid_headers: Optional[bool] = False -add_user_information_to_llm_headers: Optional[ - bool -] = None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers +add_user_information_to_llm_headers: Optional[bool] = ( + None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers +) store_audit_logs = False # Enterprise feature, allow users to see audit logs ### end of callbacks ############# -email: Optional[ - str -] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 -token: Optional[ - str -] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +email: Optional[str] = ( + None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +) +token: Optional[str] = ( + None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +) telemetry = True max_tokens: int = DEFAULT_MAX_TOKENS # OpenAI Defaults drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False)) @@ -233,20 +234,24 @@ enable_loadbalancing_on_batch_endpoints: Optional[bool] = None enable_caching_on_provider_specific_optional_params: bool = ( False # feature-flag for caching on optional params - e.g. 'top_k' ) -caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 -caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 -cache: Optional[ - Cache -] = None # cache object <- use this - https://docs.litellm.ai/docs/caching +caching: bool = ( + False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +) +caching_with_models: bool = ( + False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +) +cache: Optional[Cache] = ( + None # cache object <- use this - https://docs.litellm.ai/docs/caching +) default_in_memory_ttl: Optional[float] = None default_redis_ttl: Optional[float] = None default_redis_batch_cache_expiry: Optional[float] = None model_alias_map: Dict[str, str] = {} model_group_alias_map: Dict[str, str] = {} max_budget: float = 0.0 # set the max budget across all providers -budget_duration: Optional[ - str -] = None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). +budget_duration: Optional[str] = ( + None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). +) default_soft_budget: float = ( DEFAULT_SOFT_BUDGET # by default all litellm proxy keys have a soft budget of 50.0 ) @@ -255,11 +260,15 @@ forward_traceparent_to_llm_provider: bool = False _current_cost = 0.0 # private variable, used if max budget is set error_logs: Dict = {} -add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt +add_function_to_prompt: bool = ( + False # if function calling not supported by api, append function call details to system prompt +) client_session: Optional[httpx.Client] = None aclient_session: Optional[httpx.AsyncClient] = None model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks' -model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" +model_cost_map_url: str = ( + "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" +) suppress_debug_info = False dynamodb_table_name: Optional[str] = None s3_callback_params: Optional[Dict] = None @@ -268,6 +277,7 @@ default_key_generate_params: Optional[Dict] = None upperbound_key_generate_params: Optional[LiteLLM_UpperboundKeyGenerateParams] = None key_generation_settings: Optional[StandardKeyGenerationConfig] = None default_internal_user_params: Optional[Dict] = None +default_team_params: Optional[NewTeamRequest] = None default_team_settings: Optional[List] = None max_user_budget: Optional[float] = None default_max_internal_user_budget: Optional[float] = None @@ -281,7 +291,9 @@ disable_end_user_cost_tracking_prometheus_only: Optional[bool] = None custom_prometheus_metadata_labels: List[str] = [] #### REQUEST PRIORITIZATION #### priority_reservation: Optional[Dict[str, float]] = None -force_ipv4: bool = False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6. +force_ipv4: bool = ( + False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6. +) module_level_aclient = AsyncHTTPHandler( timeout=request_timeout, client_alias="module level aclient" ) @@ -295,13 +307,13 @@ fallbacks: Optional[List] = None context_window_fallbacks: Optional[List] = None content_policy_fallbacks: Optional[List] = None allowed_fails: int = 3 -num_retries_per_request: Optional[ - int -] = None # for the request overall (incl. fallbacks + model retries) +num_retries_per_request: Optional[int] = ( + None # for the request overall (incl. fallbacks + model retries) +) ####### SECRET MANAGERS ##################### -secret_manager_client: Optional[ - Any -] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc. +secret_manager_client: Optional[Any] = ( + None # list of instantiated key management clients - e.g. azure kv, infisical, etc. +) _google_kms_resource_name: Optional[str] = None _key_management_system: Optional[KeyManagementSystem] = None _key_management_settings: KeyManagementSettings = KeyManagementSettings() @@ -1050,10 +1062,10 @@ from .types.llms.custom_llm import CustomLLMItem from .types.utils import GenericStreamingChunk custom_provider_map: List[CustomLLMItem] = [] -_custom_providers: List[ - str -] = [] # internal helper util, used to track names of custom providers -disable_hf_tokenizer_download: Optional[ - bool -] = None # disable huggingface tokenizer download. Defaults to openai clk100 +_custom_providers: List[str] = ( + [] +) # internal helper util, used to track names of custom providers +disable_hf_tokenizer_download: Optional[bool] = ( + None # disable huggingface tokenizer download. Defaults to openai clk100 +) global_disable_no_log_param: bool = False diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index 0cd3600220..1e10aebedb 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -896,6 +896,68 @@ class SSOAuthenticationHandler: sso_teams = getattr(result, "team_ids", []) await add_missing_team_member(user_info=user_info, sso_teams=sso_teams) + @staticmethod + async def create_litellm_team_from_sso_group( + litellm_team_id: str, + litellm_team_name: Optional[str] = None, + ): + """ + Creates a Litellm Team from a SSO Group ID + + Your SSO provider might have groups that should be created on LiteLLM + + Use this helper to create a Litellm Team from a SSO Group ID + + Args: + litellm_team_id (str): The ID of the Litellm Team + litellm_team_name (Optional[str]): The name of the Litellm Team + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise ProxyException( + message="Prisma client not found. Set it in the proxy_server.py file", + type=ProxyErrorTypes.auth_error, + param="prisma_client", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + try: + team_obj = await prisma_client.db.litellm_teamtable.find_first( + where={"team_id": litellm_team_id} + ) + verbose_proxy_logger.debug(f"Team object: {team_obj}") + + # only create a new team if it doesn't exist + if team_obj: + verbose_proxy_logger.debug( + f"Team already exists: {litellm_team_id} - {litellm_team_name}" + ) + return + + team_request: NewTeamRequest = NewTeamRequest( + team_id=litellm_team_id, + team_alias=litellm_team_name, + ) + if litellm.default_team_params: + team_request = litellm.default_team_params.model_copy( + deep=True, + update={ + "team_id": litellm_team_id, + "team_alias": litellm_team_name, + }, + ) + await new_team( + data=team_request, + # params used for Audit Logging + http_request=Request(scope={"type": "http", "method": "POST"}), + user_api_key_dict=UserAPIKeyAuth( + token="", + key_alias=f"litellm.{MicrosoftSSOHandler.__name__}", + ), + ) + except Exception as e: + verbose_proxy_logger.exception(f"Error creating Litellm Team: {e}") + class MicrosoftSSOHandler: """ @@ -1176,15 +1238,6 @@ class MicrosoftSSOHandler: When a user sets a `SERVICE_PRINCIPAL_ID` in the env, litellm will fetch groups under that service principal and create Litellm Teams from them """ - from litellm.proxy.proxy_server import prisma_client - - if prisma_client is None: - raise ProxyException( - message="Prisma client not found. Set it in the proxy_server.py file", - type=ProxyErrorTypes.auth_error, - param="prisma_client", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) verbose_proxy_logger.debug( f"Creating Litellm Teams from Service Principal Teams: {service_principal_teams}" ) @@ -1199,36 +1252,10 @@ class MicrosoftSSOHandler: ) continue - try: - verbose_proxy_logger.debug( - f"Creating Litellm Team: {litellm_team_id} - {litellm_team_name}" - ) - - team_obj = await prisma_client.db.litellm_teamtable.find_first( - where={"team_id": litellm_team_id} - ) - verbose_proxy_logger.debug(f"Team object: {team_obj}") - - # only create a new team if it doesn't exist - if team_obj: - verbose_proxy_logger.debug( - f"Team already exists: {litellm_team_id} - {litellm_team_name}" - ) - continue - await new_team( - data=NewTeamRequest( - team_id=litellm_team_id, - team_alias=litellm_team_name, - ), - # params used for Audit Logging - http_request=Request(scope={"type": "http", "method": "POST"}), - user_api_key_dict=UserAPIKeyAuth( - token="", - key_alias=f"litellm.{MicrosoftSSOHandler.__name__}", - ), - ) - except Exception as e: - verbose_proxy_logger.exception(f"Error creating Litellm Team: {e}") + await SSOAuthenticationHandler.create_litellm_team_from_sso_group( + litellm_team_id=litellm_team_id, + litellm_team_name=litellm_team_name, + ) class GoogleSSOHandler: diff --git a/tests/litellm/proxy/management_endpoints/test_ui_sso.py b/tests/litellm/proxy/management_endpoints/test_ui_sso.py index 606f3833be..ff9700393f 100644 --- a/tests/litellm/proxy/management_endpoints/test_ui_sso.py +++ b/tests/litellm/proxy/management_endpoints/test_ui_sso.py @@ -2,8 +2,9 @@ import asyncio import json import os import sys +import uuid from typing import Optional, cast -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi import Request @@ -13,6 +14,8 @@ sys.path.insert( 0, os.path.abspath("../../../") ) # Adds the parent directory to the system path +import litellm +from litellm.proxy._types import NewTeamRequest from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.management_endpoints.types import CustomOpenID from litellm.proxy.management_endpoints.ui_sso import ( @@ -22,6 +25,7 @@ from litellm.proxy.management_endpoints.ui_sso import ( from litellm.types.proxy.management_endpoints.ui_sso import ( MicrosoftGraphAPIUserGroupDirectoryObject, MicrosoftGraphAPIUserGroupResponse, + MicrosoftServicePrincipalTeam, ) @@ -379,3 +383,96 @@ def test_get_group_ids_from_graph_api_response(): assert len(result) == 2 assert "group1" in result assert "group2" in result + + +@pytest.mark.asyncio +async def test_default_team_params(): + """ + When litellm.default_team_params is set, it should be used to create a new team + """ + # Arrange + litellm.default_team_params = NewTeamRequest( + max_budget=10, budget_duration="1d", models=["special-gpt-5"] + ) + + def mock_jsonify_team_object(db_data): + return db_data + + # Mock Prisma client + mock_prisma = MagicMock() + mock_prisma.db.litellm_teamtable.find_first = AsyncMock(return_value=None) + mock_prisma.db.litellm_teamtable.create = AsyncMock() + mock_prisma.get_data = AsyncMock(return_value=None) + mock_prisma.jsonify_team_object = MagicMock(side_effect=mock_jsonify_team_object) + + with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma): + # Act + team_id = str(uuid.uuid4()) + await MicrosoftSSOHandler.create_litellm_teams_from_service_principal_team_ids( + service_principal_teams=[ + MicrosoftServicePrincipalTeam( + principalId=team_id, + principalDisplayName="Test Team", + ) + ] + ) + + # Assert + # Verify team was created with correct parameters + mock_prisma.db.litellm_teamtable.create.assert_called_once() + print( + "mock_prisma.db.litellm_teamtable.create.call_args", + mock_prisma.db.litellm_teamtable.create.call_args, + ) + create_call_args = mock_prisma.db.litellm_teamtable.create.call_args.kwargs[ + "data" + ] + assert create_call_args["team_id"] == team_id + assert create_call_args["team_alias"] == "Test Team" + assert create_call_args["max_budget"] == 10 + assert create_call_args["budget_duration"] == "1d" + assert create_call_args["models"] == ["special-gpt-5"] + + +@pytest.mark.asyncio +async def test_create_team_without_default_params(): + """ + Test team creation when litellm.default_team_params is None + Should create team with just the basic required fields + """ + # Arrange + litellm.default_team_params = None + + def mock_jsonify_team_object(db_data): + return db_data + + # Mock Prisma client + mock_prisma = MagicMock() + mock_prisma.db.litellm_teamtable.find_first = AsyncMock(return_value=None) + mock_prisma.db.litellm_teamtable.create = AsyncMock() + mock_prisma.get_data = AsyncMock(return_value=None) + mock_prisma.jsonify_team_object = MagicMock(side_effect=mock_jsonify_team_object) + + with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma): + # Act + team_id = str(uuid.uuid4()) + await MicrosoftSSOHandler.create_litellm_teams_from_service_principal_team_ids( + service_principal_teams=[ + MicrosoftServicePrincipalTeam( + principalId=team_id, + principalDisplayName="Test Team", + ) + ] + ) + + # Assert + mock_prisma.db.litellm_teamtable.create.assert_called_once() + create_call_args = mock_prisma.db.litellm_teamtable.create.call_args.kwargs[ + "data" + ] + assert create_call_args["team_id"] == team_id + assert create_call_args["team_alias"] == "Test Team" + # Should not have any of the optional fields + assert "max_budget" not in create_call_args + assert "budget_duration" not in create_call_args + assert create_call_args["models"] == []