fix(management/utils.py): fix add_member to team when adding user_email

Fixes https://github.com/BerriAI/litellm/issues/5112
This commit is contained in:
Krrish Dholakia 2024-08-10 17:12:09 -07:00
parent 575afa8029
commit a0a1feb7da
2 changed files with 52 additions and 3 deletions

View file

@ -5,7 +5,7 @@ from datetime import datetime
from functools import wraps from functools import wraps
from typing import Optional from typing import Optional
from fastapi import Request from fastapi import HTTPException, Request
import litellm import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
@ -81,7 +81,7 @@ async def add_new_member(
) )
## user email is not unique acc. to prisma schema -> future improvement ## user email is not unique acc. to prisma schema -> future improvement
### for now: check if it exists in db, if not - insert it ### for now: check if it exists in db, if not - insert it
existing_user_row = await prisma_client.get_data( existing_user_row: Optional[list] = await prisma_client.get_data(
key_val={"user_email": new_member.user_email}, key_val={"user_email": new_member.user_email},
table_name="user", table_name="user",
query_type="find_all", query_type="find_all",
@ -89,8 +89,21 @@ async def add_new_member(
if existing_user_row is None or ( if existing_user_row is None or (
isinstance(existing_user_row, list) and len(existing_user_row) == 0 isinstance(existing_user_row, list) and len(existing_user_row) == 0
): ):
new_user_defaults["teams"] = [team_id]
await prisma_client.insert_data(data=new_user_defaults, table_name="user") # type: ignore await prisma_client.insert_data(data=new_user_defaults, table_name="user") # type: ignore
elif len(existing_user_row) == 1:
user_info = existing_user_row[0]
await prisma_client.db.litellm_usertable.update(
where={"user_id": user_info.user_id},
data={"teams": {"push": [team_id]}},
)
elif len(existing_user_row) > 1:
raise HTTPException(
status_code=400,
detail={
"error": "Multiple users with this email found in db. Please use 'user_id' instead."
},
)
# Check if trying to set a budget for team member # Check if trying to set a budget for team member
if max_budget_in_team is not None and new_member.user_id is not None: if max_budget_in_team is not None and new_member.user_id is not None:

View file

@ -414,6 +414,42 @@ async def test_team_update_sc_2():
assert new_team_data["data"][k] == team_data[k] assert new_team_data["data"][k] == team_data[k]
@pytest.mark.asyncio
async def test_team_member_add_email():
from test_users import get_user_info
async with aiohttp.ClientSession() as session:
## Create admin
admin_user = f"{uuid.uuid4()}"
await new_user(session=session, i=0, user_id=admin_user)
## Create team with 1 admin and 1 user
member_list = [
{"role": "admin", "user_id": admin_user},
]
team_data = await new_team(session=session, i=0, member_list=member_list)
## Add 1 user via email
user_email = "krrish{}@berri.ai".format(uuid.uuid4())
new_user_info = await new_user(session=session, i=0, user_email=user_email)
new_member = {"role": "user", "user_email": user_email}
await add_member(
session=session, i=0, team_id=team_data["team_id"], members=[new_member]
)
## check user info to confirm user is in team
updated_user_info = await get_user_info(
session=session, get_user=new_user_info["user_id"], call_user="sk-1234"
)
print(updated_user_info)
## check if team in user table
is_team_in_list: bool = False
for team in updated_user_info["teams"]:
if team_data["team_id"] == team["team_id"]:
is_team_in_list = True
assert is_team_in_list
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_team_delete(): async def test_team_delete():
""" """