feat(proxy_server.py): add new /key/update endpoint

This commit is contained in:
Krrish Dholakia 2023-12-12 17:18:51 -08:00
parent 23a4ac724b
commit 693292a64c
3 changed files with 74 additions and 3 deletions

View file

@ -121,11 +121,24 @@ class GenerateKeyRequest(LiteLLMBase):
user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None
class UpdateKeyRequest(LiteLLMBase):
key: str
duration: Optional[str] = None
models: Optional[list] = None
aliases: Optional[dict] = None
config: Optional[dict] = None
spend: Optional[float] = None
user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None
class GenerateKeyResponse(LiteLLMBase):
key: str
expires: datetime
user_id: str
class _DeleteKeyObject(LiteLLMBase):
key: str

View file

@ -252,15 +252,15 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
if api_key is None: # only require api key if master key is set
raise Exception(f"No api key passed in.")
route = request.url.path
route: str = request.url.path
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
is_master_key_valid = secrets.compare_digest(api_key, master_key)
if is_master_key_valid:
return UserAPIKeyAuth(api_key=master_key)
if (route == "/key/generate" or route == "/key/delete" or route == "/key/info") and not is_master_key_valid:
raise Exception(f"If master key is set, only master key can be used to generate, delete or get info for new keys")
if route.startswith("/key/") and not is_master_key_valid:
raise Exception(f"If master key is set, only master key can be used to generate, delete, update or get info for new keys")
if prisma_client is None: # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
raise Exception("No connected db.")
@ -676,6 +676,8 @@ async def generate_key_helper_fn(duration: Optional[str], models: list, aliases:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
return {"token": token, "expires": new_verification_token.expires, "user_id": user_id}
async def delete_verification_token(tokens: List):
global prisma_client
try:
@ -1140,6 +1142,30 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest, Authorizat
response = await generate_key_helper_fn(**data_json)
return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"])
@router.post("/key/update", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
async def update_key_fn(request: Request, data: UpdateKeyRequest):
"""
Update an existing key
"""
global prisma_client
try:
data_json: dict = data.json()
key = data_json.pop("key")
# get the row from db
if prisma_client is None:
raise Exception("Not connected to DB!")
non_default_values = {k: v for k, v in data_json.items() if v is not None}
print(f"non_default_values: {non_default_values}")
response = await prisma_client.update_data(token=key, data={**non_default_values, "token": key})
return {"key": key, **non_default_values}
# update based on remaining passed in values
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": str(e)},
)
@router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
async def delete_key_fn(request: Request, data: DeleteKeyRequest):
try:

View file

@ -71,6 +71,38 @@ def test_add_new_key(client):
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}")
def test_update_new_key(client):
try:
# Your test data
test_data = {
"models": ["gpt-3.5-turbo", "gpt-4", "claude-2", "azure-model"],
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
"duration": "20m"
}
print("testing proxy server")
# Your bearer token
token = os.getenv("PROXY_MASTER_KEY")
headers = {
"Authorization": f"Bearer {token}"
}
response = client.post("/key/generate", json=test_data, headers=headers)
print(f"response: {response.text}")
assert response.status_code == 200
result = response.json()
assert result["key"].startswith("sk-")
def _post_data():
json_data = {'models': ['bedrock-models'], "key": result["key"]}
response = client.post("/key/update", json=json_data, headers=headers)
print(f"response text: {response.text}")
assert response.status_code == 200
return response
_post_data()
print(f"Received response: {result}")
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}")
# # Run the test - only runs via pytest