forked from phoenix/litellm-mirror
fix(management/utils.py): fix add_member to team when adding user_email
Fixes https://github.com/BerriAI/litellm/issues/5112
This commit is contained in:
parent
575afa8029
commit
a0a1feb7da
2 changed files with 52 additions and 3 deletions
|
@ -5,7 +5,7 @@ from datetime import datetime
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import HTTPException, Request
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
|
@ -81,7 +81,7 @@ async def add_new_member(
|
||||||
)
|
)
|
||||||
## user email is not unique acc. to prisma schema -> future improvement
|
## user email is not unique acc. to prisma schema -> future improvement
|
||||||
### for now: check if it exists in db, if not - insert it
|
### for now: check if it exists in db, if not - insert it
|
||||||
existing_user_row = await prisma_client.get_data(
|
existing_user_row: Optional[list] = await prisma_client.get_data(
|
||||||
key_val={"user_email": new_member.user_email},
|
key_val={"user_email": new_member.user_email},
|
||||||
table_name="user",
|
table_name="user",
|
||||||
query_type="find_all",
|
query_type="find_all",
|
||||||
|
@ -89,8 +89,21 @@ async def add_new_member(
|
||||||
if existing_user_row is None or (
|
if existing_user_row is None or (
|
||||||
isinstance(existing_user_row, list) and len(existing_user_row) == 0
|
isinstance(existing_user_row, list) and len(existing_user_row) == 0
|
||||||
):
|
):
|
||||||
|
new_user_defaults["teams"] = [team_id]
|
||||||
await prisma_client.insert_data(data=new_user_defaults, table_name="user") # type: ignore
|
await prisma_client.insert_data(data=new_user_defaults, table_name="user") # type: ignore
|
||||||
|
elif len(existing_user_row) == 1:
|
||||||
|
user_info = existing_user_row[0]
|
||||||
|
await prisma_client.db.litellm_usertable.update(
|
||||||
|
where={"user_id": user_info.user_id},
|
||||||
|
data={"teams": {"push": [team_id]}},
|
||||||
|
)
|
||||||
|
elif len(existing_user_row) > 1:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": "Multiple users with this email found in db. Please use 'user_id' instead."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Check if trying to set a budget for team member
|
# Check if trying to set a budget for team member
|
||||||
if max_budget_in_team is not None and new_member.user_id is not None:
|
if max_budget_in_team is not None and new_member.user_id is not None:
|
||||||
|
|
|
@ -414,6 +414,42 @@ async def test_team_update_sc_2():
|
||||||
assert new_team_data["data"][k] == team_data[k]
|
assert new_team_data["data"][k] == team_data[k]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_team_member_add_email():
|
||||||
|
from test_users import get_user_info
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
## Create admin
|
||||||
|
admin_user = f"{uuid.uuid4()}"
|
||||||
|
await new_user(session=session, i=0, user_id=admin_user)
|
||||||
|
## Create team with 1 admin and 1 user
|
||||||
|
member_list = [
|
||||||
|
{"role": "admin", "user_id": admin_user},
|
||||||
|
]
|
||||||
|
team_data = await new_team(session=session, i=0, member_list=member_list)
|
||||||
|
## Add 1 user via email
|
||||||
|
user_email = "krrish{}@berri.ai".format(uuid.uuid4())
|
||||||
|
new_user_info = await new_user(session=session, i=0, user_email=user_email)
|
||||||
|
new_member = {"role": "user", "user_email": user_email}
|
||||||
|
await add_member(
|
||||||
|
session=session, i=0, team_id=team_data["team_id"], members=[new_member]
|
||||||
|
)
|
||||||
|
|
||||||
|
## check user info to confirm user is in team
|
||||||
|
updated_user_info = await get_user_info(
|
||||||
|
session=session, get_user=new_user_info["user_id"], call_user="sk-1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(updated_user_info)
|
||||||
|
|
||||||
|
## check if team in user table
|
||||||
|
is_team_in_list: bool = False
|
||||||
|
for team in updated_user_info["teams"]:
|
||||||
|
if team_data["team_id"] == team["team_id"]:
|
||||||
|
is_team_in_list = True
|
||||||
|
assert is_team_in_list
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_team_delete():
|
async def test_team_delete():
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue