Merge pull request #2119 from BerriAI/litellm_updated_team_endpoints

Enable `/team/update`, `/team/delete` endpoints + create teams with user defined roles
This commit is contained in:
Krish Dholakia 2024-02-21 17:24:58 -08:00 committed by GitHub
commit 0733bf1e7a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 464 additions and 35 deletions

View file

@ -224,14 +224,33 @@ class UpdateUserRequest(GenerateRequestBase):
max_budget: Optional[float] = None max_budget: Optional[float] = None
class Member(LiteLLMBase):
role: Literal["admin", "user"]
user_id: str
class NewTeamRequest(LiteLLMBase): class NewTeamRequest(LiteLLMBase):
team_alias: Optional[str] = None team_alias: Optional[str] = None
team_id: Optional[str] = None team_id: Optional[str] = None
admins: list = [] admins: list = []
members: list = [] members: list = []
members_with_roles: List[Member] = []
metadata: Optional[dict] = None metadata: Optional[dict] = None
class UpdateTeamRequest(LiteLLMBase):
team_id: str # required
team_alias: Optional[str] = None
admins: Optional[list] = None
members: Optional[list] = None
members_with_roles: Optional[List[Member]] = None
metadata: Optional[dict] = None
class DeleteTeamRequest(LiteLLMBase):
team_ids: List[str] # required
class LiteLLM_TeamTable(NewTeamRequest): class LiteLLM_TeamTable(NewTeamRequest):
max_budget: Optional[float] = None max_budget: Optional[float] = None
spend: Optional[float] = None spend: Optional[float] = None

View file

@ -4060,9 +4060,30 @@ async def user_info(
else: else:
user_info = None user_info = None
## GET ALL TEAMS ## ## GET ALL TEAMS ##
teams = await prisma_client.get_data( team_list = []
team_id_list = []
# _DEPRECATED_ check if user in 'member' field
teams_1 = await prisma_client.get_data(
user_id=user_id, table_name="team", query_type="find_all" user_id=user_id, table_name="team", query_type="find_all"
) )
if teams_1 is not None and isinstance(teams_1, list):
team_list = teams_1
for team in teams_1:
team_id_list.append(team.team_id)
if user_info is not None:
# *NEW* get all teams in user 'teams' field
teams_2 = await prisma_client.get_data(
team_id_list=user_info.teams, table_name="team", query_type="find_all"
)
if teams_2 is not None and isinstance(teams_2, list):
for team in teams_2:
if team.team_id not in team_id_list:
team_list.append(team)
team_id_list.append(team.team_id)
## GET ALL KEYS ## ## GET ALL KEYS ##
keys = await prisma_client.get_data( keys = await prisma_client.get_data(
user_id=user_id, user_id=user_id,
@ -4090,9 +4111,10 @@ async def user_info(
"user_id": user_id, "user_id": user_id,
"user_info": user_info, "user_info": user_info,
"keys": keys, "keys": keys,
"teams": teams, "teams": team_list,
} }
except Exception as e: except Exception as e:
traceback.print_exc()
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise ProxyException( raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"), message=getattr(e, "detail", f"Authentication Error({str(e)})"),
@ -4274,12 +4296,31 @@ async def new_team(
Parameters: Parameters:
- team_alias: Optional[str] - User defined team alias - team_alias: Optional[str] - User defined team alias
- team_id: Optional[str] - The team id of the user. If none passed, we'll generate it. - team_id: Optional[str] - The team id of the user. If none passed, we'll generate it.
- admins: list - A list of user IDs that will be owning the team - members_with_roles: list - A list of dictionaries, mapping user_id to role in team (either 'admin' or 'user')
- members: list - A list of user IDs that will be members of the team
- metadata: Optional[dict] - Metadata for team, store information for team. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" } - metadata: Optional[dict] - Metadata for team, store information for team. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" }
Returns: Returns:
- team_id: (str) Unique team id - used for tracking spend across multiple keys for same team id. - team_id: (str) Unique team id - used for tracking spend across multiple keys for same team id.
_deprecated_params:
- admins: list - A list of user_id's for the admin role
- users: list - A list of user_id's for the user role
Example Request:
```
curl --location 'http://0.0.0.0:8000/team/new' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"team_alias": "my-new-team_2",
"members_with_roles": [{"role": "admin", "user_id": "user-1234"},
{"role": "user", "user_id": "user-2434"}]
}'
```
""" """
global prisma_client global prisma_client
@ -4303,27 +4344,124 @@ async def new_team(
team_row = await prisma_client.insert_data( team_row = await prisma_client.insert_data(
data=complete_team_data.json(exclude_none=True), table_name="team" data=complete_team_data.json(exclude_none=True), table_name="team"
) )
## ADD TEAM ID TO USER TABLE ##
for user in complete_team_data.members_with_roles:
## add team id to user row ##
await prisma_client.update_data(
user_id=user.user_id,
data={"user_id": user.user_id, "teams": [team_row.team_id]},
update_key_values={
"teams": {
"push ": [team_row.team_id],
}
},
)
return team_row return team_row
@router.post( @router.post(
"/team/update", tags=["team management"], dependencies=[Depends(user_api_key_auth)] "/team/update", tags=["team management"], dependencies=[Depends(user_api_key_auth)]
) )
async def update_team(): async def update_team(
data: UpdateTeamRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
""" """
update team and members add new members to the team
""" """
pass global prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})
if data.team_id is None:
raise HTTPException(status_code=400, detail={"error": "No team id passed in"})
existing_team_row = await prisma_client.get_data(
team_id=data.team_id, table_name="team", query_type="find_unique"
)
updated_kv = data.json(exclude_none=True)
team_row = await prisma_client.update_data(
update_key_values=updated_kv,
data=updated_kv,
table_name="team",
team_id=data.team_id,
)
## ADD NEW USERS ##
existing_user_id_list = []
## Get new users
for user in existing_team_row.members_with_roles:
existing_user_id_list.append(user["user_id"])
## Update new user rows with team id (info used by /user/info to show all teams, user is a part of)
if data.members_with_roles is not None:
for user in data.members_with_roles:
if user.user_id not in existing_user_id_list:
await prisma_client.update_data(
user_id=user.user_id,
data={"user_id": user.user_id, "teams": [team_row["team_id"]]},
update_key_values={
"teams": {
"push": [team_row["team_id"]],
}
},
)
## REMOVE DELETED USERS ##
### Get list of deleted users (old list - new list)
deleted_user_id_list = []
existing_user_id_list = []
## Get old user list
for user in existing_team_row.members_with_roles:
existing_user_id_list.append(user["user_id"])
## Get diff
if data.members_with_roles is not None:
for user in data.members_with_roles:
if user.user_id not in existing_user_id_list:
deleted_user_id_list.append(user.user_id)
## SET UPDATED LIST
if len(deleted_user_id_list) > 0:
# get the deleted users
existing_user_rows = await prisma_client.get_data(
user_id_list=deleted_user_id_list, table_name="user", query_type="find_all"
)
for user in existing_user_rows:
if data.team_id in user["teams"]:
user["teams"].remove(data.team_id)
await prisma_client.update_data(
user_id=user["user_id"],
data=user,
update_key_values={"user_id": user["user_id"], "teams": user["teams"]},
)
return team_row
@router.post( @router.post(
"/team/delete", tags=["team management"], dependencies=[Depends(user_api_key_auth)] "/team/delete", tags=["team management"], dependencies=[Depends(user_api_key_auth)]
) )
async def delete_team(): async def delete_team(
data: DeleteTeamRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
""" """
delete team and team keys delete team and associated team keys
""" """
pass global prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})
if data.team_ids is None:
raise HTTPException(status_code=400, detail={"error": "No team id passed in"})
## DELETE ASSOCIATED KEYS
await prisma_client.delete_data(team_id_list=data.team_ids, table_name="key")
## DELETE TEAMS
await prisma_client.delete_data(team_id_list=data.team_ids, table_name="team")
@router.get( @router.get(

View file

@ -13,6 +13,7 @@ model LiteLLM_TeamTable {
team_alias String? team_alias String?
admins String[] admins String[]
members String[] members String[]
members_with_roles Json @default("{}")
metadata Json @default("{}") metadata Json @default("{}")
max_budget Float? max_budget Float?
spend Float @default(0.0) spend Float @default(0.0)
@ -32,6 +33,7 @@ model LiteLLM_TeamTable {
model LiteLLM_UserTable { model LiteLLM_UserTable {
user_id String @unique user_id String @unique
team_id String? team_id String?
teams String[] @default([])
user_role String? user_role String?
max_budget Float? max_budget Float?
spend Float @default(0.0) spend Float @default(0.0)
@ -103,5 +105,5 @@ model LiteLLM_UserNotifications {
user_id String user_id String
models String[] models String[]
justification String justification String
status String // approved, disapproved, pending status String // approved, disapproved, pending
} }

View file

@ -532,6 +532,7 @@ class PrismaClient:
user_id: Optional[str] = None, user_id: Optional[str] = None,
user_id_list: Optional[list] = None, user_id_list: Optional[list] = None,
team_id: Optional[str] = None, team_id: Optional[str] = None,
team_id_list: Optional[list] = None,
key_val: Optional[dict] = None, key_val: Optional[dict] = None,
table_name: Optional[ table_name: Optional[
Literal["user", "key", "config", "spend", "team", "user_notification"] Literal["user", "key", "config", "spend", "team", "user_notification"]
@ -641,13 +642,12 @@ class PrismaClient:
} }
) )
elif query_type == "find_all" and user_id_list is not None: elif query_type == "find_all" and user_id_list is not None:
user_id_values = str(tuple(user_id_list)) user_id_values = ", ".join(f"'{item}'" for item in user_id_list)
sql_query = f""" sql_query = f"""
SELECT * SELECT *
FROM "LiteLLM_UserTable" FROM "LiteLLM_UserTable"
WHERE "user_id" IN {user_id_values} WHERE "user_id" IN ({user_id_values})
""" """
# Execute the raw query # Execute the raw query
# The asterisk before `user_id_list` unpacks the list into separate arguments # The asterisk before `user_id_list` unpacks the list into separate arguments
response = await self.db.query_raw(sql_query) response = await self.db.query_raw(sql_query)
@ -697,7 +697,13 @@ class PrismaClient:
) )
elif query_type == "find_all" and user_id is not None: elif query_type == "find_all" and user_id is not None:
response = await self.db.litellm_teamtable.find_many( response = await self.db.litellm_teamtable.find_many(
where={"members": {"has": user_id}} where={
"members": {"has": user_id},
},
)
elif query_type == "find_all" and team_id_list is not None:
response = await self.db.litellm_teamtable.find_many(
where={"team_id": {"in": team_id_list}}
) )
return response return response
elif table_name == "user_notification": elif table_name == "user_notification":
@ -769,6 +775,12 @@ class PrismaClient:
return new_user_row return new_user_row
elif table_name == "team": elif table_name == "team":
db_data = self.jsonify_object(data=data) db_data = self.jsonify_object(data=data)
if db_data.get("members_with_roles", None) is not None and isinstance(
db_data["members_with_roles"], list
):
db_data["members_with_roles"] = json.dumps(
db_data["members_with_roles"]
)
new_team_row = await self.db.litellm_teamtable.upsert( new_team_row = await self.db.litellm_teamtable.upsert(
where={"team_id": data["team_id"]}, where={"team_id": data["team_id"]},
data={ data={
@ -915,6 +927,19 @@ class PrismaClient:
update_key_values = db_data update_key_values = db_data
if "team_id" not in db_data and team_id is not None: if "team_id" not in db_data and team_id is not None:
db_data["team_id"] = team_id db_data["team_id"] = team_id
if "members_with_roles" in db_data and isinstance(
db_data["members_with_roles"], list
):
db_data["members_with_roles"] = json.dumps(
db_data["members_with_roles"]
)
if "members_with_roles" in update_key_values and isinstance(
update_key_values["members_with_roles"], list
):
update_key_values["members_with_roles"] = json.dumps(
update_key_values["members_with_roles"]
)
update_team_row = await self.db.litellm_teamtable.upsert( update_team_row = await self.db.litellm_teamtable.upsert(
where={"team_id": team_id}, # type: ignore where={"team_id": team_id}, # type: ignore
data={ data={
@ -929,7 +954,7 @@ class PrismaClient:
+ f"DB Team Table - update succeeded {update_team_row}" + f"DB Team Table - update succeeded {update_team_row}"
+ "\033[0m" + "\033[0m"
) )
return {"team_id": team_id, "data": db_data} return {"team_id": team_id, "data": update_team_row}
elif ( elif (
table_name is not None table_name is not None
and table_name == "key" and table_name == "key"
@ -1001,22 +1026,45 @@ class PrismaClient:
max_time=10, # maximum total time to retry for max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff on_backoff=on_backoff, # specifying the function to call on backoff
) )
async def delete_data(self, tokens: List): async def delete_data(
self,
tokens: Optional[List] = None,
team_id_list: Optional[List] = None,
table_name: Optional[Literal["user", "key", "config", "spend", "team"]] = None,
):
""" """
Allow user to delete a key(s) Allow user to delete a key(s)
""" """
try: try:
hashed_tokens = [] if tokens is not None and isinstance(tokens, List):
for token in tokens: hashed_tokens = []
if isinstance(token, str) and token.startswith("sk-"): for token in tokens:
hashed_token = self.hash_token(token=token) if isinstance(token, str) and token.startswith("sk-"):
else: hashed_token = self.hash_token(token=token)
hashed_token = token else:
hashed_tokens.append(hashed_token) hashed_token = token
await self.db.litellm_verificationtoken.delete_many( hashed_tokens.append(hashed_token)
where={"token": {"in": hashed_tokens}} await self.db.litellm_verificationtoken.delete_many(
) where={"token": {"in": hashed_tokens}}
return {"deleted_keys": tokens} )
return {"deleted_keys": tokens}
elif (
table_name == "team"
and team_id_list is not None
and isinstance(team_id_list, List)
):
await self.db.litellm_teamtable.delete_many(
where={"team_id": {"in": team_id_list}}
)
return {"deleted_teams": team_id_list}
elif (
table_name == "key"
and team_id_list is not None
and isinstance(team_id_list, List)
):
await self.db.litellm_verificationtoken.delete_many(
where={"team_id": {"in": team_id_list}}
)
except Exception as e: except Exception as e:
asyncio.create_task( asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e) self.proxy_logging_obj.failure_handler(original_exception=e)

View file

@ -13,6 +13,7 @@ model LiteLLM_TeamTable {
team_alias String? team_alias String?
admins String[] admins String[]
members String[] members String[]
members_with_roles Json @default("{}")
metadata Json @default("{}") metadata Json @default("{}")
max_budget Float? max_budget Float?
spend Float @default(0.0) spend Float @default(0.0)
@ -32,6 +33,7 @@ model LiteLLM_TeamTable {
model LiteLLM_UserTable { model LiteLLM_UserTable {
user_id String @unique user_id String @unique
team_id String? team_id String?
teams String[] @default([])
user_role String? user_role String?
max_budget Float? max_budget Float?
spend Float @default(0.0) spend Float @default(0.0)
@ -104,4 +106,4 @@ model LiteLLM_UserNotifications {
models String[] models String[]
justification String justification String
status String // approved, disapproved, pending status String // approved, disapproved, pending
} }

View file

@ -3,21 +3,176 @@
import pytest import pytest
import asyncio import asyncio
import aiohttp import aiohttp
import time import time, uuid
from openai import AsyncOpenAI from openai import AsyncOpenAI
async def new_team( async def new_user(session, i, user_id=None, budget=None, budget_duration=None):
url = "http://0.0.0.0:4000/user/new"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {
"models": ["azure-models"],
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
"duration": None,
"max_budget": budget,
"budget_duration": budget_duration,
}
if user_id is not None:
data["user_id"] = user_id
async with session.post(url, headers=headers, json=data) as response:
status = response.status
response_text = await response.text()
print(f"Response {i} (Status code: {status}):")
print(response_text)
print()
if status != 200:
raise Exception(f"Request {i} did not return a 200 status code: {status}")
return await response.json()
async def generate_key(
session, session,
i, i,
budget=None,
budget_duration=None,
models=["azure-models", "gpt-4", "dall-e-3"],
team_id=None,
): ):
url = "http://0.0.0.0:4000/key/generate"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {
"models": models,
"duration": None,
"max_budget": budget,
"budget_duration": budget_duration,
}
if team_id is not None:
data["team_id"] = team_id
print(f"data: {data}")
async with session.post(url, headers=headers, json=data) as response:
status = response.status
response_text = await response.text()
print(f"Response {i} (Status code: {status}):")
print(response_text)
print()
if status != 200:
raise Exception(f"Request {i} did not return a 200 status code: {status}")
return await response.json()
async def chat_completion(session, key, model="gpt-4"):
url = "http://0.0.0.0:4000/chat/completions"
headers = {
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
}
data = {
"model": model,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
],
}
for i in range(3):
try:
async with session.post(url, headers=headers, json=data) as response:
status = response.status
response_text = await response.text()
print(response_text)
print()
if status != 200:
raise Exception(
f"Request did not return a 200 status code: {status}. Response: {response_text}"
)
return await response.json()
except Exception as e:
if "Request did not return a 200 status code" in str(e):
raise e
else:
pass
async def new_team(session, i, user_id=None, member_list=None):
url = "http://0.0.0.0:4000/team/new" url = "http://0.0.0.0:4000/team/new"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = { data = {
"team_alias": "my-new-team", "team_alias": "my-new-team",
"admins": ["user-1234"],
"members": ["user-1234"],
} }
if user_id is not None:
data["members_with_roles"] = [{"role": "user", "user_id": user_id}]
elif member_list is not None:
data["members_with_roles"] = member_list
async with session.post(url, headers=headers, json=data) as response:
status = response.status
response_text = await response.text()
print(f"Response {i} (Status code: {status}):")
print(response_text)
print()
if status != 200:
raise Exception(f"Request {i} did not return a 200 status code: {status}")
return await response.json()
async def update_team(
session,
i,
team_id,
user_id=None,
member_list=None,
):
url = "http://0.0.0.0:4000/team/update"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {
"team_id": team_id,
}
if user_id is not None:
data["members_with_roles"] = [{"role": "user", "user_id": user_id}]
elif member_list is not None:
data["members_with_roles"] = member_list
async with session.post(url, headers=headers, json=data) as response:
status = response.status
response_text = await response.text()
print(f"Response {i} (Status code: {status}):")
print(response_text)
print()
if status != 200:
raise Exception(f"Request {i} did not return a 200 status code: {status}")
return await response.json()
async def delete_team(
session,
i,
team_id,
):
url = "http://0.0.0.0:4000/team/delete"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {
"team_ids": [team_id],
}
async with session.post(url, headers=headers, json=data) as response: async with session.post(url, headers=headers, json=data) as response:
status = response.status status = response.status
response_text = await response.text() response_text = await response.text()
@ -37,8 +192,10 @@ async def test_team_new():
""" """
Make 20 parallel calls to /user/new. Assert all worked. Make 20 parallel calls to /user/new. Assert all worked.
""" """
user_id = f"{uuid.uuid4()}"
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
tasks = [new_team(session, i) for i in range(1, 11)] new_user(session=session, i=0, user_id=user_id)
tasks = [new_team(session, i, user_id=user_id) for i in range(1, 11)]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
@ -70,3 +227,66 @@ async def test_team_info():
team_id = new_team_data["team_id"] team_id = new_team_data["team_id"]
## as admin ## ## as admin ##
await get_team_info(session=session, get_team=team_id, call_key="sk-1234") await get_team_info(session=session, get_team=team_id, call_key="sk-1234")
@pytest.mark.asyncio
async def test_team_update():
"""
- Create team with 1 admin, 1 user
- Create new user
- Replace existing user with new user in team
"""
async with aiohttp.ClientSession() as session:
## Create admin
admin_user = f"{uuid.uuid4()}"
await new_user(session=session, i=0, user_id=admin_user)
## Create normal user
normal_user = f"{uuid.uuid4()}"
await new_user(session=session, i=0, user_id=normal_user)
## Create team with 1 admin and 1 user
member_list = [
{"role": "admin", "user_id": admin_user},
{"role": "user", "user_id": normal_user},
]
team_data = await new_team(session=session, i=0, member_list=member_list)
## Create new normal user
new_normal_user = f"{uuid.uuid4()}"
await new_user(session=session, i=0, user_id=new_normal_user)
## Update member list
member_list = [
{"role": "admin", "user_id": admin_user},
{"role": "user", "user_id": new_normal_user},
]
team_data = await update_team(
session=session, i=0, member_list=member_list, team_id=team_data["team_id"]
)
@pytest.mark.asyncio
async def test_team_delete():
"""
- Create team
- Create key for team
- Check if key works
- Delete team
"""
async with aiohttp.ClientSession() as session:
## Create admin
admin_user = f"{uuid.uuid4()}"
await new_user(session=session, i=0, user_id=admin_user)
## Create normal user
normal_user = f"{uuid.uuid4()}"
await new_user(session=session, i=0, user_id=normal_user)
## Create team with 1 admin and 1 user
member_list = [
{"role": "admin", "user_id": admin_user},
{"role": "user", "user_id": normal_user},
]
team_data = await new_team(session=session, i=0, member_list=member_list)
## Create key
key_gen = await generate_key(session=session, i=0, team_id=team_data["team_id"])
key = key_gen["key"]
## Test key
response = await chat_completion(session=session, key=key)
## Delete team
await delete_team(session=session, i=0, team_id=team_data["team_id"])