forked from phoenix/litellm-mirror
[Feat-Proxy] Add upperbound key duration param (#5727)
* add upperbound key duration param * use upper bound values when None set * docs upperbound params
This commit is contained in:
parent
3a5039e284
commit
8fbe2abb89
5 changed files with 147 additions and 47 deletions
|
@ -72,8 +72,13 @@ Control the upperbound that users can use for `max_budget`, `budget_duration` or
|
|||
```yaml
|
||||
litellm_settings:
|
||||
upperbound_key_generate_params:
|
||||
max_budget: 100 # upperbound of $100, for all /key/generate requests
|
||||
duration: "30d" # upperbound of 30 days for all /key/generate requests
|
||||
max_budget: 100 # Optional[float], optional): upperbound of $100, for all /key/generate requests
|
||||
budget_duration: "10d" # Optional[str], optional): upperbound of 10 days for budget_duration values
|
||||
duration: "30d" # Optional[str], optional): upperbound of 30 days for all /key/generate requests
|
||||
max_parallel_requests: 1000 # (Optional[int], optional): Max number of requests that can be made in parallel. Defaults to None.
|
||||
tpm_limit: 1000 #(Optional[int], optional): Tpm limit. Defaults to None.
|
||||
rpm_limit: 1000 #(Optional[int], optional): Rpm limit. Defaults to None.
|
||||
|
||||
```
|
||||
|
||||
** Expected Behavior **
|
||||
|
|
|
@ -744,8 +744,12 @@ Set `litellm_settings:upperbound_key_generate_params`:
|
|||
```yaml
|
||||
litellm_settings:
|
||||
upperbound_key_generate_params:
|
||||
max_budget: 100 # upperbound of $100, for all /key/generate requests
|
||||
duration: "30d" # upperbound of 30 days for all /key/generate requests
|
||||
max_budget: 100 # Optional[float], optional): upperbound of $100, for all /key/generate requests
|
||||
budget_duration: "10d" # Optional[str], optional): upperbound of 10 days for budget_duration values
|
||||
duration: "30d" # Optional[str], optional): upperbound of 30 days for all /key/generate requests
|
||||
max_parallel_requests: 1000 # (Optional[int], optional): Max number of requests that can be made in parallel. Defaults to None.
|
||||
tpm_limit: 1000 #(Optional[int], optional): Tpm limit. Defaults to None.
|
||||
rpm_limit: 1000 #(Optional[int], optional): Rpm limit. Defaults to None.
|
||||
```
|
||||
|
||||
** Expected Behavior **
|
||||
|
|
|
@ -160,10 +160,19 @@ class LiteLLMBase(BaseModel):
|
|||
class LiteLLM_UpperboundKeyGenerateParams(LiteLLMBase):
|
||||
"""
|
||||
Set default upperbound to max budget a key called via `/key/generate` can be.
|
||||
|
||||
Args:
|
||||
max_budget (Optional[float], optional): Max budget a key can be. Defaults to None.
|
||||
budget_duration (Optional[str], optional): Duration of the budget. Defaults to None.
|
||||
duration (Optional[str], optional): Duration of the key. Defaults to None.
|
||||
max_parallel_requests (Optional[int], optional): Max number of requests that can be made in parallel. Defaults to None.
|
||||
tpm_limit (Optional[int], optional): Tpm limit. Defaults to None.
|
||||
rpm_limit (Optional[int], optional): Rpm limit. Defaults to None.
|
||||
"""
|
||||
|
||||
max_budget: Optional[float] = None
|
||||
budget_duration: Optional[str] = None
|
||||
duration: Optional[str] = None
|
||||
max_parallel_requests: Optional[int] = None
|
||||
tpm_limit: Optional[int] = None
|
||||
rpm_limit: Optional[int] = None
|
||||
|
|
|
@ -28,6 +28,7 @@ from litellm.proxy._types import *
|
|||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
|
||||
from litellm.proxy.utils import _duration_in_seconds
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
@ -103,7 +104,10 @@ async def generate_key_fn(
|
|||
verbose_proxy_logger.debug("entered /key/generate")
|
||||
|
||||
if user_custom_key_generate is not None:
|
||||
result = await user_custom_key_generate(data)
|
||||
if asyncio.iscoroutinefunction(user_custom_key_generate):
|
||||
result = await user_custom_key_generate(data) # type: ignore
|
||||
else:
|
||||
raise ValueError("user_custom_key_generate must be a coroutine")
|
||||
decision = result.get("decision", True)
|
||||
message = result.get("message", "Authentication Failed - Custom Auth Rule")
|
||||
if not decision:
|
||||
|
@ -134,43 +138,42 @@ async def generate_key_fn(
|
|||
# check if user set default key/generate params on config.yaml
|
||||
if litellm.upperbound_key_generate_params is not None:
|
||||
for elem in data:
|
||||
# if key in litellm.upperbound_key_generate_params, use the min of value and litellm.upperbound_key_generate_params[key]
|
||||
key, value = elem
|
||||
if (
|
||||
value is not None
|
||||
and getattr(litellm.upperbound_key_generate_params, key, None)
|
||||
is not None
|
||||
):
|
||||
# if value is float/int
|
||||
if key in [
|
||||
"max_budget",
|
||||
"max_parallel_requests",
|
||||
"tpm_limit",
|
||||
"rpm_limit",
|
||||
]:
|
||||
if value > getattr(litellm.upperbound_key_generate_params, key):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"{key} is over max limit set in config - user_value={value}; max_value={getattr(litellm.upperbound_key_generate_params, key)}"
|
||||
},
|
||||
)
|
||||
elif key == "budget_duration":
|
||||
# budgets are in 1s, 1m, 1h, 1d, 1m (30s, 30m, 30h, 30d, 30m)
|
||||
# compare the duration in seconds and max duration in seconds
|
||||
upperbound_budget_duration = _duration_in_seconds(
|
||||
duration=getattr(
|
||||
litellm.upperbound_key_generate_params, key
|
||||
)
|
||||
)
|
||||
user_set_budget_duration = _duration_in_seconds(duration=value)
|
||||
if user_set_budget_duration > upperbound_budget_duration:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Budget duration is over max limit set in config - user_value={user_set_budget_duration}; max_value={upperbound_budget_duration}"
|
||||
},
|
||||
upperbound_value = getattr(
|
||||
litellm.upperbound_key_generate_params, key, None
|
||||
)
|
||||
if upperbound_value is not None:
|
||||
if value is None:
|
||||
# Use the upperbound value if user didn't provide a value
|
||||
setattr(data, key, upperbound_value)
|
||||
else:
|
||||
# Compare with upperbound for numeric fields
|
||||
if key in [
|
||||
"max_budget",
|
||||
"max_parallel_requests",
|
||||
"tpm_limit",
|
||||
"rpm_limit",
|
||||
]:
|
||||
if value > upperbound_value:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"{key} is over max limit set in config - user_value={value}; max_value={upperbound_value}"
|
||||
},
|
||||
)
|
||||
# Compare durations
|
||||
elif key in ["budget_duration", "duration"]:
|
||||
upperbound_duration = _duration_in_seconds(
|
||||
duration=upperbound_value
|
||||
)
|
||||
user_duration = _duration_in_seconds(duration=value)
|
||||
if user_duration > upperbound_duration:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"{key} is over max limit set in config - user_value={value}; max_value={upperbound_value}"
|
||||
},
|
||||
)
|
||||
|
||||
# TODO: @ishaan-jaff: Migrate all budget tracking to use LiteLLM_BudgetTable
|
||||
_budget_id = None
|
||||
|
@ -418,6 +421,9 @@ async def update_key_fn(
|
|||
)
|
||||
)
|
||||
|
||||
if response is None:
|
||||
raise ValueError("Failed to update key got response = None")
|
||||
|
||||
return {"key": key, **response["data"]}
|
||||
# update based on remaining passed in values
|
||||
except Exception as e:
|
||||
|
@ -503,6 +509,14 @@ async def delete_key_fn(
|
|||
token=key, table_name="key", query_type="find_unique"
|
||||
)
|
||||
|
||||
if key_row is None:
|
||||
raise ProxyException(
|
||||
message=f"Key {key} not found",
|
||||
type=ProxyErrorTypes.bad_request_error,
|
||||
param="key",
|
||||
code=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
key_row = key_row.json(exclude_none=True)
|
||||
_key_row = json.dumps(key_row, default=str)
|
||||
|
||||
|
@ -527,6 +541,13 @@ async def delete_key_fn(
|
|||
number_deleted_keys = await delete_verification_token(
|
||||
tokens=keys, user_id=user_id
|
||||
)
|
||||
if number_deleted_keys is None:
|
||||
raise ProxyException(
|
||||
message="Failed to delete keys got None response from delete_verification_token",
|
||||
type=ProxyErrorTypes.internal_server_error,
|
||||
param="keys",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"/key/delete - deleted_keys={number_deleted_keys['deleted_keys']}"
|
||||
)
|
||||
|
@ -617,6 +638,11 @@ async def info_key_fn_v2(
|
|||
key_info = await prisma_client.get_data(
|
||||
token=data.keys, table_name="key", query_type="find_all"
|
||||
)
|
||||
if key_info is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail={"message": "No keys found"},
|
||||
)
|
||||
filtered_key_info = []
|
||||
for k in key_info:
|
||||
try:
|
||||
|
@ -691,6 +717,11 @@ async def info_key_fn(
|
|||
if key == None:
|
||||
key = user_api_key_dict.api_key
|
||||
key_info = await prisma_client.get_data(token=key)
|
||||
if key_info is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail={"message": "No keys found"},
|
||||
)
|
||||
## REMOVE HASHED TOKEN INFO BEFORE RETURNING ##
|
||||
try:
|
||||
key_info = key_info.model_dump() # noqa
|
||||
|
@ -864,7 +895,7 @@ async def generate_key_helper_fn(
|
|||
}
|
||||
|
||||
if (
|
||||
litellm.get_secret("DISABLE_KEY_NAME", False) is True
|
||||
get_secret("DISABLE_KEY_NAME", False) is True
|
||||
): # allow user to disable storing abbreviated key name (shown in UI, to help figure out which key spent how much)
|
||||
pass
|
||||
else:
|
||||
|
@ -904,9 +935,12 @@ async def generate_key_helper_fn(
|
|||
user_row = await prisma_client.insert_data(
|
||||
data=user_data, table_name="user"
|
||||
)
|
||||
|
||||
if user_row is None:
|
||||
raise Exception("Failed to create user")
|
||||
## use default user model list if no key-specific model list provided
|
||||
if len(user_row.models) > 0 and len(key_data["models"]) == 0: # type: ignore
|
||||
key_data["models"] = user_row.models
|
||||
key_data["models"] = user_row.models # type: ignore
|
||||
elif query_type == "update_data":
|
||||
user_row = await prisma_client.update_data(
|
||||
data=user_data,
|
||||
|
@ -964,6 +998,10 @@ async def delete_verification_token(tokens: List, user_id: Optional[str] = None)
|
|||
deleted_tokens = await prisma_client.delete_data(
|
||||
tokens=tokens, user_id=user_id
|
||||
)
|
||||
if deleted_tokens is None:
|
||||
raise Exception(
|
||||
"Failed to delete tokens got response None when deleting tokens"
|
||||
)
|
||||
_num_deleted_tokens = deleted_tokens.get("deleted_keys", 0)
|
||||
if _num_deleted_tokens != len(tokens):
|
||||
raise Exception(
|
||||
|
@ -1168,10 +1206,6 @@ async def list_keys(
|
|||
return response
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in list_keys: {str(e)}")
|
||||
logging.error(f"Error type: {type(e)}")
|
||||
logging.error(f"Error traceback: {traceback.format_exc()}")
|
||||
|
||||
raise ProxyException(
|
||||
message=f"Error listing keys: {str(e)}",
|
||||
type=ProxyErrorTypes.internal_server_error, # Use the enum value
|
||||
|
|
|
@ -2045,7 +2045,7 @@ async def test_default_key_params(prisma_client):
|
|||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_upperbound_key_params(prisma_client):
|
||||
async def test_upperbound_key_param_larger_budget(prisma_client):
|
||||
"""
|
||||
- create key
|
||||
- get key info
|
||||
|
@ -2068,6 +2068,54 @@ async def test_upperbound_key_params(prisma_client):
|
|||
assert e.code == str(400)
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_upperbound_key_param_larger_duration(prisma_client):
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
litellm.upperbound_key_generate_params = LiteLLM_UpperboundKeyGenerateParams(
|
||||
max_budget=100, duration="14d"
|
||||
)
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
try:
|
||||
request = GenerateKeyRequest(
|
||||
max_budget=10,
|
||||
duration="30d",
|
||||
)
|
||||
key = await generate_key_fn(request)
|
||||
pytest.fail("Expected this to fail but it passed")
|
||||
# print(result)
|
||||
except Exception as e:
|
||||
assert e.code == str(400)
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_upperbound_key_param_none_duration(prisma_client):
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
litellm.upperbound_key_generate_params = LiteLLM_UpperboundKeyGenerateParams(
|
||||
max_budget=100, duration="14d"
|
||||
)
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
try:
|
||||
request = GenerateKeyRequest()
|
||||
key = await generate_key_fn(request)
|
||||
|
||||
print(key)
|
||||
# print(result)
|
||||
|
||||
assert key.max_budget == 100
|
||||
assert key.expires is not None
|
||||
|
||||
_date_key_expires = key.expires.date()
|
||||
_fourteen_days_from_now = (datetime.now() + timedelta(days=14)).date()
|
||||
|
||||
assert _date_key_expires == _fourteen_days_from_now
|
||||
except Exception as e:
|
||||
pytest.fail(f"Got exception {e}")
|
||||
|
||||
|
||||
def test_get_bearer_token():
|
||||
from litellm.proxy.auth.user_api_key_auth import _get_bearer_token
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue