diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 0177c2190..d660e576d 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1814,3 +1814,15 @@ class CreatePassThroughEndpoint(LiteLLMBase): path: str target: str headers: dict + + +class LiteLLM_TeamMembership(LiteLLMBase): + user_id: str + team_id: str + budget_id: str + litellm_budget_table: Optional[LiteLLM_BudgetTable] + + +class TeamAddMemberResponse(LiteLLM_TeamTable): + updated_users: List[LiteLLM_UserTable] + updated_team_memberships: List[LiteLLM_TeamMembership] diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index 614f37f44..a1b1c618f 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -17,14 +17,17 @@ from litellm.proxy._types import ( DeleteTeamRequest, LiteLLM_AuditLogs, LiteLLM_ModelTable, + LiteLLM_TeamMembership, LiteLLM_TeamTable, LiteLLM_TeamTableCachedObj, + LiteLLM_UserTable, LitellmTableNames, LitellmUserRoles, Member, NewTeamRequest, ProxyErrorTypes, ProxyException, + TeamAddMemberResponse, TeamMemberAddRequest, TeamMemberDeleteRequest, UpdateTeamRequest, @@ -413,6 +416,7 @@ async def update_team( "/team/member_add", tags=["team management"], dependencies=[Depends(user_api_key_auth)], + response_model=TeamAddMemberResponse, ) @management_endpoint_wrapper async def team_member_add( @@ -514,29 +518,64 @@ async def team_member_add( data={"members_with_roles": json.dumps(_db_team_members)}, # type: ignore ) + updated_users: List[LiteLLM_UserTable] = [] + updated_team_memberships: List[LiteLLM_TeamMembership] = [] + if isinstance(data.member, Member): - await add_new_member( - new_member=data.member, - max_budget_in_team=data.max_budget_in_team, - prisma_client=prisma_client, - user_api_key_dict=user_api_key_dict, - litellm_proxy_admin_name=litellm_proxy_admin_name, - team_id=data.team_id, - ) - elif isinstance(data.member, List): - tasks: List = [] - for m in data.member: - await add_new_member( - new_member=m, + try: + updated_user, updated_tm = await add_new_member( + new_member=data.member, max_budget_in_team=data.max_budget_in_team, prisma_client=prisma_client, user_api_key_dict=user_api_key_dict, litellm_proxy_admin_name=litellm_proxy_admin_name, team_id=data.team_id, ) + except Exception as e: + raise HTTPException( + status_code=500, + detail={ + "error": "Unable to add user - {}, to team - {}, for reason - {}".format( + data.member, data.team_id, str(e) + ) + }, + ) + + updated_users.append(updated_user) + if updated_tm is not None: + updated_team_memberships.append(updated_tm) + elif isinstance(data.member, List): + tasks: List = [] + for m in data.member: + try: + updated_user, updated_tm = await add_new_member( + new_member=m, + max_budget_in_team=data.max_budget_in_team, + prisma_client=prisma_client, + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name=litellm_proxy_admin_name, + team_id=data.team_id, + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail={ + "error": "Unable to add user - {}, to team - {}, for reason - {}".format( + data.member, data.team_id, str(e) + ) + }, + ) + updated_users.append(updated_user) + if updated_tm is not None: + updated_team_memberships.append(updated_tm) + await asyncio.gather(*tasks) - return updated_team + return TeamAddMemberResponse( + **updated_team.model_dump(), + updated_users=updated_users, + updated_team_memberships=updated_team_memberships, + ) @router.post( diff --git a/litellm/proxy/management_helpers/utils.py b/litellm/proxy/management_helpers/utils.py index d8a067aa2..b7ba5278b 100644 --- a/litellm/proxy/management_helpers/utils.py +++ b/litellm/proxy/management_helpers/utils.py @@ -3,7 +3,7 @@ import uuid from datetime import datetime from functools import wraps -from typing import Optional +from typing import Optional, Tuple from fastapi import HTTPException, Request @@ -14,7 +14,9 @@ from litellm.proxy._types import ( # key request types; user request types; tea DeleteTeamRequest, DeleteUserRequest, KeyRequest, + LiteLLM_TeamMembership, LiteLLM_TeamTable, + LiteLLM_UserTable, ManagementEndpointLoggingPayload, Member, SSOUserDefinedValues, @@ -59,23 +61,28 @@ async def add_new_member( team_id: str, user_api_key_dict: UserAPIKeyAuth, litellm_proxy_admin_name: str, -): +) -> Tuple[LiteLLM_UserTable, Optional[LiteLLM_TeamMembership]]: """ Add a new member to a team - add team id to user table - add team member w/ budget to team member table + + Returns created/existing user + team membership w/ budget id """ + returned_user: Optional[LiteLLM_UserTable] = None + returned_team_membership: Optional[LiteLLM_TeamMembership] = None ## ADD TEAM ID, to USER TABLE IF NEW ## if new_member.user_id is not None: new_user_defaults = get_new_internal_user_defaults(user_id=new_member.user_id) - await prisma_client.db.litellm_usertable.upsert( + _returned_user = await prisma_client.db.litellm_usertable.upsert( where={"user_id": new_member.user_id}, data={ "update": {"teams": {"push": [team_id]}}, "create": {"teams": [team_id], **new_user_defaults}, # type: ignore }, ) + returned_user = LiteLLM_UserTable(**_returned_user.model_dump()) elif new_member.user_email is not None: new_user_defaults = get_new_internal_user_defaults( user_id=str(uuid.uuid4()), user_email=new_member.user_email @@ -91,13 +98,15 @@ async def add_new_member( isinstance(existing_user_row, list) and len(existing_user_row) == 0 ): new_user_defaults["teams"] = [team_id] - await prisma_client.insert_data(data=new_user_defaults, table_name="user") # type: ignore + _returned_user = await prisma_client.insert_data(data=new_user_defaults, table_name="user") # type: ignore + returned_user = LiteLLM_UserTable(**_returned_user.model_dump()) elif len(existing_user_row) == 1: user_info = existing_user_row[0] - await prisma_client.db.litellm_usertable.update( + _returned_user = await prisma_client.db.litellm_usertable.update( where={"user_id": user_info.user_id}, data={"teams": {"push": [team_id]}}, ) + returned_user = LiteLLM_UserTable(**_returned_user.model_dump()) elif len(existing_user_row) > 1: raise HTTPException( status_code=400, @@ -118,14 +127,26 @@ async def add_new_member( ) _budget_id = response.budget_id - await prisma_client.db.litellm_teammembership.create( - data={ - "team_id": team_id, - "user_id": new_member.user_id, - "budget_id": _budget_id, - } + _returned_team_membership = ( + await prisma_client.db.litellm_teammembership.create( + data={ + "team_id": team_id, + "user_id": new_member.user_id, + "budget_id": _budget_id, + }, + include={"litellm_budget_table": True}, + ) ) + returned_team_membership = LiteLLM_TeamMembership( + **_returned_team_membership.model_dump() + ) + + if returned_user is None: + raise Exception("Unable to update user table with membership information!") + + return returned_user, returned_team_membership + def _delete_user_id_from_cache(kwargs): from litellm.proxy.proxy_server import user_api_key_cache