fix(proxy_server.py): enable /team/update endpoint for adding / deleting users from team

This commit is contained in:
Krrish Dholakia 2024-02-21 14:47:52 -08:00
parent 846757e343
commit 55a02c1a31
3 changed files with 98 additions and 7 deletions

View file

@ -238,6 +238,15 @@ class NewTeamRequest(LiteLLMBase):
metadata: Optional[dict] = None metadata: Optional[dict] = None
class UpdateTeamRequest(LiteLLMBase):
team_id: str # required
team_alias: Optional[str] = None
admins: Optional[list] = None
members: Optional[list] = None
members_with_roles: Optional[List[Member]] = None
metadata: Optional[dict] = None
class LiteLLM_TeamTable(NewTeamRequest): class LiteLLM_TeamTable(NewTeamRequest):
max_budget: Optional[float] = None max_budget: Optional[float] = None
spend: Optional[float] = None spend: Optional[float] = None

View file

@ -4363,11 +4363,81 @@ async def new_team(
@router.post( @router.post(
"/team/update", tags=["team management"], dependencies=[Depends(user_api_key_auth)] "/team/update", tags=["team management"], dependencies=[Depends(user_api_key_auth)]
) )
async def update_team(): async def update_team(
data: UpdateTeamRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
""" """
update team and members add new members to the team
""" """
pass global prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})
if data.team_id is None:
raise HTTPException(status_code=400, detail={"error": "No team id passed in"})
existing_team_row = await prisma_client.get_data(
team_id=data.team_id, table_name="team", query_type="find_unique"
)
updated_kv = data.json(exclude_none=True)
team_row = await prisma_client.update_data(
update_key_values=updated_kv,
data=updated_kv,
table_name="team",
team_id=data.team_id,
)
## ADD NEW USERS ##
existing_user_id_list = []
## Get new users
for user in existing_team_row.members_with_roles:
existing_user_id_list.append(user["user_id"])
## Update new user rows with team id (info used by /user/info to show all teams, user is a part of)
if data.members_with_roles is not None:
for user in data.members_with_roles:
if user.user_id not in existing_user_id_list:
await prisma_client.update_data(
user_id=user.user_id,
data={"user_id": user.user_id, "teams": [team_row["team_id"]]},
update_key_values={
"teams": {
"push": [team_row["team_id"]],
}
},
)
## REMOVE DELETED USERS ##
### Get list of deleted users (old list - new list)
deleted_user_id_list = []
existing_user_id_list = []
## Get old user list
for user in existing_team_row.members_with_roles:
existing_user_id_list.append(user["user_id"])
## Get diff
if data.members_with_roles is not None:
for user in data.members_with_roles:
if user.user_id not in existing_user_id_list:
deleted_user_id_list.append(user.user_id)
## SET UPDATED LIST
if len(deleted_user_id_list) > 0:
# get the deleted users
existing_user_rows = await prisma_client.get_data(
user_id_list=deleted_user_id_list, table_name="user", query_type="find_all"
)
for user in existing_user_rows:
if data.team_id in user["teams"]:
user["teams"].remove(data.team_id)
await prisma_client.update_data(
user_id=user["user_id"],
data=user,
update_key_values={"user_id": user["user_id"], "teams": user["teams"]},
)
return team_row
@router.post( @router.post(

View file

@ -642,13 +642,12 @@ class PrismaClient:
} }
) )
elif query_type == "find_all" and user_id_list is not None: elif query_type == "find_all" and user_id_list is not None:
user_id_values = str(tuple(user_id_list)) user_id_values = ", ".join(f"'{item}'" for item in user_id_list)
sql_query = f""" sql_query = f"""
SELECT * SELECT *
FROM "LiteLLM_UserTable" FROM "LiteLLM_UserTable"
WHERE "user_id" IN {user_id_values} WHERE "user_id" IN ({user_id_values})
""" """
# Execute the raw query # Execute the raw query
# The asterisk before `user_id_list` unpacks the list into separate arguments # The asterisk before `user_id_list` unpacks the list into separate arguments
response = await self.db.query_raw(sql_query) response = await self.db.query_raw(sql_query)
@ -928,6 +927,19 @@ class PrismaClient:
update_key_values = db_data update_key_values = db_data
if "team_id" not in db_data and team_id is not None: if "team_id" not in db_data and team_id is not None:
db_data["team_id"] = team_id db_data["team_id"] = team_id
if "members_with_roles" in db_data and isinstance(
db_data["members_with_roles"], list
):
db_data["members_with_roles"] = json.dumps(
db_data["members_with_roles"]
)
if "members_with_roles" in update_key_values and isinstance(
update_key_values["members_with_roles"], list
):
update_key_values["members_with_roles"] = json.dumps(
update_key_values["members_with_roles"]
)
update_team_row = await self.db.litellm_teamtable.upsert( update_team_row = await self.db.litellm_teamtable.upsert(
where={"team_id": team_id}, # type: ignore where={"team_id": team_id}, # type: ignore
data={ data={
@ -942,7 +954,7 @@ class PrismaClient:
+ f"DB Team Table - update succeeded {update_team_row}" + f"DB Team Table - update succeeded {update_team_row}"
+ "\033[0m" + "\033[0m"
) )
return {"team_id": team_id, "data": db_data} return {"team_id": team_id, "data": update_team_row}
elif ( elif (
table_name is not None table_name is not None
and table_name == "key" and table_name == "key"