feat(proxy_server.py): /key/delete endpoint

This commit is contained in:
Krrish Dholakia 2023-11-23 21:37:46 -08:00
parent 375e7e8e6e
commit 8030a9b8d1
2 changed files with 40 additions and 2 deletions

View file

@ -2,7 +2,7 @@ import sys, os, platform, time, copy, re, asyncio
import threading, ast import threading, ast
import shutil, random, traceback, requests import shutil, random, traceback, requests
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional from typing import Optional, List
import secrets, subprocess import secrets, subprocess
messages: list = [] messages: list = []
sys.path.insert( sys.path.insert(
@ -155,7 +155,7 @@ async def user_api_key_auth(request: Request):
if api_key == master_key: if api_key == master_key:
return return
if route == "/key/generate" and api_key != master_key: if (route == "/key/generate" or route == "/key/delete") and api_key != master_key:
raise Exception(f"If master key is set, only master key can be used to generate new keys") raise Exception(f"If master key is set, only master key can be used to generate new keys")
if prisma_client: if prisma_client:
@ -352,6 +352,19 @@ async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict,
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
return {"token": new_verification_token.token, "expires": new_verification_token.expires} return {"token": new_verification_token.token, "expires": new_verification_token.expires}
async def delete_verification_token(tokens: List[str]):
global prisma_client
try:
# Assuming 'db' is your Prisma Client instance
deleted_tokens = await prisma_client.litellm_verificationtoken.delete_many(
where={"token": {"in": tokens}}
)
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
return deleted_tokens
async def generate_key_cli_task(duration_str): async def generate_key_cli_task(duration_str):
task = asyncio.create_task(generate_key_helper_fn(duration_str=duration_str)) task = asyncio.create_task(generate_key_helper_fn(duration_str=duration_str))
await task await task
@ -652,6 +665,28 @@ async def generate_key_fn(request: Request):
detail={"error": "models param must be a list"}, detail={"error": "models param must be a list"},
) )
@router.post("/key/delete", dependencies=[Depends(user_api_key_auth)])
async def generate_key_fn(request: Request):
try:
data = await request.json()
keys = data.get("keys", [])
if not isinstance(keys, list):
if isinstance(keys, str):
keys = [keys]
else:
raise Exception(f"keys must be an instance of either a string or a list")
deleted_keys = await delete_verification_token(tokens=keys)
assert len(keys) == deleted_keys
return {"deleted_keys": keys}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": str(e)},
)
@router.get("/test") @router.get("/test")
async def test_endpoint(request: Request): async def test_endpoint(request: Request):
return {"route": request.url.path} return {"route": request.url.path}

View file

@ -79,6 +79,7 @@ def test_sync_fallbacks():
litellm.set_verbose = True litellm.set_verbose = True
response = router.completion(**kwargs) response = router.completion(**kwargs)
print(f"response: {response}") print(f"response: {response}")
router.flush_cache()
except Exception as e: except Exception as e:
print(e) print(e)
test_sync_fallbacks() test_sync_fallbacks()
@ -96,6 +97,7 @@ def test_async_fallbacks():
pass pass
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
router.flush_cache()
asyncio.run(test_get_response()) asyncio.run(test_get_response())
@ -110,5 +112,6 @@ def test_sync_context_window_fallbacks():
print(f"response: {response}") print(f"response: {response}")
except Exception as e: except Exception as e:
print(e) print(e)
router.flush_cache()
# test_sync_context_window_fallbacks() # test_sync_context_window_fallbacks()