forked from phoenix/litellm-mirror
fix(proxy_server.py): allow passing in a list of team members
allows batch adding members to a team by passing in a list. fixes concurrency issue caused by calling team/member_add in parallel
This commit is contained in:
parent
dddd4a73fe
commit
def648ed3f
6 changed files with 144 additions and 78 deletions
|
@ -757,9 +757,24 @@ class GlobalEndUsersSpend(LiteLLMBase):
|
||||||
|
|
||||||
class TeamMemberAddRequest(LiteLLMBase):
|
class TeamMemberAddRequest(LiteLLMBase):
|
||||||
team_id: str
|
team_id: str
|
||||||
member: Member
|
member: Union[List[Member], Member]
|
||||||
max_budget_in_team: Optional[float] = None # Users max budget within the team
|
max_budget_in_team: Optional[float] = None # Users max budget within the team
|
||||||
|
|
||||||
|
def __init__(self, **data):
|
||||||
|
member_data = data.get("member")
|
||||||
|
if isinstance(member_data, list):
|
||||||
|
# If member is a list of dictionaries, convert each dictionary to a Member object
|
||||||
|
members = [Member(**item) for item in member_data]
|
||||||
|
# Replace member_data with the list of Member objects
|
||||||
|
data["member"] = members
|
||||||
|
elif isinstance(member_data, dict):
|
||||||
|
# If member is a dictionary, convert it to a single Member object
|
||||||
|
member = Member(**member_data)
|
||||||
|
# Replace member_data with the single Member object
|
||||||
|
data["member"] = member
|
||||||
|
# Call the superclass __init__ method to initialize the object
|
||||||
|
super().__init__(**data)
|
||||||
|
|
||||||
|
|
||||||
class TeamMemberDeleteRequest(LiteLLMBase):
|
class TeamMemberDeleteRequest(LiteLLMBase):
|
||||||
team_id: str
|
team_id: str
|
||||||
|
|
63
litellm/proxy/management_helpers/utils.py
Normal file
63
litellm/proxy/management_helpers/utils.py
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
# What is this?
|
||||||
|
## Helper utils for the management endpoints (keys/users/teams)
|
||||||
|
|
||||||
|
from litellm.proxy._types import LiteLLM_TeamTable, Member, UserAPIKeyAuth
|
||||||
|
from litellm.proxy.utils import PrismaClient
|
||||||
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
async def add_new_member(
|
||||||
|
new_member: Member,
|
||||||
|
max_budget_in_team: Optional[float],
|
||||||
|
prisma_client: PrismaClient,
|
||||||
|
team_id: str,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
litellm_proxy_admin_name: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add a new member to a team
|
||||||
|
|
||||||
|
- add team id to user table
|
||||||
|
- add team member w/ budget to team member table
|
||||||
|
"""
|
||||||
|
## ADD TEAM ID, to USER TABLE IF NEW ##
|
||||||
|
if new_member.user_id is not None:
|
||||||
|
await prisma_client.db.litellm_usertable.update(
|
||||||
|
where={"user_id": new_member.user_id},
|
||||||
|
data={"teams": {"push": [team_id]}},
|
||||||
|
)
|
||||||
|
elif new_member.user_email is not None:
|
||||||
|
user_data = {"user_id": str(uuid.uuid4()), "user_email": new_member.user_email}
|
||||||
|
## user email is not unique acc. to prisma schema -> future improvement
|
||||||
|
### for now: check if it exists in db, if not - insert it
|
||||||
|
existing_user_row = await prisma_client.get_data(
|
||||||
|
key_val={"user_email": new_member.user_email},
|
||||||
|
table_name="user",
|
||||||
|
query_type="find_all",
|
||||||
|
)
|
||||||
|
if existing_user_row is None or (
|
||||||
|
isinstance(existing_user_row, list) and len(existing_user_row) == 0
|
||||||
|
):
|
||||||
|
|
||||||
|
await prisma_client.insert_data(data=user_data, table_name="user")
|
||||||
|
|
||||||
|
# Check if trying to set a budget for team member
|
||||||
|
if max_budget_in_team is not None and new_member.user_id is not None:
|
||||||
|
# create a new budget item for this member
|
||||||
|
response = await prisma_client.db.litellm_budgettable.create(
|
||||||
|
data={
|
||||||
|
"max_budget": max_budget_in_team,
|
||||||
|
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||||
|
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
_budget_id = response.budget_id
|
||||||
|
await prisma_client.db.litellm_teammembership.create(
|
||||||
|
data={
|
||||||
|
"team_id": team_id,
|
||||||
|
"user_id": new_member.user_id,
|
||||||
|
"budget_id": _budget_id,
|
||||||
|
}
|
||||||
|
)
|
|
@ -90,6 +90,7 @@ from litellm.types.llms.openai import (
|
||||||
HttpxBinaryResponseContent,
|
HttpxBinaryResponseContent,
|
||||||
)
|
)
|
||||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||||
|
from litellm.proxy.management_helpers.utils import add_new_member
|
||||||
from litellm.proxy.utils import (
|
from litellm.proxy.utils import (
|
||||||
PrismaClient,
|
PrismaClient,
|
||||||
DBClient,
|
DBClient,
|
||||||
|
@ -10159,10 +10160,12 @@ async def team_member_add(
|
||||||
raise HTTPException(status_code=400, detail={"error": "No team id passed in"})
|
raise HTTPException(status_code=400, detail={"error": "No team id passed in"})
|
||||||
|
|
||||||
if data.member is None:
|
if data.member is None:
|
||||||
raise HTTPException(status_code=400, detail={"error": "No member passed in"})
|
raise HTTPException(
|
||||||
|
status_code=400, detail={"error": "No member/members passed in"}
|
||||||
|
)
|
||||||
|
|
||||||
existing_team_row = await prisma_client.get_data( # type: ignore
|
existing_team_row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||||
team_id=data.team_id, table_name="team", query_type="find_unique"
|
where={"team_id": data.team_id}
|
||||||
)
|
)
|
||||||
if existing_team_row is None:
|
if existing_team_row is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
@ -10172,75 +10175,52 @@ async def team_member_add(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
complete_team_data = LiteLLM_TeamTable(**existing_team_row.model_dump())
|
||||||
|
|
||||||
|
if isinstance(data.member, Member):
|
||||||
|
# add to team db
|
||||||
new_member = data.member
|
new_member = data.member
|
||||||
|
|
||||||
existing_team_row.members_with_roles.append(new_member)
|
complete_team_data.members_with_roles.append(new_member)
|
||||||
|
|
||||||
complete_team_data = LiteLLM_TeamTable(
|
elif isinstance(data.member, List):
|
||||||
**_get_pydantic_json_dict(existing_team_row),
|
# add to team db
|
||||||
|
new_members = data.member
|
||||||
|
|
||||||
|
complete_team_data.members_with_roles.extend(new_members)
|
||||||
|
|
||||||
|
# ADD MEMBER TO TEAM
|
||||||
|
_db_team_members = [
|
||||||
|
m.model_dump() for m in complete_team_data.members_with_roles
|
||||||
|
]
|
||||||
|
updated_team = await prisma_client.db.litellm_teamtable.update(
|
||||||
|
where={"team_id": data.team_id},
|
||||||
|
data={"members_with_roles": json.dumps(_db_team_members)}, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
team_row = await prisma_client.update_data(
|
if isinstance(data.member, Member):
|
||||||
update_key_values=complete_team_data.json(exclude_none=True),
|
await add_new_member(
|
||||||
data=complete_team_data.json(exclude_none=True),
|
new_member=data.member,
|
||||||
table_name="team",
|
max_budget_in_team=data.max_budget_in_team,
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
litellm_proxy_admin_name=litellm_proxy_admin_name,
|
||||||
team_id=data.team_id,
|
team_id=data.team_id,
|
||||||
)
|
)
|
||||||
|
elif isinstance(data.member, List):
|
||||||
## ADD USER, IF NEW ##
|
tasks: List = []
|
||||||
user_data = { # type: ignore
|
for m in data.member:
|
||||||
"teams": [team_row["team_id"]],
|
await add_new_member(
|
||||||
"models": team_row["data"].models,
|
new_member=m,
|
||||||
}
|
max_budget_in_team=data.max_budget_in_team,
|
||||||
if new_member.user_id is not None:
|
prisma_client=prisma_client,
|
||||||
user_data["user_id"] = new_member.user_id # type: ignore
|
user_api_key_dict=user_api_key_dict,
|
||||||
await prisma_client.update_data(
|
litellm_proxy_admin_name=litellm_proxy_admin_name,
|
||||||
user_id=new_member.user_id,
|
team_id=data.team_id,
|
||||||
data=user_data,
|
|
||||||
update_key_values_custom_query={
|
|
||||||
"teams": {
|
|
||||||
"push": [team_row["team_id"]],
|
|
||||||
}
|
|
||||||
},
|
|
||||||
table_name="user",
|
|
||||||
)
|
)
|
||||||
elif new_member.user_email is not None:
|
await asyncio.gather(*tasks)
|
||||||
user_data["user_id"] = str(uuid.uuid4())
|
|
||||||
user_data["user_email"] = new_member.user_email
|
|
||||||
## user email is not unique acc. to prisma schema -> future improvement
|
|
||||||
### for now: check if it exists in db, if not - insert it
|
|
||||||
existing_user_row = await prisma_client.get_data(
|
|
||||||
key_val={"user_email": new_member.user_email},
|
|
||||||
table_name="user",
|
|
||||||
query_type="find_all",
|
|
||||||
)
|
|
||||||
if existing_user_row is None or (
|
|
||||||
isinstance(existing_user_row, list) and len(existing_user_row) == 0
|
|
||||||
):
|
|
||||||
|
|
||||||
await prisma_client.insert_data(data=user_data, table_name="user")
|
return updated_team
|
||||||
|
|
||||||
# Check if trying to set a budget for team member
|
|
||||||
if data.max_budget_in_team is not None and new_member.user_id is not None:
|
|
||||||
# create a new budget item for this member
|
|
||||||
response = await prisma_client.db.litellm_budgettable.create(
|
|
||||||
data={
|
|
||||||
"max_budget": data.max_budget_in_team,
|
|
||||||
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
|
||||||
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
_budget_id = response.budget_id
|
|
||||||
await prisma_client.db.litellm_teammembership.create(
|
|
||||||
data={
|
|
||||||
"team_id": data.team_id,
|
|
||||||
"user_id": new_member.user_id,
|
|
||||||
"budget_id": _budget_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return team_row
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
|
|
|
@ -91,7 +91,7 @@ model LiteLLM_TeamTable {
|
||||||
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
|
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
|
||||||
model_spend Json @default("{}")
|
model_spend Json @default("{}")
|
||||||
model_max_budget Json @default("{}")
|
model_max_budget Json @default("{}")
|
||||||
model_id Int? @unique
|
model_id Int? @unique // id for LiteLLM_ModelTable -> stores team-level model aliases
|
||||||
litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id])
|
litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id])
|
||||||
litellm_model_table LiteLLM_ModelTable? @relation(fields: [model_id], references: [id])
|
litellm_model_table LiteLLM_ModelTable? @relation(fields: [model_id], references: [id])
|
||||||
}
|
}
|
||||||
|
|
|
@ -91,7 +91,7 @@ model LiteLLM_TeamTable {
|
||||||
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
|
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
|
||||||
model_spend Json @default("{}")
|
model_spend Json @default("{}")
|
||||||
model_max_budget Json @default("{}")
|
model_max_budget Json @default("{}")
|
||||||
model_id Int? @unique
|
model_id Int? @unique // id for LiteLLM_ModelTable -> stores team-level model aliases
|
||||||
litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id])
|
litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id])
|
||||||
litellm_model_table LiteLLM_ModelTable? @relation(fields: [model_id], references: [id])
|
litellm_model_table LiteLLM_ModelTable? @relation(fields: [model_id], references: [id])
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,7 +49,7 @@ async def new_user(
|
||||||
|
|
||||||
|
|
||||||
async def add_member(
|
async def add_member(
|
||||||
session, i, team_id, user_id=None, user_email=None, max_budget=None
|
session, i, team_id, user_id=None, user_email=None, max_budget=None, members=None
|
||||||
):
|
):
|
||||||
url = "http://0.0.0.0:4000/team/member_add"
|
url = "http://0.0.0.0:4000/team/member_add"
|
||||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||||
|
@ -58,10 +58,13 @@ async def add_member(
|
||||||
data["member"]["user_email"] = user_email
|
data["member"]["user_email"] = user_email
|
||||||
elif user_id is not None:
|
elif user_id is not None:
|
||||||
data["member"]["user_id"] = user_id
|
data["member"]["user_id"] = user_id
|
||||||
|
elif members is not None:
|
||||||
|
data["member"] = members
|
||||||
|
|
||||||
if max_budget is not None:
|
if max_budget is not None:
|
||||||
data["max_budget_in_team"] = max_budget
|
data["max_budget_in_team"] = max_budget
|
||||||
|
|
||||||
|
print("sent data: {}".format(data))
|
||||||
async with session.post(url, headers=headers, json=data) as response:
|
async with session.post(url, headers=headers, json=data) as response:
|
||||||
status = response.status
|
status = response.status
|
||||||
response_text = await response.text()
|
response_text = await response.text()
|
||||||
|
@ -339,7 +342,7 @@ async def test_team_info():
|
||||||
async def test_team_update_sc_2():
|
async def test_team_update_sc_2():
|
||||||
"""
|
"""
|
||||||
- Create team
|
- Create team
|
||||||
- Add 1 user (doesn't exist in db)
|
- Add 3 users (doesn't exist in db)
|
||||||
- Change team alias
|
- Change team alias
|
||||||
- Check if it works
|
- Check if it works
|
||||||
- Assert team object unchanged besides team alias
|
- Assert team object unchanged besides team alias
|
||||||
|
@ -353,15 +356,20 @@ async def test_team_update_sc_2():
|
||||||
{"role": "admin", "user_id": admin_user},
|
{"role": "admin", "user_id": admin_user},
|
||||||
]
|
]
|
||||||
team_data = await new_team(session=session, i=0, member_list=member_list)
|
team_data = await new_team(session=session, i=0, member_list=member_list)
|
||||||
## Create new normal user
|
## Create 10 normal users
|
||||||
new_normal_user = f"krrish_{uuid.uuid4()}@berri.ai"
|
members = [
|
||||||
|
{"role": "user", "user_id": f"krrish_{uuid.uuid4()}@berri.ai"}
|
||||||
|
for _ in range(10)
|
||||||
|
]
|
||||||
await add_member(
|
await add_member(
|
||||||
session=session,
|
session=session, i=0, team_id=team_data["team_id"], members=members
|
||||||
i=0,
|
|
||||||
team_id=team_data["team_id"],
|
|
||||||
user_id=None,
|
|
||||||
user_email=new_normal_user,
|
|
||||||
)
|
)
|
||||||
|
## ASSERT TEAM SIZE
|
||||||
|
team_info = await get_team_info(
|
||||||
|
session=session, get_team=team_data["team_id"], call_key="sk-1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(team_info["team_info"]["members_with_roles"]) == 12
|
||||||
|
|
||||||
## CHANGE TEAM ALIAS
|
## CHANGE TEAM ALIAS
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue