feat(proxy_server.py): /key/delete endpoint

This commit is contained in:
Krrish Dholakia 2023-11-23 21:37:46 -08:00
parent 0c210cc96c
commit 9a44433844
2 changed files with 40 additions and 2 deletions

View file

@ -2,7 +2,7 @@ import sys, os, platform, time, copy, re, asyncio
import threading, ast
import shutil, random, traceback, requests
from datetime import datetime, timedelta
from typing import Optional
from typing import Optional, List
import secrets, subprocess
messages: list = []
sys.path.insert(
@ -155,7 +155,7 @@ async def user_api_key_auth(request: Request):
if api_key == master_key:
return
if route == "/key/generate" and api_key != master_key:
if (route == "/key/generate" or route == "/key/delete") and api_key != master_key:
raise Exception(f"If master key is set, only master key can be used to generate new keys")
if prisma_client:
@ -352,6 +352,19 @@ async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict,
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
return {"token": new_verification_token.token, "expires": new_verification_token.expires}
async def delete_verification_token(tokens: List[str]):
global prisma_client
try:
# Assuming 'db' is your Prisma Client instance
deleted_tokens = await prisma_client.litellm_verificationtoken.delete_many(
where={"token": {"in": tokens}}
)
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
return deleted_tokens
async def generate_key_cli_task(duration_str):
task = asyncio.create_task(generate_key_helper_fn(duration_str=duration_str))
await task
@ -652,6 +665,28 @@ async def generate_key_fn(request: Request):
detail={"error": "models param must be a list"},
)
@router.post("/key/delete", dependencies=[Depends(user_api_key_auth)])
async def generate_key_fn(request: Request):
try:
data = await request.json()
keys = data.get("keys", [])
if not isinstance(keys, list):
if isinstance(keys, str):
keys = [keys]
else:
raise Exception(f"keys must be an instance of either a string or a list")
deleted_keys = await delete_verification_token(tokens=keys)
assert len(keys) == deleted_keys
return {"deleted_keys": keys}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": str(e)},
)
@router.get("/test")
async def test_endpoint(request: Request):
return {"route": request.url.path}

View file

@ -79,6 +79,7 @@ def test_sync_fallbacks():
litellm.set_verbose = True
response = router.completion(**kwargs)
print(f"response: {response}")
router.flush_cache()
except Exception as e:
print(e)
test_sync_fallbacks()
@ -96,6 +97,7 @@ def test_async_fallbacks():
pass
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
router.flush_cache()
asyncio.run(test_get_response())
@ -110,5 +112,6 @@ def test_sync_context_window_fallbacks():
print(f"response: {response}")
except Exception as e:
print(e)
router.flush_cache()
# test_sync_context_window_fallbacks()