mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
[Feat-Proxy] Add upperbound key duration param (#5727)
* add upperbound key duration param * use upper bound values when None set * docs upperbound params
This commit is contained in:
parent
0d18292549
commit
a252c95c7f
5 changed files with 147 additions and 47 deletions
|
@ -28,6 +28,7 @@ from litellm.proxy._types import *
|
|||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
|
||||
from litellm.proxy.utils import _duration_in_seconds
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
@ -103,7 +104,10 @@ async def generate_key_fn(
|
|||
verbose_proxy_logger.debug("entered /key/generate")
|
||||
|
||||
if user_custom_key_generate is not None:
|
||||
result = await user_custom_key_generate(data)
|
||||
if asyncio.iscoroutinefunction(user_custom_key_generate):
|
||||
result = await user_custom_key_generate(data) # type: ignore
|
||||
else:
|
||||
raise ValueError("user_custom_key_generate must be a coroutine")
|
||||
decision = result.get("decision", True)
|
||||
message = result.get("message", "Authentication Failed - Custom Auth Rule")
|
||||
if not decision:
|
||||
|
@ -134,43 +138,42 @@ async def generate_key_fn(
|
|||
# check if user set default key/generate params on config.yaml
|
||||
if litellm.upperbound_key_generate_params is not None:
|
||||
for elem in data:
|
||||
# if key in litellm.upperbound_key_generate_params, use the min of value and litellm.upperbound_key_generate_params[key]
|
||||
key, value = elem
|
||||
if (
|
||||
value is not None
|
||||
and getattr(litellm.upperbound_key_generate_params, key, None)
|
||||
is not None
|
||||
):
|
||||
# if value is float/int
|
||||
if key in [
|
||||
"max_budget",
|
||||
"max_parallel_requests",
|
||||
"tpm_limit",
|
||||
"rpm_limit",
|
||||
]:
|
||||
if value > getattr(litellm.upperbound_key_generate_params, key):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"{key} is over max limit set in config - user_value={value}; max_value={getattr(litellm.upperbound_key_generate_params, key)}"
|
||||
},
|
||||
)
|
||||
elif key == "budget_duration":
|
||||
# budgets are in 1s, 1m, 1h, 1d, 1m (30s, 30m, 30h, 30d, 30m)
|
||||
# compare the duration in seconds and max duration in seconds
|
||||
upperbound_budget_duration = _duration_in_seconds(
|
||||
duration=getattr(
|
||||
litellm.upperbound_key_generate_params, key
|
||||
)
|
||||
)
|
||||
user_set_budget_duration = _duration_in_seconds(duration=value)
|
||||
if user_set_budget_duration > upperbound_budget_duration:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Budget duration is over max limit set in config - user_value={user_set_budget_duration}; max_value={upperbound_budget_duration}"
|
||||
},
|
||||
upperbound_value = getattr(
|
||||
litellm.upperbound_key_generate_params, key, None
|
||||
)
|
||||
if upperbound_value is not None:
|
||||
if value is None:
|
||||
# Use the upperbound value if user didn't provide a value
|
||||
setattr(data, key, upperbound_value)
|
||||
else:
|
||||
# Compare with upperbound for numeric fields
|
||||
if key in [
|
||||
"max_budget",
|
||||
"max_parallel_requests",
|
||||
"tpm_limit",
|
||||
"rpm_limit",
|
||||
]:
|
||||
if value > upperbound_value:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"{key} is over max limit set in config - user_value={value}; max_value={upperbound_value}"
|
||||
},
|
||||
)
|
||||
# Compare durations
|
||||
elif key in ["budget_duration", "duration"]:
|
||||
upperbound_duration = _duration_in_seconds(
|
||||
duration=upperbound_value
|
||||
)
|
||||
user_duration = _duration_in_seconds(duration=value)
|
||||
if user_duration > upperbound_duration:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"{key} is over max limit set in config - user_value={value}; max_value={upperbound_value}"
|
||||
},
|
||||
)
|
||||
|
||||
# TODO: @ishaan-jaff: Migrate all budget tracking to use LiteLLM_BudgetTable
|
||||
_budget_id = None
|
||||
|
@ -418,6 +421,9 @@ async def update_key_fn(
|
|||
)
|
||||
)
|
||||
|
||||
if response is None:
|
||||
raise ValueError("Failed to update key got response = None")
|
||||
|
||||
return {"key": key, **response["data"]}
|
||||
# update based on remaining passed in values
|
||||
except Exception as e:
|
||||
|
@ -503,6 +509,14 @@ async def delete_key_fn(
|
|||
token=key, table_name="key", query_type="find_unique"
|
||||
)
|
||||
|
||||
if key_row is None:
|
||||
raise ProxyException(
|
||||
message=f"Key {key} not found",
|
||||
type=ProxyErrorTypes.bad_request_error,
|
||||
param="key",
|
||||
code=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
key_row = key_row.json(exclude_none=True)
|
||||
_key_row = json.dumps(key_row, default=str)
|
||||
|
||||
|
@ -527,6 +541,13 @@ async def delete_key_fn(
|
|||
number_deleted_keys = await delete_verification_token(
|
||||
tokens=keys, user_id=user_id
|
||||
)
|
||||
if number_deleted_keys is None:
|
||||
raise ProxyException(
|
||||
message="Failed to delete keys got None response from delete_verification_token",
|
||||
type=ProxyErrorTypes.internal_server_error,
|
||||
param="keys",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"/key/delete - deleted_keys={number_deleted_keys['deleted_keys']}"
|
||||
)
|
||||
|
@ -617,6 +638,11 @@ async def info_key_fn_v2(
|
|||
key_info = await prisma_client.get_data(
|
||||
token=data.keys, table_name="key", query_type="find_all"
|
||||
)
|
||||
if key_info is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail={"message": "No keys found"},
|
||||
)
|
||||
filtered_key_info = []
|
||||
for k in key_info:
|
||||
try:
|
||||
|
@ -691,6 +717,11 @@ async def info_key_fn(
|
|||
if key == None:
|
||||
key = user_api_key_dict.api_key
|
||||
key_info = await prisma_client.get_data(token=key)
|
||||
if key_info is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail={"message": "No keys found"},
|
||||
)
|
||||
## REMOVE HASHED TOKEN INFO BEFORE RETURNING ##
|
||||
try:
|
||||
key_info = key_info.model_dump() # noqa
|
||||
|
@ -864,7 +895,7 @@ async def generate_key_helper_fn(
|
|||
}
|
||||
|
||||
if (
|
||||
litellm.get_secret("DISABLE_KEY_NAME", False) is True
|
||||
get_secret("DISABLE_KEY_NAME", False) is True
|
||||
): # allow user to disable storing abbreviated key name (shown in UI, to help figure out which key spent how much)
|
||||
pass
|
||||
else:
|
||||
|
@ -904,9 +935,12 @@ async def generate_key_helper_fn(
|
|||
user_row = await prisma_client.insert_data(
|
||||
data=user_data, table_name="user"
|
||||
)
|
||||
|
||||
if user_row is None:
|
||||
raise Exception("Failed to create user")
|
||||
## use default user model list if no key-specific model list provided
|
||||
if len(user_row.models) > 0 and len(key_data["models"]) == 0: # type: ignore
|
||||
key_data["models"] = user_row.models
|
||||
key_data["models"] = user_row.models # type: ignore
|
||||
elif query_type == "update_data":
|
||||
user_row = await prisma_client.update_data(
|
||||
data=user_data,
|
||||
|
@ -964,6 +998,10 @@ async def delete_verification_token(tokens: List, user_id: Optional[str] = None)
|
|||
deleted_tokens = await prisma_client.delete_data(
|
||||
tokens=tokens, user_id=user_id
|
||||
)
|
||||
if deleted_tokens is None:
|
||||
raise Exception(
|
||||
"Failed to delete tokens got response None when deleting tokens"
|
||||
)
|
||||
_num_deleted_tokens = deleted_tokens.get("deleted_keys", 0)
|
||||
if _num_deleted_tokens != len(tokens):
|
||||
raise Exception(
|
||||
|
@ -1168,10 +1206,6 @@ async def list_keys(
|
|||
return response
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in list_keys: {str(e)}")
|
||||
logging.error(f"Error type: {type(e)}")
|
||||
logging.error(f"Error traceback: {traceback.format_exc()}")
|
||||
|
||||
raise ProxyException(
|
||||
message=f"Error listing keys: {str(e)}",
|
||||
type=ProxyErrorTypes.internal_server_error, # Use the enum value
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue