test_default_team_params

This commit is contained in:
Ishaan Jaff 2025-04-10 20:04:42 -07:00
parent 2162e092a5
commit 6d66e2ebf1
3 changed files with 49 additions and 12 deletions

View file

@ -277,7 +277,7 @@ default_key_generate_params: Optional[Dict] = None
upperbound_key_generate_params: Optional[LiteLLM_UpperboundKeyGenerateParams] = None upperbound_key_generate_params: Optional[LiteLLM_UpperboundKeyGenerateParams] = None
key_generation_settings: Optional[StandardKeyGenerationConfig] = None key_generation_settings: Optional[StandardKeyGenerationConfig] = None
default_internal_user_params: Optional[Dict] = None default_internal_user_params: Optional[Dict] = None
default_team_params: Optional[NewTeamRequest] = None default_team_params: Optional[Union[NewTeamRequest, Dict]] = None
default_team_settings: Optional[List] = None default_team_settings: Optional[List] = None
max_user_budget: Optional[float] = None max_user_budget: Optional[float] = None
default_max_internal_user_budget: Optional[float] = None default_max_internal_user_budget: Optional[float] = None

View file

@ -939,13 +939,14 @@ class SSOAuthenticationHandler:
team_id=litellm_team_id, team_id=litellm_team_id,
team_alias=litellm_team_name, team_alias=litellm_team_name,
) )
if litellm.default_team_params and isinstance( if litellm.default_team_params:
litellm.default_team_params, dict team_request = SSOAuthenticationHandler._cast_and_deepcopy_litellm_default_team_params(
): default_team_params=litellm.default_team_params,
_team_request = deepcopy(litellm.default_team_params) litellm_team_id=litellm_team_id,
_team_request["team_id"] = litellm_team_id litellm_team_name=litellm_team_name,
_team_request["team_alias"] = litellm_team_name team_request=team_request,
team_request = NewTeamRequest(**_team_request) )
await new_team( await new_team(
data=team_request, data=team_request,
# params used for Audit Logging # params used for Audit Logging
@ -958,6 +959,35 @@ class SSOAuthenticationHandler:
except Exception as e: except Exception as e:
verbose_proxy_logger.exception(f"Error creating Litellm Team: {e}") verbose_proxy_logger.exception(f"Error creating Litellm Team: {e}")
@staticmethod
def _cast_and_deepcopy_litellm_default_team_params(
default_team_params: Union[NewTeamRequest, Dict],
team_request: NewTeamRequest,
litellm_team_id: str,
litellm_team_name: Optional[str] = None,
) -> NewTeamRequest:
"""
Casts and deepcopies the litellm.default_team_params to a NewTeamRequest object
- Ensures we create a new NewTeamRequest object
- Handle the case where litellm.default_team_params is a dict or a NewTeamRequest object
- Adds the litellm_team_id and litellm_team_name to the NewTeamRequest object
"""
if isinstance(default_team_params, dict):
_team_request = deepcopy(default_team_params)
_team_request["team_id"] = litellm_team_id
_team_request["team_alias"] = litellm_team_name
team_request = NewTeamRequest(**_team_request)
elif isinstance(litellm.default_team_params, NewTeamRequest):
team_request = litellm.default_team_params.model_copy(
deep=True,
update={
"team_id": litellm_team_id,
"team_alias": litellm_team_name,
},
)
return team_request
class MicrosoftSSOHandler: class MicrosoftSSOHandler:
""" """

View file

@ -386,14 +386,21 @@ def test_get_group_ids_from_graph_api_response():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_default_team_params(): @pytest.mark.parametrize(
"team_params",
[
# Test case 1: Using NewTeamRequest
NewTeamRequest(max_budget=10, budget_duration="1d", models=["special-gpt-5"]),
# Test case 2: Using Dict
{"max_budget": 10, "budget_duration": "1d", "models": ["special-gpt-5"]},
],
)
async def test_default_team_params(team_params):
""" """
When litellm.default_team_params is set, it should be used to create a new team When litellm.default_team_params is set, it should be used to create a new team
""" """
# Arrange # Arrange
litellm.default_team_params = NewTeamRequest( litellm.default_team_params = team_params
max_budget=10, budget_duration="1d", models=["special-gpt-5"]
)
def mock_jsonify_team_object(db_data): def mock_jsonify_team_object(db_data):
return db_data return db_data