mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
allow passing expiry time to /key/regenerate
This commit is contained in:
parent
b5349e97c7
commit
aed59abe35
4 changed files with 342 additions and 160 deletions
|
@ -280,6 +280,52 @@ async def generate_key_fn(
|
|||
)
|
||||
|
||||
|
||||
async def prepare_key_update_data(data: UpdateKeyRequest, existing_key_row):
|
||||
data_json: dict = data.dict(exclude_unset=True)
|
||||
key = data_json.pop("key", None)
|
||||
|
||||
_metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails"]
|
||||
non_default_values = {}
|
||||
for k, v in data_json.items():
|
||||
if k in _metadata_fields:
|
||||
continue
|
||||
if v is not None and v not in ([], {}, 0):
|
||||
non_default_values[k] = v
|
||||
|
||||
if "duration" in non_default_values:
|
||||
duration = non_default_values.pop("duration")
|
||||
duration_s = _duration_in_seconds(duration=duration)
|
||||
expires = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
||||
non_default_values["expires"] = expires
|
||||
|
||||
if "budget_duration" in non_default_values:
|
||||
duration_s = _duration_in_seconds(
|
||||
duration=non_default_values["budget_duration"]
|
||||
)
|
||||
key_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
||||
non_default_values["budget_reset_at"] = key_reset_at
|
||||
|
||||
_metadata = existing_key_row.metadata or {}
|
||||
|
||||
if data.model_tpm_limit:
|
||||
if "model_tpm_limit" not in _metadata:
|
||||
_metadata["model_tpm_limit"] = {}
|
||||
_metadata["model_tpm_limit"].update(data.model_tpm_limit)
|
||||
non_default_values["metadata"] = _metadata
|
||||
|
||||
if data.model_rpm_limit:
|
||||
if "model_rpm_limit" not in _metadata:
|
||||
_metadata["model_rpm_limit"] = {}
|
||||
_metadata["model_rpm_limit"].update(data.model_rpm_limit)
|
||||
non_default_values["metadata"] = _metadata
|
||||
|
||||
if data.guardrails:
|
||||
_metadata["guardrails"] = data.guardrails
|
||||
non_default_values["metadata"] = _metadata
|
||||
|
||||
return non_default_values
|
||||
|
||||
|
||||
@router.post(
|
||||
"/key/update", tags=["key management"], dependencies=[Depends(user_api_key_auth)]
|
||||
)
|
||||
|
@ -323,59 +369,9 @@ async def update_key_fn(
|
|||
detail={"error": f"Team not found, passed team_id={data.team_id}"},
|
||||
)
|
||||
|
||||
_metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails"]
|
||||
# get non default values for key
|
||||
non_default_values = {}
|
||||
for k, v in data_json.items():
|
||||
# this field gets stored in metadata
|
||||
if key in _metadata_fields:
|
||||
continue
|
||||
if v is not None and v not in (
|
||||
[],
|
||||
{},
|
||||
0,
|
||||
): # models default to [], spend defaults to 0, we should not reset these values
|
||||
non_default_values[k] = v
|
||||
|
||||
if "duration" in non_default_values:
|
||||
duration = non_default_values.pop("duration")
|
||||
duration_s = _duration_in_seconds(duration=duration)
|
||||
expires = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
||||
non_default_values["expires"] = expires
|
||||
|
||||
if "budget_duration" in non_default_values:
|
||||
duration_s = _duration_in_seconds(
|
||||
duration=non_default_values["budget_duration"]
|
||||
)
|
||||
key_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
||||
non_default_values["budget_reset_at"] = key_reset_at
|
||||
|
||||
# Update metadata for virtual Key
|
||||
if data.model_tpm_limit:
|
||||
_metadata = existing_key_row.metadata or {}
|
||||
if "model_tpm_limit" not in _metadata:
|
||||
_metadata["model_tpm_limit"] = {}
|
||||
|
||||
_metadata["model_tpm_limit"].update(data.model_tpm_limit)
|
||||
non_default_values["metadata"] = _metadata
|
||||
non_default_values.pop("model_tpm_limit", None)
|
||||
|
||||
if data.model_rpm_limit:
|
||||
_metadata = existing_key_row.metadata or {}
|
||||
if "model_rpm_limit" not in _metadata:
|
||||
_metadata["model_rpm_limit"] = {}
|
||||
|
||||
_metadata["model_rpm_limit"].update(data.model_rpm_limit)
|
||||
non_default_values["metadata"] = _metadata
|
||||
non_default_values.pop("model_rpm_limit", None)
|
||||
|
||||
if data.guardrails:
|
||||
_metadata = existing_key_row.metadata or {}
|
||||
_metadata["guardrails"] = data.guardrails
|
||||
|
||||
# update values that will be written to the DB
|
||||
non_default_values["metadata"] = _metadata
|
||||
non_default_values.pop("guardrails", None)
|
||||
non_default_values = await prepare_key_update_data(
|
||||
data=data, existing_key_row=existing_key_row
|
||||
)
|
||||
|
||||
response = await prisma_client.update_data(
|
||||
token=key, data={**non_default_values, "token": key}
|
||||
|
@ -983,6 +979,7 @@ async def delete_verification_token(tokens: List, user_id: Optional[str] = None)
|
|||
@management_endpoint_wrapper
|
||||
async def regenerate_key_fn(
|
||||
key: str,
|
||||
data: Optional[RegenerateKeyRequest] = None,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
litellm_changed_by: Optional[str] = Header(
|
||||
None,
|
||||
|
@ -1041,14 +1038,26 @@ async def regenerate_key_fn(
|
|||
new_token_hash = hash_token(new_token)
|
||||
new_token_key_name = f"sk-...{new_token[-4:]}"
|
||||
|
||||
# update new token in DB
|
||||
# Prepare the update data
|
||||
update_data = {
|
||||
"token": new_token_hash,
|
||||
"key_name": new_token_key_name,
|
||||
}
|
||||
|
||||
non_default_values = {}
|
||||
if data is not None:
|
||||
# Update with any provided parameters from GenerateKeyRequest
|
||||
non_default_values = await prepare_key_update_data(
|
||||
data=data, existing_key_row=_key_in_db
|
||||
)
|
||||
|
||||
update_data.update(non_default_values)
|
||||
# Update the token in the database
|
||||
updated_token = await prisma_client.db.litellm_verificationtoken.update(
|
||||
where={"token": hashed_api_key},
|
||||
data={
|
||||
"token": new_token_hash,
|
||||
"key_name": new_token_key_name,
|
||||
},
|
||||
data=update_data,
|
||||
)
|
||||
|
||||
updated_token_dict = {}
|
||||
if updated_token is not None:
|
||||
updated_token_dict = dict(updated_token)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue