[Feat-Proxy] Add upperbound key duration param (#5727)

* add upperbound key duration param

* use upper bound values when None set

* docs upperbound params
This commit is contained in:
Ishaan Jaff 2024-09-16 16:28:36 -07:00 committed by GitHub
parent 0d18292549
commit a252c95c7f
5 changed files with 147 additions and 47 deletions

View file

@ -28,6 +28,7 @@ from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
from litellm.proxy.utils import _duration_in_seconds
from litellm.secret_managers.main import get_secret
router = APIRouter()
@ -103,7 +104,10 @@ async def generate_key_fn(
verbose_proxy_logger.debug("entered /key/generate")
if user_custom_key_generate is not None:
result = await user_custom_key_generate(data)
if asyncio.iscoroutinefunction(user_custom_key_generate):
result = await user_custom_key_generate(data) # type: ignore
else:
raise ValueError("user_custom_key_generate must be a coroutine")
decision = result.get("decision", True)
message = result.get("message", "Authentication Failed - Custom Auth Rule")
if not decision:
@ -134,43 +138,42 @@ async def generate_key_fn(
# check if user set default key/generate params on config.yaml
if litellm.upperbound_key_generate_params is not None:
for elem in data:
# if key in litellm.upperbound_key_generate_params, use the min of value and litellm.upperbound_key_generate_params[key]
key, value = elem
if (
value is not None
and getattr(litellm.upperbound_key_generate_params, key, None)
is not None
):
# if value is float/int
if key in [
"max_budget",
"max_parallel_requests",
"tpm_limit",
"rpm_limit",
]:
if value > getattr(litellm.upperbound_key_generate_params, key):
raise HTTPException(
status_code=400,
detail={
"error": f"{key} is over max limit set in config - user_value={value}; max_value={getattr(litellm.upperbound_key_generate_params, key)}"
},
)
elif key == "budget_duration":
# budgets are in 1s, 1m, 1h, 1d, 1m (30s, 30m, 30h, 30d, 30m)
# compare the duration in seconds and max duration in seconds
upperbound_budget_duration = _duration_in_seconds(
duration=getattr(
litellm.upperbound_key_generate_params, key
)
)
user_set_budget_duration = _duration_in_seconds(duration=value)
if user_set_budget_duration > upperbound_budget_duration:
raise HTTPException(
status_code=400,
detail={
"error": f"Budget duration is over max limit set in config - user_value={user_set_budget_duration}; max_value={upperbound_budget_duration}"
},
upperbound_value = getattr(
litellm.upperbound_key_generate_params, key, None
)
if upperbound_value is not None:
if value is None:
# Use the upperbound value if user didn't provide a value
setattr(data, key, upperbound_value)
else:
# Compare with upperbound for numeric fields
if key in [
"max_budget",
"max_parallel_requests",
"tpm_limit",
"rpm_limit",
]:
if value > upperbound_value:
raise HTTPException(
status_code=400,
detail={
"error": f"{key} is over max limit set in config - user_value={value}; max_value={upperbound_value}"
},
)
# Compare durations
elif key in ["budget_duration", "duration"]:
upperbound_duration = _duration_in_seconds(
duration=upperbound_value
)
user_duration = _duration_in_seconds(duration=value)
if user_duration > upperbound_duration:
raise HTTPException(
status_code=400,
detail={
"error": f"{key} is over max limit set in config - user_value={value}; max_value={upperbound_value}"
},
)
# TODO: @ishaan-jaff: Migrate all budget tracking to use LiteLLM_BudgetTable
_budget_id = None
@ -418,6 +421,9 @@ async def update_key_fn(
)
)
if response is None:
raise ValueError("Failed to update key got response = None")
return {"key": key, **response["data"]}
# update based on remaining passed in values
except Exception as e:
@ -503,6 +509,14 @@ async def delete_key_fn(
token=key, table_name="key", query_type="find_unique"
)
if key_row is None:
raise ProxyException(
message=f"Key {key} not found",
type=ProxyErrorTypes.bad_request_error,
param="key",
code=status.HTTP_404_NOT_FOUND,
)
key_row = key_row.json(exclude_none=True)
_key_row = json.dumps(key_row, default=str)
@ -527,6 +541,13 @@ async def delete_key_fn(
number_deleted_keys = await delete_verification_token(
tokens=keys, user_id=user_id
)
if number_deleted_keys is None:
raise ProxyException(
message="Failed to delete keys got None response from delete_verification_token",
type=ProxyErrorTypes.internal_server_error,
param="keys",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
verbose_proxy_logger.debug(
f"/key/delete - deleted_keys={number_deleted_keys['deleted_keys']}"
)
@ -617,6 +638,11 @@ async def info_key_fn_v2(
key_info = await prisma_client.get_data(
token=data.keys, table_name="key", query_type="find_all"
)
if key_info is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={"message": "No keys found"},
)
filtered_key_info = []
for k in key_info:
try:
@ -691,6 +717,11 @@ async def info_key_fn(
if key == None:
key = user_api_key_dict.api_key
key_info = await prisma_client.get_data(token=key)
if key_info is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={"message": "No keys found"},
)
## REMOVE HASHED TOKEN INFO BEFORE RETURNING ##
try:
key_info = key_info.model_dump() # noqa
@ -864,7 +895,7 @@ async def generate_key_helper_fn(
}
if (
litellm.get_secret("DISABLE_KEY_NAME", False) is True
get_secret("DISABLE_KEY_NAME", False) is True
): # allow user to disable storing abbreviated key name (shown in UI, to help figure out which key spent how much)
pass
else:
@ -904,9 +935,12 @@ async def generate_key_helper_fn(
user_row = await prisma_client.insert_data(
data=user_data, table_name="user"
)
if user_row is None:
raise Exception("Failed to create user")
## use default user model list if no key-specific model list provided
if len(user_row.models) > 0 and len(key_data["models"]) == 0: # type: ignore
key_data["models"] = user_row.models
key_data["models"] = user_row.models # type: ignore
elif query_type == "update_data":
user_row = await prisma_client.update_data(
data=user_data,
@ -964,6 +998,10 @@ async def delete_verification_token(tokens: List, user_id: Optional[str] = None)
deleted_tokens = await prisma_client.delete_data(
tokens=tokens, user_id=user_id
)
if deleted_tokens is None:
raise Exception(
"Failed to delete tokens got response None when deleting tokens"
)
_num_deleted_tokens = deleted_tokens.get("deleted_keys", 0)
if _num_deleted_tokens != len(tokens):
raise Exception(
@ -1168,10 +1206,6 @@ async def list_keys(
return response
except Exception as e:
logging.error(f"Error in list_keys: {str(e)}")
logging.error(f"Error type: {type(e)}")
logging.error(f"Error traceback: {traceback.format_exc()}")
raise ProxyException(
message=f"Error listing keys: {str(e)}",
type=ProxyErrorTypes.internal_server_error, # Use the enum value