[Feat-Proxy] Add upperbound key duration param (#5727)

* add upperbound key duration param

* use upper bound values when None set

* docs upperbound params
This commit is contained in:
Ishaan Jaff 2024-09-16 16:28:36 -07:00 committed by GitHub
parent 3a5039e284
commit 8fbe2abb89
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 147 additions and 47 deletions

View file

@ -72,8 +72,13 @@ Control the upperbound that users can use for `max_budget`, `budget_duration` or
```yaml
litellm_settings:
upperbound_key_generate_params:
max_budget: 100 # upperbound of $100, for all /key/generate requests
duration: "30d" # upperbound of 30 days for all /key/generate requests
max_budget: 100 # Optional[float], optional): upperbound of $100, for all /key/generate requests
budget_duration: "10d" # Optional[str], optional): upperbound of 10 days for budget_duration values
duration: "30d" # Optional[str], optional): upperbound of 30 days for all /key/generate requests
max_parallel_requests: 1000 # (Optional[int], optional): Max number of requests that can be made in parallel. Defaults to None.
tpm_limit: 1000 #(Optional[int], optional): Tpm limit. Defaults to None.
rpm_limit: 1000 #(Optional[int], optional): Rpm limit. Defaults to None.
```
** Expected Behavior **

View file

@ -744,8 +744,12 @@ Set `litellm_settings:upperbound_key_generate_params`:
```yaml
litellm_settings:
upperbound_key_generate_params:
max_budget: 100 # upperbound of $100, for all /key/generate requests
duration: "30d" # upperbound of 30 days for all /key/generate requests
max_budget: 100 # Optional[float], optional): upperbound of $100, for all /key/generate requests
budget_duration: "10d" # Optional[str], optional): upperbound of 10 days for budget_duration values
duration: "30d" # Optional[str], optional): upperbound of 30 days for all /key/generate requests
max_parallel_requests: 1000 # (Optional[int], optional): Max number of requests that can be made in parallel. Defaults to None.
tpm_limit: 1000 #(Optional[int], optional): Tpm limit. Defaults to None.
rpm_limit: 1000 #(Optional[int], optional): Rpm limit. Defaults to None.
```
** Expected Behavior **

View file

@ -160,10 +160,19 @@ class LiteLLMBase(BaseModel):
class LiteLLM_UpperboundKeyGenerateParams(LiteLLMBase):
"""
Set default upperbound to max budget a key called via `/key/generate` can be.
Args:
max_budget (Optional[float], optional): Max budget a key can be. Defaults to None.
budget_duration (Optional[str], optional): Duration of the budget. Defaults to None.
duration (Optional[str], optional): Duration of the key. Defaults to None.
max_parallel_requests (Optional[int], optional): Max number of requests that can be made in parallel. Defaults to None.
tpm_limit (Optional[int], optional): Tpm limit. Defaults to None.
rpm_limit (Optional[int], optional): Rpm limit. Defaults to None.
"""
max_budget: Optional[float] = None
budget_duration: Optional[str] = None
duration: Optional[str] = None
max_parallel_requests: Optional[int] = None
tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None

View file

@ -28,6 +28,7 @@ from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
from litellm.proxy.utils import _duration_in_seconds
from litellm.secret_managers.main import get_secret
router = APIRouter()
@ -103,7 +104,10 @@ async def generate_key_fn(
verbose_proxy_logger.debug("entered /key/generate")
if user_custom_key_generate is not None:
result = await user_custom_key_generate(data)
if asyncio.iscoroutinefunction(user_custom_key_generate):
result = await user_custom_key_generate(data) # type: ignore
else:
raise ValueError("user_custom_key_generate must be a coroutine")
decision = result.get("decision", True)
message = result.get("message", "Authentication Failed - Custom Auth Rule")
if not decision:
@ -134,43 +138,42 @@ async def generate_key_fn(
# check if user set default key/generate params on config.yaml
if litellm.upperbound_key_generate_params is not None:
for elem in data:
# if key in litellm.upperbound_key_generate_params, use the min of value and litellm.upperbound_key_generate_params[key]
key, value = elem
if (
value is not None
and getattr(litellm.upperbound_key_generate_params, key, None)
is not None
):
# if value is float/int
if key in [
"max_budget",
"max_parallel_requests",
"tpm_limit",
"rpm_limit",
]:
if value > getattr(litellm.upperbound_key_generate_params, key):
raise HTTPException(
status_code=400,
detail={
"error": f"{key} is over max limit set in config - user_value={value}; max_value={getattr(litellm.upperbound_key_generate_params, key)}"
},
)
elif key == "budget_duration":
# budgets are in 1s, 1m, 1h, 1d, 1m (30s, 30m, 30h, 30d, 30m)
# compare the duration in seconds and max duration in seconds
upperbound_budget_duration = _duration_in_seconds(
duration=getattr(
litellm.upperbound_key_generate_params, key
)
)
user_set_budget_duration = _duration_in_seconds(duration=value)
if user_set_budget_duration > upperbound_budget_duration:
raise HTTPException(
status_code=400,
detail={
"error": f"Budget duration is over max limit set in config - user_value={user_set_budget_duration}; max_value={upperbound_budget_duration}"
},
upperbound_value = getattr(
litellm.upperbound_key_generate_params, key, None
)
if upperbound_value is not None:
if value is None:
# Use the upperbound value if user didn't provide a value
setattr(data, key, upperbound_value)
else:
# Compare with upperbound for numeric fields
if key in [
"max_budget",
"max_parallel_requests",
"tpm_limit",
"rpm_limit",
]:
if value > upperbound_value:
raise HTTPException(
status_code=400,
detail={
"error": f"{key} is over max limit set in config - user_value={value}; max_value={upperbound_value}"
},
)
# Compare durations
elif key in ["budget_duration", "duration"]:
upperbound_duration = _duration_in_seconds(
duration=upperbound_value
)
user_duration = _duration_in_seconds(duration=value)
if user_duration > upperbound_duration:
raise HTTPException(
status_code=400,
detail={
"error": f"{key} is over max limit set in config - user_value={value}; max_value={upperbound_value}"
},
)
# TODO: @ishaan-jaff: Migrate all budget tracking to use LiteLLM_BudgetTable
_budget_id = None
@ -418,6 +421,9 @@ async def update_key_fn(
)
)
if response is None:
raise ValueError("Failed to update key got response = None")
return {"key": key, **response["data"]}
# update based on remaining passed in values
except Exception as e:
@ -503,6 +509,14 @@ async def delete_key_fn(
token=key, table_name="key", query_type="find_unique"
)
if key_row is None:
raise ProxyException(
message=f"Key {key} not found",
type=ProxyErrorTypes.bad_request_error,
param="key",
code=status.HTTP_404_NOT_FOUND,
)
key_row = key_row.json(exclude_none=True)
_key_row = json.dumps(key_row, default=str)
@ -527,6 +541,13 @@ async def delete_key_fn(
number_deleted_keys = await delete_verification_token(
tokens=keys, user_id=user_id
)
if number_deleted_keys is None:
raise ProxyException(
message="Failed to delete keys got None response from delete_verification_token",
type=ProxyErrorTypes.internal_server_error,
param="keys",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
verbose_proxy_logger.debug(
f"/key/delete - deleted_keys={number_deleted_keys['deleted_keys']}"
)
@ -617,6 +638,11 @@ async def info_key_fn_v2(
key_info = await prisma_client.get_data(
token=data.keys, table_name="key", query_type="find_all"
)
if key_info is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={"message": "No keys found"},
)
filtered_key_info = []
for k in key_info:
try:
@ -691,6 +717,11 @@ async def info_key_fn(
if key == None:
key = user_api_key_dict.api_key
key_info = await prisma_client.get_data(token=key)
if key_info is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={"message": "No keys found"},
)
## REMOVE HASHED TOKEN INFO BEFORE RETURNING ##
try:
key_info = key_info.model_dump() # noqa
@ -864,7 +895,7 @@ async def generate_key_helper_fn(
}
if (
litellm.get_secret("DISABLE_KEY_NAME", False) is True
get_secret("DISABLE_KEY_NAME", False) is True
): # allow user to disable storing abbreviated key name (shown in UI, to help figure out which key spent how much)
pass
else:
@ -904,9 +935,12 @@ async def generate_key_helper_fn(
user_row = await prisma_client.insert_data(
data=user_data, table_name="user"
)
if user_row is None:
raise Exception("Failed to create user")
## use default user model list if no key-specific model list provided
if len(user_row.models) > 0 and len(key_data["models"]) == 0: # type: ignore
key_data["models"] = user_row.models
key_data["models"] = user_row.models # type: ignore
elif query_type == "update_data":
user_row = await prisma_client.update_data(
data=user_data,
@ -964,6 +998,10 @@ async def delete_verification_token(tokens: List, user_id: Optional[str] = None)
deleted_tokens = await prisma_client.delete_data(
tokens=tokens, user_id=user_id
)
if deleted_tokens is None:
raise Exception(
"Failed to delete tokens got response None when deleting tokens"
)
_num_deleted_tokens = deleted_tokens.get("deleted_keys", 0)
if _num_deleted_tokens != len(tokens):
raise Exception(
@ -1168,10 +1206,6 @@ async def list_keys(
return response
except Exception as e:
logging.error(f"Error in list_keys: {str(e)}")
logging.error(f"Error type: {type(e)}")
logging.error(f"Error traceback: {traceback.format_exc()}")
raise ProxyException(
message=f"Error listing keys: {str(e)}",
type=ProxyErrorTypes.internal_server_error, # Use the enum value

View file

@ -2045,7 +2045,7 @@ async def test_default_key_params(prisma_client):
@pytest.mark.asyncio()
async def test_upperbound_key_params(prisma_client):
async def test_upperbound_key_param_larger_budget(prisma_client):
"""
- create key
- get key info
@ -2068,6 +2068,54 @@ async def test_upperbound_key_params(prisma_client):
assert e.code == str(400)
@pytest.mark.asyncio()
async def test_upperbound_key_param_larger_duration(prisma_client):
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
litellm.upperbound_key_generate_params = LiteLLM_UpperboundKeyGenerateParams(
max_budget=100, duration="14d"
)
await litellm.proxy.proxy_server.prisma_client.connect()
try:
request = GenerateKeyRequest(
max_budget=10,
duration="30d",
)
key = await generate_key_fn(request)
pytest.fail("Expected this to fail but it passed")
# print(result)
except Exception as e:
assert e.code == str(400)
@pytest.mark.asyncio()
async def test_upperbound_key_param_none_duration(prisma_client):
from datetime import datetime, timedelta
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
litellm.upperbound_key_generate_params = LiteLLM_UpperboundKeyGenerateParams(
max_budget=100, duration="14d"
)
await litellm.proxy.proxy_server.prisma_client.connect()
try:
request = GenerateKeyRequest()
key = await generate_key_fn(request)
print(key)
# print(result)
assert key.max_budget == 100
assert key.expires is not None
_date_key_expires = key.expires.date()
_fourteen_days_from_now = (datetime.now() + timedelta(days=14)).date()
assert _date_key_expires == _fourteen_days_from_now
except Exception as e:
pytest.fail(f"Got exception {e}")
def test_get_bearer_token():
from litellm.proxy.auth.user_api_key_auth import _get_bearer_token