From a7229c9253478e8d2399da2f6cdfea285ce820b7 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 21 Feb 2024 16:53:12 -0800 Subject: [PATCH] fix(proxy_server.py): enable proxy /team/delete endpoint --- litellm/proxy/_types.py | 4 + litellm/proxy/proxy_server.py | 20 ++- litellm/proxy/utils.py | 47 +++++-- tests/test_team.py | 230 +++++++++++++++++++++++++++++++++- 4 files changed, 281 insertions(+), 20 deletions(-) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index a96f410a7..f0f384094 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -247,6 +247,10 @@ class UpdateTeamRequest(LiteLLMBase): metadata: Optional[dict] = None +class DeleteTeamRequest(LiteLLMBase): + team_ids: List[str] # required + + class LiteLLM_TeamTable(NewTeamRequest): max_budget: Optional[float] = None spend: Optional[float] = None diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index dcaa8be6a..c81dec20a 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -4443,11 +4443,25 @@ async def update_team( @router.post( "/team/delete", tags=["team management"], dependencies=[Depends(user_api_key_auth)] ) -async def delete_team(): +async def delete_team( + data: DeleteTeamRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): """ - delete team and team keys + delete team and associated team keys """ - pass + global prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if data.team_ids is None: + raise HTTPException(status_code=400, detail={"error": "No team id passed in"}) + + ## DELETE ASSOCIATED KEYS + await prisma_client.delete_data(team_id_list=data.team_ids, table_name="key") + ## DELETE TEAMS + await prisma_client.delete_data(team_id_list=data.team_ids, table_name="team") @router.get( diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 6de5fc6e3..9edf87be4 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1026,22 +1026,45 @@ class PrismaClient: max_time=10, # maximum total time to retry for on_backoff=on_backoff, # specifying the function to call on backoff ) - async def delete_data(self, tokens: List): + async def delete_data( + self, + tokens: Optional[List] = None, + team_id_list: Optional[List] = None, + table_name: Optional[Literal["user", "key", "config", "spend", "team"]] = None, + ): """ Allow user to delete a key(s) """ try: - hashed_tokens = [] - for token in tokens: - if isinstance(token, str) and token.startswith("sk-"): - hashed_token = self.hash_token(token=token) - else: - hashed_token = token - hashed_tokens.append(hashed_token) - await self.db.litellm_verificationtoken.delete_many( - where={"token": {"in": hashed_tokens}} - ) - return {"deleted_keys": tokens} + if tokens is not None and isinstance(tokens, List): + hashed_tokens = [] + for token in tokens: + if isinstance(token, str) and token.startswith("sk-"): + hashed_token = self.hash_token(token=token) + else: + hashed_token = token + hashed_tokens.append(hashed_token) + await self.db.litellm_verificationtoken.delete_many( + where={"token": {"in": hashed_tokens}} + ) + return {"deleted_keys": tokens} + elif ( + table_name == "team" + and team_id_list is not None + and isinstance(team_id_list, List) + ): + await self.db.litellm_teamtable.delete_many( + where={"team_id": {"in": team_id_list}} + ) + return {"deleted_teams": team_id_list} + elif ( + table_name == "key" + and team_id_list is not None + and isinstance(team_id_list, List) + ): + await self.db.litellm_verificationtoken.delete_many( + where={"team_id": {"in": team_id_list}} + ) except Exception as e: asyncio.create_task( self.proxy_logging_obj.failure_handler(original_exception=e) diff --git a/tests/test_team.py b/tests/test_team.py index b7da1bf9e..9cb98a8c2 100644 --- a/tests/test_team.py +++ b/tests/test_team.py @@ -3,21 +3,176 @@ import pytest import asyncio import aiohttp -import time +import time, uuid from openai import AsyncOpenAI -async def new_team( +async def new_user(session, i, user_id=None, budget=None, budget_duration=None): + url = "http://0.0.0.0:4000/user/new" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = { + "models": ["azure-models"], + "aliases": {"mistral-7b": "gpt-3.5-turbo"}, + "duration": None, + "max_budget": budget, + "budget_duration": budget_duration, + } + + if user_id is not None: + data["user_id"] = user_id + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(f"Response {i} (Status code: {status}):") + print(response_text) + print() + + if status != 200: + raise Exception(f"Request {i} did not return a 200 status code: {status}") + + return await response.json() + + +async def generate_key( session, i, + budget=None, + budget_duration=None, + models=["azure-models", "gpt-4", "dall-e-3"], + team_id=None, ): + url = "http://0.0.0.0:4000/key/generate" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = { + "models": models, + "duration": None, + "max_budget": budget, + "budget_duration": budget_duration, + } + if team_id is not None: + data["team_id"] = team_id + + print(f"data: {data}") + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(f"Response {i} (Status code: {status}):") + print(response_text) + print() + + if status != 200: + raise Exception(f"Request {i} did not return a 200 status code: {status}") + + return await response.json() + + +async def chat_completion(session, key, model="gpt-4"): + url = "http://0.0.0.0:4000/chat/completions" + headers = { + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + } + data = { + "model": model, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + ], + } + + for i in range(3): + try: + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(response_text) + print() + + if status != 200: + raise Exception( + f"Request did not return a 200 status code: {status}. Response: {response_text}" + ) + + return await response.json() + except Exception as e: + if "Request did not return a 200 status code" in str(e): + raise e + else: + pass + + +async def new_team(session, i, user_id=None, member_list=None): url = "http://0.0.0.0:4000/team/new" headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} data = { "team_alias": "my-new-team", - "admins": ["user-1234"], - "members": ["user-1234"], } + if user_id is not None: + data["members_with_roles"] = [{"role": "user", "user_id": user_id}] + elif member_list is not None: + data["members_with_roles"] = member_list + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(f"Response {i} (Status code: {status}):") + print(response_text) + print() + + if status != 200: + raise Exception(f"Request {i} did not return a 200 status code: {status}") + + return await response.json() + + +async def update_team( + session, + i, + team_id, + user_id=None, + member_list=None, +): + url = "http://0.0.0.0:4000/team/update" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = { + "team_id": team_id, + } + if user_id is not None: + data["members_with_roles"] = [{"role": "user", "user_id": user_id}] + elif member_list is not None: + data["members_with_roles"] = member_list + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(f"Response {i} (Status code: {status}):") + print(response_text) + print() + + if status != 200: + raise Exception(f"Request {i} did not return a 200 status code: {status}") + + return await response.json() + + +async def delete_team( + session, + i, + team_id, +): + url = "http://0.0.0.0:4000/team/delete" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = { + "team_ids": [team_id], + } + async with session.post(url, headers=headers, json=data) as response: status = response.status response_text = await response.text() @@ -37,8 +192,10 @@ async def test_team_new(): """ Make 20 parallel calls to /user/new. Assert all worked. """ + user_id = f"{uuid.uuid4()}" async with aiohttp.ClientSession() as session: - tasks = [new_team(session, i) for i in range(1, 11)] + new_user(session=session, i=0, user_id=user_id) + tasks = [new_team(session, i, user_id=user_id) for i in range(1, 11)] await asyncio.gather(*tasks) @@ -70,3 +227,66 @@ async def test_team_info(): team_id = new_team_data["team_id"] ## as admin ## await get_team_info(session=session, get_team=team_id, call_key="sk-1234") + + +@pytest.mark.asyncio +async def test_team_update(): + """ + - Create team with 1 admin, 1 user + - Create new user + - Replace existing user with new user in team + """ + async with aiohttp.ClientSession() as session: + ## Create admin + admin_user = f"{uuid.uuid4()}" + await new_user(session=session, i=0, user_id=admin_user) + ## Create normal user + normal_user = f"{uuid.uuid4()}" + await new_user(session=session, i=0, user_id=normal_user) + ## Create team with 1 admin and 1 user + member_list = [ + {"role": "admin", "user_id": admin_user}, + {"role": "user", "user_id": normal_user}, + ] + team_data = await new_team(session=session, i=0, member_list=member_list) + ## Create new normal user + new_normal_user = f"{uuid.uuid4()}" + await new_user(session=session, i=0, user_id=new_normal_user) + ## Update member list + member_list = [ + {"role": "admin", "user_id": admin_user}, + {"role": "user", "user_id": new_normal_user}, + ] + team_data = await update_team( + session=session, i=0, member_list=member_list, team_id=team_data["team_id"] + ) + + +@pytest.mark.asyncio +async def test_team_delete(): + """ + - Create team + - Create key for team + - Check if key works + - Delete team + """ + async with aiohttp.ClientSession() as session: + ## Create admin + admin_user = f"{uuid.uuid4()}" + await new_user(session=session, i=0, user_id=admin_user) + ## Create normal user + normal_user = f"{uuid.uuid4()}" + await new_user(session=session, i=0, user_id=normal_user) + ## Create team with 1 admin and 1 user + member_list = [ + {"role": "admin", "user_id": admin_user}, + {"role": "user", "user_id": normal_user}, + ] + team_data = await new_team(session=session, i=0, member_list=member_list) + ## Create key + key_gen = await generate_key(session=session, i=0, team_id=team_data["team_id"]) + key = key_gen["key"] + ## Test key + response = await chat_completion(session=session, key=key) + ## Delete team + await delete_team(session=session, i=0, team_id=team_data["team_id"])