forked from phoenix/litellm-mirror
feat(proxy_server.py): /key/delete endpoint
This commit is contained in:
parent
0c210cc96c
commit
9a44433844
2 changed files with 40 additions and 2 deletions
|
@ -2,7 +2,7 @@ import sys, os, platform, time, copy, re, asyncio
|
||||||
import threading, ast
|
import threading, ast
|
||||||
import shutil, random, traceback, requests
|
import shutil, random, traceback, requests
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Optional
|
from typing import Optional, List
|
||||||
import secrets, subprocess
|
import secrets, subprocess
|
||||||
messages: list = []
|
messages: list = []
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
|
@ -155,7 +155,7 @@ async def user_api_key_auth(request: Request):
|
||||||
if api_key == master_key:
|
if api_key == master_key:
|
||||||
return
|
return
|
||||||
|
|
||||||
if route == "/key/generate" and api_key != master_key:
|
if (route == "/key/generate" or route == "/key/delete") and api_key != master_key:
|
||||||
raise Exception(f"If master key is set, only master key can be used to generate new keys")
|
raise Exception(f"If master key is set, only master key can be used to generate new keys")
|
||||||
|
|
||||||
if prisma_client:
|
if prisma_client:
|
||||||
|
@ -352,6 +352,19 @@ async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict,
|
||||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
return {"token": new_verification_token.token, "expires": new_verification_token.expires}
|
return {"token": new_verification_token.token, "expires": new_verification_token.expires}
|
||||||
|
|
||||||
|
async def delete_verification_token(tokens: List[str]):
|
||||||
|
global prisma_client
|
||||||
|
try:
|
||||||
|
# Assuming 'db' is your Prisma Client instance
|
||||||
|
deleted_tokens = await prisma_client.litellm_verificationtoken.delete_many(
|
||||||
|
where={"token": {"in": tokens}}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
|
return deleted_tokens
|
||||||
|
|
||||||
|
|
||||||
async def generate_key_cli_task(duration_str):
|
async def generate_key_cli_task(duration_str):
|
||||||
task = asyncio.create_task(generate_key_helper_fn(duration_str=duration_str))
|
task = asyncio.create_task(generate_key_helper_fn(duration_str=duration_str))
|
||||||
await task
|
await task
|
||||||
|
@ -652,6 +665,28 @@ async def generate_key_fn(request: Request):
|
||||||
detail={"error": "models param must be a list"},
|
detail={"error": "models param must be a list"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@router.post("/key/delete", dependencies=[Depends(user_api_key_auth)])
|
||||||
|
async def generate_key_fn(request: Request):
|
||||||
|
try:
|
||||||
|
data = await request.json()
|
||||||
|
|
||||||
|
keys = data.get("keys", [])
|
||||||
|
|
||||||
|
if not isinstance(keys, list):
|
||||||
|
if isinstance(keys, str):
|
||||||
|
keys = [keys]
|
||||||
|
else:
|
||||||
|
raise Exception(f"keys must be an instance of either a string or a list")
|
||||||
|
|
||||||
|
deleted_keys = await delete_verification_token(tokens=keys)
|
||||||
|
assert len(keys) == deleted_keys
|
||||||
|
return {"deleted_keys": keys}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail={"error": str(e)},
|
||||||
|
)
|
||||||
|
|
||||||
@router.get("/test")
|
@router.get("/test")
|
||||||
async def test_endpoint(request: Request):
|
async def test_endpoint(request: Request):
|
||||||
return {"route": request.url.path}
|
return {"route": request.url.path}
|
||||||
|
|
|
@ -79,6 +79,7 @@ def test_sync_fallbacks():
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
response = router.completion(**kwargs)
|
response = router.completion(**kwargs)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
|
router.flush_cache()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
test_sync_fallbacks()
|
test_sync_fallbacks()
|
||||||
|
@ -96,6 +97,7 @@ def test_async_fallbacks():
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"An exception occurred: {e}")
|
pytest.fail(f"An exception occurred: {e}")
|
||||||
|
router.flush_cache()
|
||||||
|
|
||||||
asyncio.run(test_get_response())
|
asyncio.run(test_get_response())
|
||||||
|
|
||||||
|
@ -110,5 +112,6 @@ def test_sync_context_window_fallbacks():
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
router.flush_cache()
|
||||||
|
|
||||||
# test_sync_context_window_fallbacks()
|
# test_sync_context_window_fallbacks()
|
Loading…
Add table
Add a link
Reference in a new issue