fix(management_helpers/utils.py): use user_default max_budget, budget duration on new user upsert during team member add

Fixes https://github.com/BerriAI/litellm/issues/5106
This commit is contained in:
Krrish Dholakia 2024-08-08 19:14:43 -07:00
parent 6af9d9d2b3
commit 1d39c0fb7d
3 changed files with 109 additions and 8 deletions

View file

@ -7,6 +7,7 @@ from typing import Optional
from fastapi import Request
import litellm
from litellm._logging import verbose_logger
from litellm.proxy._types import ( # key request types; user request types; team request types; customer request types
DeleteCustomerRequest,
@ -16,6 +17,7 @@ from litellm.proxy._types import ( # key request types; user request types; tea
LiteLLM_TeamTable,
ManagementEndpointLoggingPayload,
Member,
SSOUserDefinedValues,
UpdateCustomerRequest,
UpdateKeyRequest,
UpdateTeamRequest,
@ -26,6 +28,25 @@ from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.proxy.utils import PrismaClient
def get_new_internal_user_defaults(
user_id: str, user_email: Optional[str] = None
) -> SSOUserDefinedValues:
user_info = litellm.default_user_params or {}
returned_dict: SSOUserDefinedValues = {
"models": user_info.get("models", None),
"max_budget": user_info.get("max_budget", litellm.max_internal_user_budget),
"budget_duration": user_info.get(
"budget_duration", litellm.internal_user_budget_duration
),
"user_email": user_email or user_info.get("user_email", None),
"user_id": user_id,
"user_role": "internal_user",
}
return returned_dict
async def add_new_member(
new_member: Member,
max_budget_in_team: Optional[float],
@ -42,15 +63,18 @@ async def add_new_member(
"""
## ADD TEAM ID, to USER TABLE IF NEW ##
if new_member.user_id is not None:
new_user_defaults = get_new_internal_user_defaults(user_id=new_member.user_id)
await prisma_client.db.litellm_usertable.upsert(
where={"user_id": new_member.user_id},
data={
"update": {"teams": {"push": [team_id]}},
"create": {"user_id": new_member.user_id, "teams": [team_id]},
"create": {"teams": [team_id], **new_user_defaults},
},
)
elif new_member.user_email is not None:
user_data = {"user_id": str(uuid.uuid4()), "user_email": new_member.user_email}
new_user_defaults = get_new_internal_user_defaults(
user_id=str(uuid.uuid4()), user_email=new_member.user_email
)
## user email is not unique acc. to prisma schema -> future improvement
### for now: check if it exists in db, if not - insert it
existing_user_row = await prisma_client.get_data(
@ -62,7 +86,7 @@ async def add_new_member(
isinstance(existing_user_row, list) and len(existing_user_row) == 0
):
await prisma_client.insert_data(data=user_data, table_name="user")
await prisma_client.insert_data(data=new_user_defaults, table_name="user") # type: ignore
# Check if trying to set a budget for team member
if max_budget_in_team is not None and new_member.user_id is not None: