diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index 2c19bc25b..43e458465 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -420,7 +420,6 @@ async def update_team( @management_endpoint_wrapper async def team_member_add( data: TeamMemberAddRequest, - http_request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): """ @@ -442,8 +441,11 @@ async def team_member_add( from litellm.proxy.proxy_server import ( _duration_in_seconds, create_audit_log_for_update, + get_team_object, litellm_proxy_admin_name, prisma_client, + proxy_logging_obj, + user_api_key_cache, ) if prisma_client is None: @@ -457,8 +459,13 @@ async def team_member_add( status_code=400, detail={"error": "No member/members passed in"} ) - existing_team_row = await prisma_client.db.litellm_teamtable.find_unique( - where={"team_id": data.team_id} + existing_team_row = await get_team_object( + team_id=data.team_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=None, + proxy_logging_obj=proxy_logging_obj, + check_cache_only=False, ) if existing_team_row is None: raise HTTPException( diff --git a/litellm/proxy/management_helpers/utils.py b/litellm/proxy/management_helpers/utils.py index 5c91364de..a897dfd5f 100644 --- a/litellm/proxy/management_helpers/utils.py +++ b/litellm/proxy/management_helpers/utils.py @@ -7,6 +7,7 @@ from typing import Optional from fastapi import Request +import litellm from litellm._logging import verbose_logger from litellm.proxy._types import ( # key request types; user request types; team request types; customer request types DeleteCustomerRequest, @@ -16,6 +17,7 @@ from litellm.proxy._types import ( # key request types; user request types; tea LiteLLM_TeamTable, ManagementEndpointLoggingPayload, Member, + SSOUserDefinedValues, UpdateCustomerRequest, UpdateKeyRequest, UpdateTeamRequest, @@ -26,6 +28,25 @@ from litellm.proxy.common_utils.http_parsing_utils import _read_request_body from litellm.proxy.utils import PrismaClient +def get_new_internal_user_defaults( + user_id: str, user_email: Optional[str] = None +) -> SSOUserDefinedValues: + user_info = litellm.default_user_params or {} + + returned_dict: SSOUserDefinedValues = { + "models": user_info.get("models", None), + "max_budget": user_info.get("max_budget", litellm.max_internal_user_budget), + "budget_duration": user_info.get( + "budget_duration", litellm.internal_user_budget_duration + ), + "user_email": user_email or user_info.get("user_email", None), + "user_id": user_id, + "user_role": "internal_user", + } + + return returned_dict + + async def add_new_member( new_member: Member, max_budget_in_team: Optional[float], @@ -42,15 +63,18 @@ async def add_new_member( """ ## ADD TEAM ID, to USER TABLE IF NEW ## if new_member.user_id is not None: + new_user_defaults = get_new_internal_user_defaults(user_id=new_member.user_id) await prisma_client.db.litellm_usertable.upsert( where={"user_id": new_member.user_id}, data={ "update": {"teams": {"push": [team_id]}}, - "create": {"user_id": new_member.user_id, "teams": [team_id]}, + "create": {"teams": [team_id], **new_user_defaults}, }, ) elif new_member.user_email is not None: - user_data = {"user_id": str(uuid.uuid4()), "user_email": new_member.user_email} + new_user_defaults = get_new_internal_user_defaults( + user_id=str(uuid.uuid4()), user_email=new_member.user_email + ) ## user email is not unique acc. to prisma schema -> future improvement ### for now: check if it exists in db, if not - insert it existing_user_row = await prisma_client.get_data( @@ -62,7 +86,7 @@ async def add_new_member( isinstance(existing_user_row, list) and len(existing_user_row) == 0 ): - await prisma_client.insert_data(data=user_data, table_name="user") + await prisma_client.insert_data(data=new_user_defaults, table_name="user") # type: ignore # Check if trying to set a budget for team member if max_budget_in_team is not None and new_member.user_id is not None: diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index c910e786c..ffc2600ba 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -804,10 +804,16 @@ async def test_get_team_redis(client_no_auth): import random import uuid -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch -from litellm.proxy._types import LitellmUserRoles, NewUserRequest, UserAPIKeyAuth +from litellm.proxy._types import ( + LitellmUserRoles, + NewUserRequest, + TeamMemberAddRequest, + UserAPIKeyAuth, +) from litellm.proxy.management_endpoints.internal_user_endpoints import new_user +from litellm.proxy.management_endpoints.team_endpoints import team_member_add from litellm.tests.test_key_generate_prisma import prisma_client @@ -852,3 +858,67 @@ async def test_create_user_default_budget(prisma_client, user_role): else: assert mock_client.call_args.kwargs["data"]["max_budget"] is None assert mock_client.call_args.kwargs["data"]["budget_duration"] is None + + +@pytest.mark.parametrize("new_member_method", ["user_id", "user_email"]) +@pytest.mark.asyncio +async def test_create_team_member_add(prisma_client, new_member_method): + import time + + from litellm.proxy._types import LiteLLM_TeamTableCachedObj + from litellm.proxy.proxy_server import hash_token, user_api_key_cache + + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + setattr(litellm, "max_internal_user_budget", 10) + setattr(litellm, "internal_user_budget_duration", "5m") + await litellm.proxy.proxy_server.prisma_client.connect() + user = f"ishaan {uuid.uuid4().hex}" + _team_id = "litellm-test-client-id-new" + team_obj = LiteLLM_TeamTableCachedObj( + team_id=_team_id, + blocked=False, + last_refreshed_at=time.time(), + metadata={"guardrails": {"modify_guardrails": False}}, + ) + # user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token) + user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj) + + setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache) + if new_member_method == "user_id": + data = { + "team_id": _team_id, + "member": [{"role": "user", "user_id": user}], + } + elif new_member_method == "user_email": + data = { + "team_id": _team_id, + "member": [{"role": "user", "user_email": user}], + } + team_member_add_request = TeamMemberAddRequest(**data) + + with patch( + "litellm.proxy.proxy_server.prisma_client.db.litellm_usertable", + new_callable=AsyncMock, + ) as mock_litellm_usertable: + mock_client = AsyncMock() + mock_litellm_usertable.upsert = mock_client + mock_litellm_usertable.find_many = AsyncMock(return_value=None) + + await team_member_add( + data=team_member_add_request, user_api_key_dict=UserAPIKeyAuth() + ) + + mock_client.assert_called() + + print(f"mock_client.call_args: {mock_client.call_args}") + print("mock_client.call_args.kwargs: {}".format(mock_client.call_args.kwargs)) + + assert ( + mock_client.call_args.kwargs["data"]["create"]["max_budget"] + == litellm.max_internal_user_budget + ) + assert ( + mock_client.call_args.kwargs["data"]["create"]["budget_duration"] + == litellm.internal_user_budget_duration + )