mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
[Feat SSO] - Allow admins to set default_team_params
to have default params for when litellm SSO creates default teams (#9895)
* add default_team_params as a config.yaml setting * create_litellm_team_from_sso_group * test_default_team_params * test_create_team_without_default_params * docs default team settings
This commit is contained in:
parent
7c007c0be3
commit
557a2ca102
4 changed files with 252 additions and 86 deletions
|
@ -207,9 +207,14 @@ This walks through setting up sso auto-add for **Microsoft Entra ID**
|
||||||
|
|
||||||
Follow along this video for a walkthrough of how to set this up with Microsoft Entra ID
|
Follow along this video for a walkthrough of how to set this up with Microsoft Entra ID
|
||||||
|
|
||||||
|
|
||||||
<iframe width="840" height="500" src="https://www.loom.com/embed/ea711323aa9a496d84a01fd7b2a12f54?sid=c53e238c-5bfd-4135-b8fb-b5b1a08632cf" frameborder="0" webkitallowfullscreen mozallowfullscreen allowfullscreen></iframe>
|
<iframe width="840" height="500" src="https://www.loom.com/embed/ea711323aa9a496d84a01fd7b2a12f54?sid=c53e238c-5bfd-4135-b8fb-b5b1a08632cf" frameborder="0" webkitallowfullscreen mozallowfullscreen allowfullscreen></iframe>
|
||||||
|
|
||||||
|
<br />
|
||||||
|
<br />
|
||||||
|
|
||||||
|
**Next steps**
|
||||||
|
|
||||||
|
1. [Set default params for new teams auto-created from SSO](#set-default-params-for-new-teams)
|
||||||
|
|
||||||
### Debugging SSO JWT fields
|
### Debugging SSO JWT fields
|
||||||
|
|
||||||
|
@ -279,6 +284,26 @@ This budget does not apply to keys created under non-default teams.
|
||||||
|
|
||||||
[**Go Here**](./team_budgets.md)
|
[**Go Here**](./team_budgets.md)
|
||||||
|
|
||||||
|
### Set default params for new teams
|
||||||
|
|
||||||
|
When you connect litellm to your SSO provider, litellm can auto-create teams. Use this to set the default `models`, `max_budget`, `budget_duration` for these auto-created teams.
|
||||||
|
|
||||||
|
**How it works**
|
||||||
|
|
||||||
|
1. When litellm fetches `groups` from your SSO provider, it will check if the corresponding group_id exists as a `team_id` in litellm.
|
||||||
|
2. If the team_id does not exist, litellm will auto-create a team with the default params you've set.
|
||||||
|
3. If the team_id already exist, litellm will not apply any settings on the team.
|
||||||
|
|
||||||
|
**Usage**
|
||||||
|
|
||||||
|
```yaml showLineNumbers title="Default Params for new teams"
|
||||||
|
litellm_settings:
|
||||||
|
default_team_params: # Default Params to apply when litellm auto creates a team from SSO IDP provider
|
||||||
|
max_budget: 100 # Optional[float], optional): $100 budget for the team
|
||||||
|
budget_duration: 30d # Optional[str], optional): 30 days budget_duration for the team
|
||||||
|
models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by the team
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
### Restrict Users from creating personal keys
|
### Restrict Users from creating personal keys
|
||||||
|
|
||||||
|
@ -290,7 +315,7 @@ This will also prevent users from using their session tokens on the test keys ch
|
||||||
|
|
||||||
## **All Settings for Self Serve / SSO Flow**
|
## **All Settings for Self Serve / SSO Flow**
|
||||||
|
|
||||||
```yaml
|
```yaml showLineNumbers title="All Settings for Self Serve / SSO Flow"
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
max_internal_user_budget: 10 # max budget for internal users
|
max_internal_user_budget: 10 # max budget for internal users
|
||||||
internal_user_budget_duration: "1mo" # reset every month
|
internal_user_budget_duration: "1mo" # reset every month
|
||||||
|
@ -300,6 +325,11 @@ litellm_settings:
|
||||||
max_budget: 100 # Optional[float], optional): $100 budget for a new SSO sign in user
|
max_budget: 100 # Optional[float], optional): $100 budget for a new SSO sign in user
|
||||||
budget_duration: 30d # Optional[str], optional): 30 days budget_duration for a new SSO sign in user
|
budget_duration: 30d # Optional[str], optional): 30 days budget_duration for a new SSO sign in user
|
||||||
models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by a new SSO sign in user
|
models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by a new SSO sign in user
|
||||||
|
|
||||||
|
default_team_params: # Default Params to apply when litellm auto creates a team from SSO IDP provider
|
||||||
|
max_budget: 100 # Optional[float], optional): $100 budget for the team
|
||||||
|
budget_duration: 30d # Optional[str], optional): 30 days budget_duration for the team
|
||||||
|
models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by the team
|
||||||
|
|
||||||
|
|
||||||
upperbound_key_generate_params: # Upperbound for /key/generate requests when self-serve flow is on
|
upperbound_key_generate_params: # Upperbound for /key/generate requests when self-serve flow is on
|
||||||
|
|
|
@ -65,6 +65,7 @@ from litellm.proxy._types import (
|
||||||
KeyManagementSystem,
|
KeyManagementSystem,
|
||||||
KeyManagementSettings,
|
KeyManagementSettings,
|
||||||
LiteLLM_UpperboundKeyGenerateParams,
|
LiteLLM_UpperboundKeyGenerateParams,
|
||||||
|
NewTeamRequest,
|
||||||
)
|
)
|
||||||
from litellm.types.utils import StandardKeyGenerationConfig, LlmProviders
|
from litellm.types.utils import StandardKeyGenerationConfig, LlmProviders
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
@ -126,19 +127,19 @@ prometheus_initialize_budget_metrics: Optional[bool] = False
|
||||||
require_auth_for_metrics_endpoint: Optional[bool] = False
|
require_auth_for_metrics_endpoint: Optional[bool] = False
|
||||||
argilla_batch_size: Optional[int] = None
|
argilla_batch_size: Optional[int] = None
|
||||||
datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload
|
datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload
|
||||||
gcs_pub_sub_use_v1: Optional[
|
gcs_pub_sub_use_v1: Optional[bool] = (
|
||||||
bool
|
False # if you want to use v1 gcs pubsub logged payload
|
||||||
] = False # if you want to use v1 gcs pubsub logged payload
|
)
|
||||||
argilla_transformation_object: Optional[Dict[str, Any]] = None
|
argilla_transformation_object: Optional[Dict[str, Any]] = None
|
||||||
_async_input_callback: List[
|
_async_input_callback: List[Union[str, Callable, CustomLogger]] = (
|
||||||
Union[str, Callable, CustomLogger]
|
[]
|
||||||
] = [] # internal variable - async custom callbacks are routed here.
|
) # internal variable - async custom callbacks are routed here.
|
||||||
_async_success_callback: List[
|
_async_success_callback: List[Union[str, Callable, CustomLogger]] = (
|
||||||
Union[str, Callable, CustomLogger]
|
[]
|
||||||
] = [] # internal variable - async custom callbacks are routed here.
|
) # internal variable - async custom callbacks are routed here.
|
||||||
_async_failure_callback: List[
|
_async_failure_callback: List[Union[str, Callable, CustomLogger]] = (
|
||||||
Union[str, Callable, CustomLogger]
|
[]
|
||||||
] = [] # internal variable - async custom callbacks are routed here.
|
) # internal variable - async custom callbacks are routed here.
|
||||||
pre_call_rules: List[Callable] = []
|
pre_call_rules: List[Callable] = []
|
||||||
post_call_rules: List[Callable] = []
|
post_call_rules: List[Callable] = []
|
||||||
turn_off_message_logging: Optional[bool] = False
|
turn_off_message_logging: Optional[bool] = False
|
||||||
|
@ -146,18 +147,18 @@ log_raw_request_response: bool = False
|
||||||
redact_messages_in_exceptions: Optional[bool] = False
|
redact_messages_in_exceptions: Optional[bool] = False
|
||||||
redact_user_api_key_info: Optional[bool] = False
|
redact_user_api_key_info: Optional[bool] = False
|
||||||
filter_invalid_headers: Optional[bool] = False
|
filter_invalid_headers: Optional[bool] = False
|
||||||
add_user_information_to_llm_headers: Optional[
|
add_user_information_to_llm_headers: Optional[bool] = (
|
||||||
bool
|
None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
|
||||||
] = None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
|
)
|
||||||
store_audit_logs = False # Enterprise feature, allow users to see audit logs
|
store_audit_logs = False # Enterprise feature, allow users to see audit logs
|
||||||
### end of callbacks #############
|
### end of callbacks #############
|
||||||
|
|
||||||
email: Optional[
|
email: Optional[str] = (
|
||||||
str
|
None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||||
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
)
|
||||||
token: Optional[
|
token: Optional[str] = (
|
||||||
str
|
None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||||
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
)
|
||||||
telemetry = True
|
telemetry = True
|
||||||
max_tokens: int = DEFAULT_MAX_TOKENS # OpenAI Defaults
|
max_tokens: int = DEFAULT_MAX_TOKENS # OpenAI Defaults
|
||||||
drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False))
|
drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False))
|
||||||
|
@ -233,20 +234,24 @@ enable_loadbalancing_on_batch_endpoints: Optional[bool] = None
|
||||||
enable_caching_on_provider_specific_optional_params: bool = (
|
enable_caching_on_provider_specific_optional_params: bool = (
|
||||||
False # feature-flag for caching on optional params - e.g. 'top_k'
|
False # feature-flag for caching on optional params - e.g. 'top_k'
|
||||||
)
|
)
|
||||||
caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
caching: bool = (
|
||||||
caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||||
cache: Optional[
|
)
|
||||||
Cache
|
caching_with_models: bool = (
|
||||||
] = None # cache object <- use this - https://docs.litellm.ai/docs/caching
|
False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||||
|
)
|
||||||
|
cache: Optional[Cache] = (
|
||||||
|
None # cache object <- use this - https://docs.litellm.ai/docs/caching
|
||||||
|
)
|
||||||
default_in_memory_ttl: Optional[float] = None
|
default_in_memory_ttl: Optional[float] = None
|
||||||
default_redis_ttl: Optional[float] = None
|
default_redis_ttl: Optional[float] = None
|
||||||
default_redis_batch_cache_expiry: Optional[float] = None
|
default_redis_batch_cache_expiry: Optional[float] = None
|
||||||
model_alias_map: Dict[str, str] = {}
|
model_alias_map: Dict[str, str] = {}
|
||||||
model_group_alias_map: Dict[str, str] = {}
|
model_group_alias_map: Dict[str, str] = {}
|
||||||
max_budget: float = 0.0 # set the max budget across all providers
|
max_budget: float = 0.0 # set the max budget across all providers
|
||||||
budget_duration: Optional[
|
budget_duration: Optional[str] = (
|
||||||
str
|
None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
|
||||||
] = None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
|
)
|
||||||
default_soft_budget: float = (
|
default_soft_budget: float = (
|
||||||
DEFAULT_SOFT_BUDGET # by default all litellm proxy keys have a soft budget of 50.0
|
DEFAULT_SOFT_BUDGET # by default all litellm proxy keys have a soft budget of 50.0
|
||||||
)
|
)
|
||||||
|
@ -255,11 +260,15 @@ forward_traceparent_to_llm_provider: bool = False
|
||||||
|
|
||||||
_current_cost = 0.0 # private variable, used if max budget is set
|
_current_cost = 0.0 # private variable, used if max budget is set
|
||||||
error_logs: Dict = {}
|
error_logs: Dict = {}
|
||||||
add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
|
add_function_to_prompt: bool = (
|
||||||
|
False # if function calling not supported by api, append function call details to system prompt
|
||||||
|
)
|
||||||
client_session: Optional[httpx.Client] = None
|
client_session: Optional[httpx.Client] = None
|
||||||
aclient_session: Optional[httpx.AsyncClient] = None
|
aclient_session: Optional[httpx.AsyncClient] = None
|
||||||
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
|
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
|
||||||
model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
model_cost_map_url: str = (
|
||||||
|
"https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||||
|
)
|
||||||
suppress_debug_info = False
|
suppress_debug_info = False
|
||||||
dynamodb_table_name: Optional[str] = None
|
dynamodb_table_name: Optional[str] = None
|
||||||
s3_callback_params: Optional[Dict] = None
|
s3_callback_params: Optional[Dict] = None
|
||||||
|
@ -268,6 +277,7 @@ default_key_generate_params: Optional[Dict] = None
|
||||||
upperbound_key_generate_params: Optional[LiteLLM_UpperboundKeyGenerateParams] = None
|
upperbound_key_generate_params: Optional[LiteLLM_UpperboundKeyGenerateParams] = None
|
||||||
key_generation_settings: Optional[StandardKeyGenerationConfig] = None
|
key_generation_settings: Optional[StandardKeyGenerationConfig] = None
|
||||||
default_internal_user_params: Optional[Dict] = None
|
default_internal_user_params: Optional[Dict] = None
|
||||||
|
default_team_params: Optional[NewTeamRequest] = None
|
||||||
default_team_settings: Optional[List] = None
|
default_team_settings: Optional[List] = None
|
||||||
max_user_budget: Optional[float] = None
|
max_user_budget: Optional[float] = None
|
||||||
default_max_internal_user_budget: Optional[float] = None
|
default_max_internal_user_budget: Optional[float] = None
|
||||||
|
@ -281,7 +291,9 @@ disable_end_user_cost_tracking_prometheus_only: Optional[bool] = None
|
||||||
custom_prometheus_metadata_labels: List[str] = []
|
custom_prometheus_metadata_labels: List[str] = []
|
||||||
#### REQUEST PRIORITIZATION ####
|
#### REQUEST PRIORITIZATION ####
|
||||||
priority_reservation: Optional[Dict[str, float]] = None
|
priority_reservation: Optional[Dict[str, float]] = None
|
||||||
force_ipv4: bool = False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6.
|
force_ipv4: bool = (
|
||||||
|
False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6.
|
||||||
|
)
|
||||||
module_level_aclient = AsyncHTTPHandler(
|
module_level_aclient = AsyncHTTPHandler(
|
||||||
timeout=request_timeout, client_alias="module level aclient"
|
timeout=request_timeout, client_alias="module level aclient"
|
||||||
)
|
)
|
||||||
|
@ -295,13 +307,13 @@ fallbacks: Optional[List] = None
|
||||||
context_window_fallbacks: Optional[List] = None
|
context_window_fallbacks: Optional[List] = None
|
||||||
content_policy_fallbacks: Optional[List] = None
|
content_policy_fallbacks: Optional[List] = None
|
||||||
allowed_fails: int = 3
|
allowed_fails: int = 3
|
||||||
num_retries_per_request: Optional[
|
num_retries_per_request: Optional[int] = (
|
||||||
int
|
None # for the request overall (incl. fallbacks + model retries)
|
||||||
] = None # for the request overall (incl. fallbacks + model retries)
|
)
|
||||||
####### SECRET MANAGERS #####################
|
####### SECRET MANAGERS #####################
|
||||||
secret_manager_client: Optional[
|
secret_manager_client: Optional[Any] = (
|
||||||
Any
|
None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
|
||||||
] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
|
)
|
||||||
_google_kms_resource_name: Optional[str] = None
|
_google_kms_resource_name: Optional[str] = None
|
||||||
_key_management_system: Optional[KeyManagementSystem] = None
|
_key_management_system: Optional[KeyManagementSystem] = None
|
||||||
_key_management_settings: KeyManagementSettings = KeyManagementSettings()
|
_key_management_settings: KeyManagementSettings = KeyManagementSettings()
|
||||||
|
@ -1050,10 +1062,10 @@ from .types.llms.custom_llm import CustomLLMItem
|
||||||
from .types.utils import GenericStreamingChunk
|
from .types.utils import GenericStreamingChunk
|
||||||
|
|
||||||
custom_provider_map: List[CustomLLMItem] = []
|
custom_provider_map: List[CustomLLMItem] = []
|
||||||
_custom_providers: List[
|
_custom_providers: List[str] = (
|
||||||
str
|
[]
|
||||||
] = [] # internal helper util, used to track names of custom providers
|
) # internal helper util, used to track names of custom providers
|
||||||
disable_hf_tokenizer_download: Optional[
|
disable_hf_tokenizer_download: Optional[bool] = (
|
||||||
bool
|
None # disable huggingface tokenizer download. Defaults to openai clk100
|
||||||
] = None # disable huggingface tokenizer download. Defaults to openai clk100
|
)
|
||||||
global_disable_no_log_param: bool = False
|
global_disable_no_log_param: bool = False
|
||||||
|
|
|
@ -896,6 +896,68 @@ class SSOAuthenticationHandler:
|
||||||
sso_teams = getattr(result, "team_ids", [])
|
sso_teams = getattr(result, "team_ids", [])
|
||||||
await add_missing_team_member(user_info=user_info, sso_teams=sso_teams)
|
await add_missing_team_member(user_info=user_info, sso_teams=sso_teams)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def create_litellm_team_from_sso_group(
|
||||||
|
litellm_team_id: str,
|
||||||
|
litellm_team_name: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Creates a Litellm Team from a SSO Group ID
|
||||||
|
|
||||||
|
Your SSO provider might have groups that should be created on LiteLLM
|
||||||
|
|
||||||
|
Use this helper to create a Litellm Team from a SSO Group ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
litellm_team_id (str): The ID of the Litellm Team
|
||||||
|
litellm_team_name (Optional[str]): The name of the Litellm Team
|
||||||
|
"""
|
||||||
|
from litellm.proxy.proxy_server import prisma_client
|
||||||
|
|
||||||
|
if prisma_client is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="Prisma client not found. Set it in the proxy_server.py file",
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param="prisma_client",
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
team_obj = await prisma_client.db.litellm_teamtable.find_first(
|
||||||
|
where={"team_id": litellm_team_id}
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug(f"Team object: {team_obj}")
|
||||||
|
|
||||||
|
# only create a new team if it doesn't exist
|
||||||
|
if team_obj:
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"Team already exists: {litellm_team_id} - {litellm_team_name}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
team_request: NewTeamRequest = NewTeamRequest(
|
||||||
|
team_id=litellm_team_id,
|
||||||
|
team_alias=litellm_team_name,
|
||||||
|
)
|
||||||
|
if litellm.default_team_params:
|
||||||
|
team_request = litellm.default_team_params.model_copy(
|
||||||
|
deep=True,
|
||||||
|
update={
|
||||||
|
"team_id": litellm_team_id,
|
||||||
|
"team_alias": litellm_team_name,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await new_team(
|
||||||
|
data=team_request,
|
||||||
|
# params used for Audit Logging
|
||||||
|
http_request=Request(scope={"type": "http", "method": "POST"}),
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
token="",
|
||||||
|
key_alias=f"litellm.{MicrosoftSSOHandler.__name__}",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.exception(f"Error creating Litellm Team: {e}")
|
||||||
|
|
||||||
|
|
||||||
class MicrosoftSSOHandler:
|
class MicrosoftSSOHandler:
|
||||||
"""
|
"""
|
||||||
|
@ -1176,15 +1238,6 @@ class MicrosoftSSOHandler:
|
||||||
|
|
||||||
When a user sets a `SERVICE_PRINCIPAL_ID` in the env, litellm will fetch groups under that service principal and create Litellm Teams from them
|
When a user sets a `SERVICE_PRINCIPAL_ID` in the env, litellm will fetch groups under that service principal and create Litellm Teams from them
|
||||||
"""
|
"""
|
||||||
from litellm.proxy.proxy_server import prisma_client
|
|
||||||
|
|
||||||
if prisma_client is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="Prisma client not found. Set it in the proxy_server.py file",
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param="prisma_client",
|
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"Creating Litellm Teams from Service Principal Teams: {service_principal_teams}"
|
f"Creating Litellm Teams from Service Principal Teams: {service_principal_teams}"
|
||||||
)
|
)
|
||||||
|
@ -1199,36 +1252,10 @@ class MicrosoftSSOHandler:
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
await SSOAuthenticationHandler.create_litellm_team_from_sso_group(
|
||||||
verbose_proxy_logger.debug(
|
litellm_team_id=litellm_team_id,
|
||||||
f"Creating Litellm Team: {litellm_team_id} - {litellm_team_name}"
|
litellm_team_name=litellm_team_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
team_obj = await prisma_client.db.litellm_teamtable.find_first(
|
|
||||||
where={"team_id": litellm_team_id}
|
|
||||||
)
|
|
||||||
verbose_proxy_logger.debug(f"Team object: {team_obj}")
|
|
||||||
|
|
||||||
# only create a new team if it doesn't exist
|
|
||||||
if team_obj:
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f"Team already exists: {litellm_team_id} - {litellm_team_name}"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
await new_team(
|
|
||||||
data=NewTeamRequest(
|
|
||||||
team_id=litellm_team_id,
|
|
||||||
team_alias=litellm_team_name,
|
|
||||||
),
|
|
||||||
# params used for Audit Logging
|
|
||||||
http_request=Request(scope={"type": "http", "method": "POST"}),
|
|
||||||
user_api_key_dict=UserAPIKeyAuth(
|
|
||||||
token="",
|
|
||||||
key_alias=f"litellm.{MicrosoftSSOHandler.__name__}",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
verbose_proxy_logger.exception(f"Error creating Litellm Team: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleSSOHandler:
|
class GoogleSSOHandler:
|
||||||
|
|
|
@ -2,8 +2,9 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import uuid
|
||||||
from typing import Optional, cast
|
from typing import Optional, cast
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
@ -13,6 +14,8 @@ sys.path.insert(
|
||||||
0, os.path.abspath("../../../")
|
0, os.path.abspath("../../../")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.proxy._types import NewTeamRequest
|
||||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||||
from litellm.proxy.management_endpoints.types import CustomOpenID
|
from litellm.proxy.management_endpoints.types import CustomOpenID
|
||||||
from litellm.proxy.management_endpoints.ui_sso import (
|
from litellm.proxy.management_endpoints.ui_sso import (
|
||||||
|
@ -22,6 +25,7 @@ from litellm.proxy.management_endpoints.ui_sso import (
|
||||||
from litellm.types.proxy.management_endpoints.ui_sso import (
|
from litellm.types.proxy.management_endpoints.ui_sso import (
|
||||||
MicrosoftGraphAPIUserGroupDirectoryObject,
|
MicrosoftGraphAPIUserGroupDirectoryObject,
|
||||||
MicrosoftGraphAPIUserGroupResponse,
|
MicrosoftGraphAPIUserGroupResponse,
|
||||||
|
MicrosoftServicePrincipalTeam,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -379,3 +383,96 @@ def test_get_group_ids_from_graph_api_response():
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
assert "group1" in result
|
assert "group1" in result
|
||||||
assert "group2" in result
|
assert "group2" in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_default_team_params():
|
||||||
|
"""
|
||||||
|
When litellm.default_team_params is set, it should be used to create a new team
|
||||||
|
"""
|
||||||
|
# Arrange
|
||||||
|
litellm.default_team_params = NewTeamRequest(
|
||||||
|
max_budget=10, budget_duration="1d", models=["special-gpt-5"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def mock_jsonify_team_object(db_data):
|
||||||
|
return db_data
|
||||||
|
|
||||||
|
# Mock Prisma client
|
||||||
|
mock_prisma = MagicMock()
|
||||||
|
mock_prisma.db.litellm_teamtable.find_first = AsyncMock(return_value=None)
|
||||||
|
mock_prisma.db.litellm_teamtable.create = AsyncMock()
|
||||||
|
mock_prisma.get_data = AsyncMock(return_value=None)
|
||||||
|
mock_prisma.jsonify_team_object = MagicMock(side_effect=mock_jsonify_team_object)
|
||||||
|
|
||||||
|
with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma):
|
||||||
|
# Act
|
||||||
|
team_id = str(uuid.uuid4())
|
||||||
|
await MicrosoftSSOHandler.create_litellm_teams_from_service_principal_team_ids(
|
||||||
|
service_principal_teams=[
|
||||||
|
MicrosoftServicePrincipalTeam(
|
||||||
|
principalId=team_id,
|
||||||
|
principalDisplayName="Test Team",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# Verify team was created with correct parameters
|
||||||
|
mock_prisma.db.litellm_teamtable.create.assert_called_once()
|
||||||
|
print(
|
||||||
|
"mock_prisma.db.litellm_teamtable.create.call_args",
|
||||||
|
mock_prisma.db.litellm_teamtable.create.call_args,
|
||||||
|
)
|
||||||
|
create_call_args = mock_prisma.db.litellm_teamtable.create.call_args.kwargs[
|
||||||
|
"data"
|
||||||
|
]
|
||||||
|
assert create_call_args["team_id"] == team_id
|
||||||
|
assert create_call_args["team_alias"] == "Test Team"
|
||||||
|
assert create_call_args["max_budget"] == 10
|
||||||
|
assert create_call_args["budget_duration"] == "1d"
|
||||||
|
assert create_call_args["models"] == ["special-gpt-5"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_team_without_default_params():
|
||||||
|
"""
|
||||||
|
Test team creation when litellm.default_team_params is None
|
||||||
|
Should create team with just the basic required fields
|
||||||
|
"""
|
||||||
|
# Arrange
|
||||||
|
litellm.default_team_params = None
|
||||||
|
|
||||||
|
def mock_jsonify_team_object(db_data):
|
||||||
|
return db_data
|
||||||
|
|
||||||
|
# Mock Prisma client
|
||||||
|
mock_prisma = MagicMock()
|
||||||
|
mock_prisma.db.litellm_teamtable.find_first = AsyncMock(return_value=None)
|
||||||
|
mock_prisma.db.litellm_teamtable.create = AsyncMock()
|
||||||
|
mock_prisma.get_data = AsyncMock(return_value=None)
|
||||||
|
mock_prisma.jsonify_team_object = MagicMock(side_effect=mock_jsonify_team_object)
|
||||||
|
|
||||||
|
with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma):
|
||||||
|
# Act
|
||||||
|
team_id = str(uuid.uuid4())
|
||||||
|
await MicrosoftSSOHandler.create_litellm_teams_from_service_principal_team_ids(
|
||||||
|
service_principal_teams=[
|
||||||
|
MicrosoftServicePrincipalTeam(
|
||||||
|
principalId=team_id,
|
||||||
|
principalDisplayName="Test Team",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_prisma.db.litellm_teamtable.create.assert_called_once()
|
||||||
|
create_call_args = mock_prisma.db.litellm_teamtable.create.call_args.kwargs[
|
||||||
|
"data"
|
||||||
|
]
|
||||||
|
assert create_call_args["team_id"] == team_id
|
||||||
|
assert create_call_args["team_alias"] == "Test Team"
|
||||||
|
# Should not have any of the optional fields
|
||||||
|
assert "max_budget" not in create_call_args
|
||||||
|
assert "budget_duration" not in create_call_args
|
||||||
|
assert create_call_args["models"] == []
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue