diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index bbe20d047..f0f384094 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -224,14 +224,33 @@ class UpdateUserRequest(GenerateRequestBase): max_budget: Optional[float] = None +class Member(LiteLLMBase): + role: Literal["admin", "user"] + user_id: str + + class NewTeamRequest(LiteLLMBase): team_alias: Optional[str] = None team_id: Optional[str] = None admins: list = [] members: list = [] + members_with_roles: List[Member] = [] metadata: Optional[dict] = None +class UpdateTeamRequest(LiteLLMBase): + team_id: str # required + team_alias: Optional[str] = None + admins: Optional[list] = None + members: Optional[list] = None + members_with_roles: Optional[List[Member]] = None + metadata: Optional[dict] = None + + +class DeleteTeamRequest(LiteLLMBase): + team_ids: List[str] # required + + class LiteLLM_TeamTable(NewTeamRequest): max_budget: Optional[float] = None spend: Optional[float] = None diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index f086ebc70..d28248a9f 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -4060,9 +4060,30 @@ async def user_info( else: user_info = None ## GET ALL TEAMS ## - teams = await prisma_client.get_data( + team_list = [] + team_id_list = [] + # _DEPRECATED_ check if user in 'member' field + teams_1 = await prisma_client.get_data( user_id=user_id, table_name="team", query_type="find_all" ) + + if teams_1 is not None and isinstance(teams_1, list): + team_list = teams_1 + for team in teams_1: + team_id_list.append(team.team_id) + + if user_info is not None: + # *NEW* get all teams in user 'teams' field + teams_2 = await prisma_client.get_data( + team_id_list=user_info.teams, table_name="team", query_type="find_all" + ) + + if teams_2 is not None and isinstance(teams_2, list): + for team in teams_2: + if team.team_id not in team_id_list: + team_list.append(team) + team_id_list.append(team.team_id) + ## GET ALL KEYS ## keys = await prisma_client.get_data( user_id=user_id, @@ -4090,9 +4111,10 @@ async def user_info( "user_id": user_id, "user_info": user_info, "keys": keys, - "teams": teams, + "teams": team_list, } except Exception as e: + traceback.print_exc() if isinstance(e, HTTPException): raise ProxyException( message=getattr(e, "detail", f"Authentication Error({str(e)})"), @@ -4274,12 +4296,31 @@ async def new_team( Parameters: - team_alias: Optional[str] - User defined team alias - team_id: Optional[str] - The team id of the user. If none passed, we'll generate it. - - admins: list - A list of user IDs that will be owning the team - - members: list - A list of user IDs that will be members of the team + - members_with_roles: list - A list of dictionaries, mapping user_id to role in team (either 'admin' or 'user') - metadata: Optional[dict] - Metadata for team, store information for team. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" } Returns: - team_id: (str) Unique team id - used for tracking spend across multiple keys for same team id. + + _deprecated_params: + - admins: list - A list of user_id's for the admin role + - users: list - A list of user_id's for the user role + + Example Request: + ``` + curl --location 'http://0.0.0.0:8000/team/new' \ + + --header 'Authorization: Bearer sk-1234' \ + + --header 'Content-Type: application/json' \ + + --data '{ + "team_alias": "my-new-team_2", + "members_with_roles": [{"role": "admin", "user_id": "user-1234"}, + {"role": "user", "user_id": "user-2434"}] + }' + + ``` """ global prisma_client @@ -4303,27 +4344,124 @@ async def new_team( team_row = await prisma_client.insert_data( data=complete_team_data.json(exclude_none=True), table_name="team" ) + + ## ADD TEAM ID TO USER TABLE ## + for user in complete_team_data.members_with_roles: + ## add team id to user row ## + await prisma_client.update_data( + user_id=user.user_id, + data={"user_id": user.user_id, "teams": [team_row.team_id]}, + update_key_values={ + "teams": { + "push ": [team_row.team_id], + } + }, + ) return team_row @router.post( "/team/update", tags=["team management"], dependencies=[Depends(user_api_key_auth)] ) -async def update_team(): +async def update_team( + data: UpdateTeamRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): """ - update team and members + add new members to the team """ - pass + global prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if data.team_id is None: + raise HTTPException(status_code=400, detail={"error": "No team id passed in"}) + + existing_team_row = await prisma_client.get_data( + team_id=data.team_id, table_name="team", query_type="find_unique" + ) + + updated_kv = data.json(exclude_none=True) + team_row = await prisma_client.update_data( + update_key_values=updated_kv, + data=updated_kv, + table_name="team", + team_id=data.team_id, + ) + + ## ADD NEW USERS ## + existing_user_id_list = [] + ## Get new users + for user in existing_team_row.members_with_roles: + existing_user_id_list.append(user["user_id"]) + + ## Update new user rows with team id (info used by /user/info to show all teams, user is a part of) + if data.members_with_roles is not None: + for user in data.members_with_roles: + if user.user_id not in existing_user_id_list: + await prisma_client.update_data( + user_id=user.user_id, + data={"user_id": user.user_id, "teams": [team_row["team_id"]]}, + update_key_values={ + "teams": { + "push": [team_row["team_id"]], + } + }, + ) + + ## REMOVE DELETED USERS ## + ### Get list of deleted users (old list - new list) + deleted_user_id_list = [] + existing_user_id_list = [] + ## Get old user list + for user in existing_team_row.members_with_roles: + existing_user_id_list.append(user["user_id"]) + ## Get diff + if data.members_with_roles is not None: + for user in data.members_with_roles: + if user.user_id not in existing_user_id_list: + deleted_user_id_list.append(user.user_id) + + ## SET UPDATED LIST + if len(deleted_user_id_list) > 0: + # get the deleted users + existing_user_rows = await prisma_client.get_data( + user_id_list=deleted_user_id_list, table_name="user", query_type="find_all" + ) + for user in existing_user_rows: + if data.team_id in user["teams"]: + user["teams"].remove(data.team_id) + await prisma_client.update_data( + user_id=user["user_id"], + data=user, + update_key_values={"user_id": user["user_id"], "teams": user["teams"]}, + ) + return team_row @router.post( "/team/delete", tags=["team management"], dependencies=[Depends(user_api_key_auth)] ) -async def delete_team(): +async def delete_team( + data: DeleteTeamRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): """ - delete team and team keys + delete team and associated team keys """ - pass + global prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if data.team_ids is None: + raise HTTPException(status_code=400, detail={"error": "No team id passed in"}) + + ## DELETE ASSOCIATED KEYS + await prisma_client.delete_data(team_id_list=data.team_ids, table_name="key") + ## DELETE TEAMS + await prisma_client.delete_data(team_id_list=data.team_ids, table_name="team") @router.get( diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 5377fe90b..7e8e26fcd 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -13,6 +13,7 @@ model LiteLLM_TeamTable { team_alias String? admins String[] members String[] + members_with_roles Json @default("{}") metadata Json @default("{}") max_budget Float? spend Float @default(0.0) @@ -32,6 +33,7 @@ model LiteLLM_TeamTable { model LiteLLM_UserTable { user_id String @unique team_id String? + teams String[] @default([]) user_role String? max_budget Float? spend Float @default(0.0) @@ -103,5 +105,5 @@ model LiteLLM_UserNotifications { user_id String models String[] justification String - status String // approved, disapproved, pending -} \ No newline at end of file + status String // approved, disapproved, pending +} diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 34a77a973..4e08f8fb9 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -532,6 +532,7 @@ class PrismaClient: user_id: Optional[str] = None, user_id_list: Optional[list] = None, team_id: Optional[str] = None, + team_id_list: Optional[list] = None, key_val: Optional[dict] = None, table_name: Optional[ Literal["user", "key", "config", "spend", "team", "user_notification"] @@ -641,13 +642,12 @@ class PrismaClient: } ) elif query_type == "find_all" and user_id_list is not None: - user_id_values = str(tuple(user_id_list)) + user_id_values = ", ".join(f"'{item}'" for item in user_id_list) sql_query = f""" SELECT * FROM "LiteLLM_UserTable" - WHERE "user_id" IN {user_id_values} + WHERE "user_id" IN ({user_id_values}) """ - # Execute the raw query # The asterisk before `user_id_list` unpacks the list into separate arguments response = await self.db.query_raw(sql_query) @@ -697,7 +697,13 @@ class PrismaClient: ) elif query_type == "find_all" and user_id is not None: response = await self.db.litellm_teamtable.find_many( - where={"members": {"has": user_id}} + where={ + "members": {"has": user_id}, + }, + ) + elif query_type == "find_all" and team_id_list is not None: + response = await self.db.litellm_teamtable.find_many( + where={"team_id": {"in": team_id_list}} ) return response elif table_name == "user_notification": @@ -769,6 +775,12 @@ class PrismaClient: return new_user_row elif table_name == "team": db_data = self.jsonify_object(data=data) + if db_data.get("members_with_roles", None) is not None and isinstance( + db_data["members_with_roles"], list + ): + db_data["members_with_roles"] = json.dumps( + db_data["members_with_roles"] + ) new_team_row = await self.db.litellm_teamtable.upsert( where={"team_id": data["team_id"]}, data={ @@ -915,6 +927,19 @@ class PrismaClient: update_key_values = db_data if "team_id" not in db_data and team_id is not None: db_data["team_id"] = team_id + if "members_with_roles" in db_data and isinstance( + db_data["members_with_roles"], list + ): + db_data["members_with_roles"] = json.dumps( + db_data["members_with_roles"] + ) + if "members_with_roles" in update_key_values and isinstance( + update_key_values["members_with_roles"], list + ): + update_key_values["members_with_roles"] = json.dumps( + update_key_values["members_with_roles"] + ) + update_team_row = await self.db.litellm_teamtable.upsert( where={"team_id": team_id}, # type: ignore data={ @@ -929,7 +954,7 @@ class PrismaClient: + f"DB Team Table - update succeeded {update_team_row}" + "\033[0m" ) - return {"team_id": team_id, "data": db_data} + return {"team_id": team_id, "data": update_team_row} elif ( table_name is not None and table_name == "key" @@ -1001,22 +1026,45 @@ class PrismaClient: max_time=10, # maximum total time to retry for on_backoff=on_backoff, # specifying the function to call on backoff ) - async def delete_data(self, tokens: List): + async def delete_data( + self, + tokens: Optional[List] = None, + team_id_list: Optional[List] = None, + table_name: Optional[Literal["user", "key", "config", "spend", "team"]] = None, + ): """ Allow user to delete a key(s) """ try: - hashed_tokens = [] - for token in tokens: - if isinstance(token, str) and token.startswith("sk-"): - hashed_token = self.hash_token(token=token) - else: - hashed_token = token - hashed_tokens.append(hashed_token) - await self.db.litellm_verificationtoken.delete_many( - where={"token": {"in": hashed_tokens}} - ) - return {"deleted_keys": tokens} + if tokens is not None and isinstance(tokens, List): + hashed_tokens = [] + for token in tokens: + if isinstance(token, str) and token.startswith("sk-"): + hashed_token = self.hash_token(token=token) + else: + hashed_token = token + hashed_tokens.append(hashed_token) + await self.db.litellm_verificationtoken.delete_many( + where={"token": {"in": hashed_tokens}} + ) + return {"deleted_keys": tokens} + elif ( + table_name == "team" + and team_id_list is not None + and isinstance(team_id_list, List) + ): + await self.db.litellm_teamtable.delete_many( + where={"team_id": {"in": team_id_list}} + ) + return {"deleted_teams": team_id_list} + elif ( + table_name == "key" + and team_id_list is not None + and isinstance(team_id_list, List) + ): + await self.db.litellm_verificationtoken.delete_many( + where={"team_id": {"in": team_id_list}} + ) except Exception as e: asyncio.create_task( self.proxy_logging_obj.failure_handler(original_exception=e) diff --git a/schema.prisma b/schema.prisma index 8663db1b0..7e8e26fcd 100644 --- a/schema.prisma +++ b/schema.prisma @@ -13,6 +13,7 @@ model LiteLLM_TeamTable { team_alias String? admins String[] members String[] + members_with_roles Json @default("{}") metadata Json @default("{}") max_budget Float? spend Float @default(0.0) @@ -32,6 +33,7 @@ model LiteLLM_TeamTable { model LiteLLM_UserTable { user_id String @unique team_id String? + teams String[] @default([]) user_role String? max_budget Float? spend Float @default(0.0) @@ -104,4 +106,4 @@ model LiteLLM_UserNotifications { models String[] justification String status String // approved, disapproved, pending -} \ No newline at end of file +} diff --git a/tests/test_team.py b/tests/test_team.py index b7da1bf9e..9cb98a8c2 100644 --- a/tests/test_team.py +++ b/tests/test_team.py @@ -3,21 +3,176 @@ import pytest import asyncio import aiohttp -import time +import time, uuid from openai import AsyncOpenAI -async def new_team( +async def new_user(session, i, user_id=None, budget=None, budget_duration=None): + url = "http://0.0.0.0:4000/user/new" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = { + "models": ["azure-models"], + "aliases": {"mistral-7b": "gpt-3.5-turbo"}, + "duration": None, + "max_budget": budget, + "budget_duration": budget_duration, + } + + if user_id is not None: + data["user_id"] = user_id + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(f"Response {i} (Status code: {status}):") + print(response_text) + print() + + if status != 200: + raise Exception(f"Request {i} did not return a 200 status code: {status}") + + return await response.json() + + +async def generate_key( session, i, + budget=None, + budget_duration=None, + models=["azure-models", "gpt-4", "dall-e-3"], + team_id=None, ): + url = "http://0.0.0.0:4000/key/generate" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = { + "models": models, + "duration": None, + "max_budget": budget, + "budget_duration": budget_duration, + } + if team_id is not None: + data["team_id"] = team_id + + print(f"data: {data}") + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(f"Response {i} (Status code: {status}):") + print(response_text) + print() + + if status != 200: + raise Exception(f"Request {i} did not return a 200 status code: {status}") + + return await response.json() + + +async def chat_completion(session, key, model="gpt-4"): + url = "http://0.0.0.0:4000/chat/completions" + headers = { + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + } + data = { + "model": model, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + ], + } + + for i in range(3): + try: + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(response_text) + print() + + if status != 200: + raise Exception( + f"Request did not return a 200 status code: {status}. Response: {response_text}" + ) + + return await response.json() + except Exception as e: + if "Request did not return a 200 status code" in str(e): + raise e + else: + pass + + +async def new_team(session, i, user_id=None, member_list=None): url = "http://0.0.0.0:4000/team/new" headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} data = { "team_alias": "my-new-team", - "admins": ["user-1234"], - "members": ["user-1234"], } + if user_id is not None: + data["members_with_roles"] = [{"role": "user", "user_id": user_id}] + elif member_list is not None: + data["members_with_roles"] = member_list + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(f"Response {i} (Status code: {status}):") + print(response_text) + print() + + if status != 200: + raise Exception(f"Request {i} did not return a 200 status code: {status}") + + return await response.json() + + +async def update_team( + session, + i, + team_id, + user_id=None, + member_list=None, +): + url = "http://0.0.0.0:4000/team/update" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = { + "team_id": team_id, + } + if user_id is not None: + data["members_with_roles"] = [{"role": "user", "user_id": user_id}] + elif member_list is not None: + data["members_with_roles"] = member_list + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(f"Response {i} (Status code: {status}):") + print(response_text) + print() + + if status != 200: + raise Exception(f"Request {i} did not return a 200 status code: {status}") + + return await response.json() + + +async def delete_team( + session, + i, + team_id, +): + url = "http://0.0.0.0:4000/team/delete" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = { + "team_ids": [team_id], + } + async with session.post(url, headers=headers, json=data) as response: status = response.status response_text = await response.text() @@ -37,8 +192,10 @@ async def test_team_new(): """ Make 20 parallel calls to /user/new. Assert all worked. """ + user_id = f"{uuid.uuid4()}" async with aiohttp.ClientSession() as session: - tasks = [new_team(session, i) for i in range(1, 11)] + new_user(session=session, i=0, user_id=user_id) + tasks = [new_team(session, i, user_id=user_id) for i in range(1, 11)] await asyncio.gather(*tasks) @@ -70,3 +227,66 @@ async def test_team_info(): team_id = new_team_data["team_id"] ## as admin ## await get_team_info(session=session, get_team=team_id, call_key="sk-1234") + + +@pytest.mark.asyncio +async def test_team_update(): + """ + - Create team with 1 admin, 1 user + - Create new user + - Replace existing user with new user in team + """ + async with aiohttp.ClientSession() as session: + ## Create admin + admin_user = f"{uuid.uuid4()}" + await new_user(session=session, i=0, user_id=admin_user) + ## Create normal user + normal_user = f"{uuid.uuid4()}" + await new_user(session=session, i=0, user_id=normal_user) + ## Create team with 1 admin and 1 user + member_list = [ + {"role": "admin", "user_id": admin_user}, + {"role": "user", "user_id": normal_user}, + ] + team_data = await new_team(session=session, i=0, member_list=member_list) + ## Create new normal user + new_normal_user = f"{uuid.uuid4()}" + await new_user(session=session, i=0, user_id=new_normal_user) + ## Update member list + member_list = [ + {"role": "admin", "user_id": admin_user}, + {"role": "user", "user_id": new_normal_user}, + ] + team_data = await update_team( + session=session, i=0, member_list=member_list, team_id=team_data["team_id"] + ) + + +@pytest.mark.asyncio +async def test_team_delete(): + """ + - Create team + - Create key for team + - Check if key works + - Delete team + """ + async with aiohttp.ClientSession() as session: + ## Create admin + admin_user = f"{uuid.uuid4()}" + await new_user(session=session, i=0, user_id=admin_user) + ## Create normal user + normal_user = f"{uuid.uuid4()}" + await new_user(session=session, i=0, user_id=normal_user) + ## Create team with 1 admin and 1 user + member_list = [ + {"role": "admin", "user_id": admin_user}, + {"role": "user", "user_id": normal_user}, + ] + team_data = await new_team(session=session, i=0, member_list=member_list) + ## Create key + key_gen = await generate_key(session=session, i=0, team_id=team_data["team_id"]) + key = key_gen["key"] + ## Test key + response = await chat_completion(session=session, key=key) + ## Delete team + await delete_team(session=session, i=0, team_id=team_data["team_id"])