From 9a444338448e79c134d69acef73c7bd324bd04bf Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 23 Nov 2023 21:37:46 -0800 Subject: [PATCH] feat(proxy_server.py): /key/delete endpoint --- litellm/proxy/proxy_server.py | 39 ++++++++++++++++++++++++-- litellm/tests/test_router_fallbacks.py | 3 ++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 17f703fe69..c9193d7504 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2,7 +2,7 @@ import sys, os, platform, time, copy, re, asyncio import threading, ast import shutil, random, traceback, requests from datetime import datetime, timedelta -from typing import Optional +from typing import Optional, List import secrets, subprocess messages: list = [] sys.path.insert( @@ -155,7 +155,7 @@ async def user_api_key_auth(request: Request): if api_key == master_key: return - if route == "/key/generate" and api_key != master_key: + if (route == "/key/generate" or route == "/key/delete") and api_key != master_key: raise Exception(f"If master key is set, only master key can be used to generate new keys") if prisma_client: @@ -352,6 +352,19 @@ async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict, raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) return {"token": new_verification_token.token, "expires": new_verification_token.expires} +async def delete_verification_token(tokens: List[str]): + global prisma_client + try: + # Assuming 'db' is your Prisma Client instance + deleted_tokens = await prisma_client.litellm_verificationtoken.delete_many( + where={"token": {"in": tokens}} + ) + except Exception as e: + traceback.print_exc() + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) + return deleted_tokens + + async def generate_key_cli_task(duration_str): task = asyncio.create_task(generate_key_helper_fn(duration_str=duration_str)) await task @@ -652,6 +665,28 @@ async def generate_key_fn(request: Request): detail={"error": "models param must be a list"}, ) +@router.post("/key/delete", dependencies=[Depends(user_api_key_auth)]) +async def generate_key_fn(request: Request): + try: + data = await request.json() + + keys = data.get("keys", []) + + if not isinstance(keys, list): + if isinstance(keys, str): + keys = [keys] + else: + raise Exception(f"keys must be an instance of either a string or a list") + + deleted_keys = await delete_verification_token(tokens=keys) + assert len(keys) == deleted_keys + return {"deleted_keys": keys} + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": str(e)}, + ) + @router.get("/test") async def test_endpoint(request: Request): return {"route": request.url.path} diff --git a/litellm/tests/test_router_fallbacks.py b/litellm/tests/test_router_fallbacks.py index abb7e5eb87..f50f8f7330 100644 --- a/litellm/tests/test_router_fallbacks.py +++ b/litellm/tests/test_router_fallbacks.py @@ -79,6 +79,7 @@ def test_sync_fallbacks(): litellm.set_verbose = True response = router.completion(**kwargs) print(f"response: {response}") + router.flush_cache() except Exception as e: print(e) test_sync_fallbacks() @@ -96,6 +97,7 @@ def test_async_fallbacks(): pass except Exception as e: pytest.fail(f"An exception occurred: {e}") + router.flush_cache() asyncio.run(test_get_response()) @@ -110,5 +112,6 @@ def test_sync_context_window_fallbacks(): print(f"response: {response}") except Exception as e: print(e) + router.flush_cache() # test_sync_context_window_fallbacks() \ No newline at end of file