feat(proxy_server.py): add new /key/update endpoint

This commit is contained in:
Krrish Dholakia 2023-12-12 17:18:51 -08:00
parent 23a4ac724b
commit 693292a64c
3 changed files with 74 additions and 3 deletions

View file

@ -121,11 +121,24 @@ class GenerateKeyRequest(LiteLLMBase):
user_id: Optional[str] = None user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None max_parallel_requests: Optional[int] = None
class UpdateKeyRequest(LiteLLMBase):
key: str
duration: Optional[str] = None
models: Optional[list] = None
aliases: Optional[dict] = None
config: Optional[dict] = None
spend: Optional[float] = None
user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None
class GenerateKeyResponse(LiteLLMBase): class GenerateKeyResponse(LiteLLMBase):
key: str key: str
expires: datetime expires: datetime
user_id: str user_id: str
class _DeleteKeyObject(LiteLLMBase): class _DeleteKeyObject(LiteLLMBase):
key: str key: str

View file

@ -252,15 +252,15 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
if api_key is None: # only require api key if master key is set if api_key is None: # only require api key if master key is set
raise Exception(f"No api key passed in.") raise Exception(f"No api key passed in.")
route = request.url.path route: str = request.url.path
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead # note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
is_master_key_valid = secrets.compare_digest(api_key, master_key) is_master_key_valid = secrets.compare_digest(api_key, master_key)
if is_master_key_valid: if is_master_key_valid:
return UserAPIKeyAuth(api_key=master_key) return UserAPIKeyAuth(api_key=master_key)
if (route == "/key/generate" or route == "/key/delete" or route == "/key/info") and not is_master_key_valid: if route.startswith("/key/") and not is_master_key_valid:
raise Exception(f"If master key is set, only master key can be used to generate, delete or get info for new keys") raise Exception(f"If master key is set, only master key can be used to generate, delete, update or get info for new keys")
if prisma_client is None: # if both master key + user key submitted, and user key != master key, and no db connected, raise an error if prisma_client is None: # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
raise Exception("No connected db.") raise Exception("No connected db.")
@ -676,6 +676,8 @@ async def generate_key_helper_fn(duration: Optional[str], models: list, aliases:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
return {"token": token, "expires": new_verification_token.expires, "user_id": user_id} return {"token": token, "expires": new_verification_token.expires, "user_id": user_id}
async def delete_verification_token(tokens: List): async def delete_verification_token(tokens: List):
global prisma_client global prisma_client
try: try:
@ -1140,6 +1142,30 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest, Authorizat
response = await generate_key_helper_fn(**data_json) response = await generate_key_helper_fn(**data_json)
return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"]) return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"])
@router.post("/key/update", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
async def update_key_fn(request: Request, data: UpdateKeyRequest):
"""
Update an existing key
"""
global prisma_client
try:
data_json: dict = data.json()
key = data_json.pop("key")
# get the row from db
if prisma_client is None:
raise Exception("Not connected to DB!")
non_default_values = {k: v for k, v in data_json.items() if v is not None}
print(f"non_default_values: {non_default_values}")
response = await prisma_client.update_data(token=key, data={**non_default_values, "token": key})
return {"key": key, **non_default_values}
# update based on remaining passed in values
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": str(e)},
)
@router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)]) @router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
async def delete_key_fn(request: Request, data: DeleteKeyRequest): async def delete_key_fn(request: Request, data: DeleteKeyRequest):
try: try:

View file

@ -71,6 +71,38 @@ def test_add_new_key(client):
except Exception as e: except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}") pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}")
def test_update_new_key(client):
try:
# Your test data
test_data = {
"models": ["gpt-3.5-turbo", "gpt-4", "claude-2", "azure-model"],
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
"duration": "20m"
}
print("testing proxy server")
# Your bearer token
token = os.getenv("PROXY_MASTER_KEY")
headers = {
"Authorization": f"Bearer {token}"
}
response = client.post("/key/generate", json=test_data, headers=headers)
print(f"response: {response.text}")
assert response.status_code == 200
result = response.json()
assert result["key"].startswith("sk-")
def _post_data():
json_data = {'models': ['bedrock-models'], "key": result["key"]}
response = client.post("/key/update", json=json_data, headers=headers)
print(f"response text: {response.text}")
assert response.status_code == 200
return response
_post_data()
print(f"Received response: {result}")
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}")
# # Run the test - only runs via pytest # # Run the test - only runs via pytest