allow passing expiry time to /key/regenerate

This commit is contained in:
Ishaan Jaff 2024-09-06 08:36:34 -07:00
parent b5349e97c7
commit aed59abe35
4 changed files with 342 additions and 160 deletions

View file

@ -280,6 +280,52 @@ async def generate_key_fn(
)
async def prepare_key_update_data(data: UpdateKeyRequest, existing_key_row):
data_json: dict = data.dict(exclude_unset=True)
key = data_json.pop("key", None)
_metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails"]
non_default_values = {}
for k, v in data_json.items():
if k in _metadata_fields:
continue
if v is not None and v not in ([], {}, 0):
non_default_values[k] = v
if "duration" in non_default_values:
duration = non_default_values.pop("duration")
duration_s = _duration_in_seconds(duration=duration)
expires = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
non_default_values["expires"] = expires
if "budget_duration" in non_default_values:
duration_s = _duration_in_seconds(
duration=non_default_values["budget_duration"]
)
key_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
non_default_values["budget_reset_at"] = key_reset_at
_metadata = existing_key_row.metadata or {}
if data.model_tpm_limit:
if "model_tpm_limit" not in _metadata:
_metadata["model_tpm_limit"] = {}
_metadata["model_tpm_limit"].update(data.model_tpm_limit)
non_default_values["metadata"] = _metadata
if data.model_rpm_limit:
if "model_rpm_limit" not in _metadata:
_metadata["model_rpm_limit"] = {}
_metadata["model_rpm_limit"].update(data.model_rpm_limit)
non_default_values["metadata"] = _metadata
if data.guardrails:
_metadata["guardrails"] = data.guardrails
non_default_values["metadata"] = _metadata
return non_default_values
@router.post(
"/key/update", tags=["key management"], dependencies=[Depends(user_api_key_auth)]
)
@ -323,59 +369,9 @@ async def update_key_fn(
detail={"error": f"Team not found, passed team_id={data.team_id}"},
)
_metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails"]
# get non default values for key
non_default_values = {}
for k, v in data_json.items():
# this field gets stored in metadata
if key in _metadata_fields:
continue
if v is not None and v not in (
[],
{},
0,
): # models default to [], spend defaults to 0, we should not reset these values
non_default_values[k] = v
if "duration" in non_default_values:
duration = non_default_values.pop("duration")
duration_s = _duration_in_seconds(duration=duration)
expires = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
non_default_values["expires"] = expires
if "budget_duration" in non_default_values:
duration_s = _duration_in_seconds(
duration=non_default_values["budget_duration"]
)
key_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
non_default_values["budget_reset_at"] = key_reset_at
# Update metadata for virtual Key
if data.model_tpm_limit:
_metadata = existing_key_row.metadata or {}
if "model_tpm_limit" not in _metadata:
_metadata["model_tpm_limit"] = {}
_metadata["model_tpm_limit"].update(data.model_tpm_limit)
non_default_values["metadata"] = _metadata
non_default_values.pop("model_tpm_limit", None)
if data.model_rpm_limit:
_metadata = existing_key_row.metadata or {}
if "model_rpm_limit" not in _metadata:
_metadata["model_rpm_limit"] = {}
_metadata["model_rpm_limit"].update(data.model_rpm_limit)
non_default_values["metadata"] = _metadata
non_default_values.pop("model_rpm_limit", None)
if data.guardrails:
_metadata = existing_key_row.metadata or {}
_metadata["guardrails"] = data.guardrails
# update values that will be written to the DB
non_default_values["metadata"] = _metadata
non_default_values.pop("guardrails", None)
non_default_values = await prepare_key_update_data(
data=data, existing_key_row=existing_key_row
)
response = await prisma_client.update_data(
token=key, data={**non_default_values, "token": key}
@ -983,6 +979,7 @@ async def delete_verification_token(tokens: List, user_id: Optional[str] = None)
@management_endpoint_wrapper
async def regenerate_key_fn(
key: str,
data: Optional[RegenerateKeyRequest] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
litellm_changed_by: Optional[str] = Header(
None,
@ -1041,14 +1038,26 @@ async def regenerate_key_fn(
new_token_hash = hash_token(new_token)
new_token_key_name = f"sk-...{new_token[-4:]}"
# update new token in DB
# Prepare the update data
update_data = {
"token": new_token_hash,
"key_name": new_token_key_name,
}
non_default_values = {}
if data is not None:
# Update with any provided parameters from GenerateKeyRequest
non_default_values = await prepare_key_update_data(
data=data, existing_key_row=_key_in_db
)
update_data.update(non_default_values)
# Update the token in the database
updated_token = await prisma_client.db.litellm_verificationtoken.update(
where={"token": hashed_api_key},
data={
"token": new_token_hash,
"key_name": new_token_key_name,
},
data=update_data,
)
updated_token_dict = {}
if updated_token is not None:
updated_token_dict = dict(updated_token)