From 8fbe2abb89aa37a83c1829976c2871cc15fd703e Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 16 Sep 2024 16:28:36 -0700 Subject: [PATCH] [Feat-Proxy] Add upperbound key duration param (#5727) * add upperbound key duration param * use upper bound values when None set * docs upperbound params --- docs/my-website/docs/proxy/ui.md | 9 +- docs/my-website/docs/proxy/virtual_keys.md | 8 +- litellm/proxy/_types.py | 9 ++ .../key_management_endpoints.py | 118 +++++++++++------- litellm/tests/test_key_generate_prisma.py | 50 +++++++- 5 files changed, 147 insertions(+), 47 deletions(-) diff --git a/docs/my-website/docs/proxy/ui.md b/docs/my-website/docs/proxy/ui.md index 375ef5ebf..d678d550c 100644 --- a/docs/my-website/docs/proxy/ui.md +++ b/docs/my-website/docs/proxy/ui.md @@ -72,8 +72,13 @@ Control the upperbound that users can use for `max_budget`, `budget_duration` or ```yaml litellm_settings: upperbound_key_generate_params: - max_budget: 100 # upperbound of $100, for all /key/generate requests - duration: "30d" # upperbound of 30 days for all /key/generate requests + max_budget: 100 # Optional[float], optional): upperbound of $100, for all /key/generate requests + budget_duration: "10d" # Optional[str], optional): upperbound of 10 days for budget_duration values + duration: "30d" # Optional[str], optional): upperbound of 30 days for all /key/generate requests + max_parallel_requests: 1000 # (Optional[int], optional): Max number of requests that can be made in parallel. Defaults to None. + tpm_limit: 1000 #(Optional[int], optional): Tpm limit. Defaults to None. + rpm_limit: 1000 #(Optional[int], optional): Rpm limit. Defaults to None. + ``` ** Expected Behavior ** diff --git a/docs/my-website/docs/proxy/virtual_keys.md b/docs/my-website/docs/proxy/virtual_keys.md index fa2da9f28..692d153fc 100644 --- a/docs/my-website/docs/proxy/virtual_keys.md +++ b/docs/my-website/docs/proxy/virtual_keys.md @@ -744,8 +744,12 @@ Set `litellm_settings:upperbound_key_generate_params`: ```yaml litellm_settings: upperbound_key_generate_params: - max_budget: 100 # upperbound of $100, for all /key/generate requests - duration: "30d" # upperbound of 30 days for all /key/generate requests + max_budget: 100 # Optional[float], optional): upperbound of $100, for all /key/generate requests + budget_duration: "10d" # Optional[str], optional): upperbound of 10 days for budget_duration values + duration: "30d" # Optional[str], optional): upperbound of 30 days for all /key/generate requests + max_parallel_requests: 1000 # (Optional[int], optional): Max number of requests that can be made in parallel. Defaults to None. + tpm_limit: 1000 #(Optional[int], optional): Tpm limit. Defaults to None. + rpm_limit: 1000 #(Optional[int], optional): Rpm limit. Defaults to None. ``` ** Expected Behavior ** diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index db3304745..c75ea4760 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -160,10 +160,19 @@ class LiteLLMBase(BaseModel): class LiteLLM_UpperboundKeyGenerateParams(LiteLLMBase): """ Set default upperbound to max budget a key called via `/key/generate` can be. + + Args: + max_budget (Optional[float], optional): Max budget a key can be. Defaults to None. + budget_duration (Optional[str], optional): Duration of the budget. Defaults to None. + duration (Optional[str], optional): Duration of the key. Defaults to None. + max_parallel_requests (Optional[int], optional): Max number of requests that can be made in parallel. Defaults to None. + tpm_limit (Optional[int], optional): Tpm limit. Defaults to None. + rpm_limit (Optional[int], optional): Rpm limit. Defaults to None. """ max_budget: Optional[float] = None budget_duration: Optional[str] = None + duration: Optional[str] = None max_parallel_requests: Optional[int] = None tpm_limit: Optional[int] = None rpm_limit: Optional[int] = None diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 90e7728d0..4accfbc09 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -28,6 +28,7 @@ from litellm.proxy._types import * from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.management_helpers.utils import management_endpoint_wrapper from litellm.proxy.utils import _duration_in_seconds +from litellm.secret_managers.main import get_secret router = APIRouter() @@ -103,7 +104,10 @@ async def generate_key_fn( verbose_proxy_logger.debug("entered /key/generate") if user_custom_key_generate is not None: - result = await user_custom_key_generate(data) + if asyncio.iscoroutinefunction(user_custom_key_generate): + result = await user_custom_key_generate(data) # type: ignore + else: + raise ValueError("user_custom_key_generate must be a coroutine") decision = result.get("decision", True) message = result.get("message", "Authentication Failed - Custom Auth Rule") if not decision: @@ -134,43 +138,42 @@ async def generate_key_fn( # check if user set default key/generate params on config.yaml if litellm.upperbound_key_generate_params is not None: for elem in data: - # if key in litellm.upperbound_key_generate_params, use the min of value and litellm.upperbound_key_generate_params[key] key, value = elem - if ( - value is not None - and getattr(litellm.upperbound_key_generate_params, key, None) - is not None - ): - # if value is float/int - if key in [ - "max_budget", - "max_parallel_requests", - "tpm_limit", - "rpm_limit", - ]: - if value > getattr(litellm.upperbound_key_generate_params, key): - raise HTTPException( - status_code=400, - detail={ - "error": f"{key} is over max limit set in config - user_value={value}; max_value={getattr(litellm.upperbound_key_generate_params, key)}" - }, - ) - elif key == "budget_duration": - # budgets are in 1s, 1m, 1h, 1d, 1m (30s, 30m, 30h, 30d, 30m) - # compare the duration in seconds and max duration in seconds - upperbound_budget_duration = _duration_in_seconds( - duration=getattr( - litellm.upperbound_key_generate_params, key - ) - ) - user_set_budget_duration = _duration_in_seconds(duration=value) - if user_set_budget_duration > upperbound_budget_duration: - raise HTTPException( - status_code=400, - detail={ - "error": f"Budget duration is over max limit set in config - user_value={user_set_budget_duration}; max_value={upperbound_budget_duration}" - }, + upperbound_value = getattr( + litellm.upperbound_key_generate_params, key, None + ) + if upperbound_value is not None: + if value is None: + # Use the upperbound value if user didn't provide a value + setattr(data, key, upperbound_value) + else: + # Compare with upperbound for numeric fields + if key in [ + "max_budget", + "max_parallel_requests", + "tpm_limit", + "rpm_limit", + ]: + if value > upperbound_value: + raise HTTPException( + status_code=400, + detail={ + "error": f"{key} is over max limit set in config - user_value={value}; max_value={upperbound_value}" + }, + ) + # Compare durations + elif key in ["budget_duration", "duration"]: + upperbound_duration = _duration_in_seconds( + duration=upperbound_value ) + user_duration = _duration_in_seconds(duration=value) + if user_duration > upperbound_duration: + raise HTTPException( + status_code=400, + detail={ + "error": f"{key} is over max limit set in config - user_value={value}; max_value={upperbound_value}" + }, + ) # TODO: @ishaan-jaff: Migrate all budget tracking to use LiteLLM_BudgetTable _budget_id = None @@ -418,6 +421,9 @@ async def update_key_fn( ) ) + if response is None: + raise ValueError("Failed to update key got response = None") + return {"key": key, **response["data"]} # update based on remaining passed in values except Exception as e: @@ -503,6 +509,14 @@ async def delete_key_fn( token=key, table_name="key", query_type="find_unique" ) + if key_row is None: + raise ProxyException( + message=f"Key {key} not found", + type=ProxyErrorTypes.bad_request_error, + param="key", + code=status.HTTP_404_NOT_FOUND, + ) + key_row = key_row.json(exclude_none=True) _key_row = json.dumps(key_row, default=str) @@ -527,6 +541,13 @@ async def delete_key_fn( number_deleted_keys = await delete_verification_token( tokens=keys, user_id=user_id ) + if number_deleted_keys is None: + raise ProxyException( + message="Failed to delete keys got None response from delete_verification_token", + type=ProxyErrorTypes.internal_server_error, + param="keys", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) verbose_proxy_logger.debug( f"/key/delete - deleted_keys={number_deleted_keys['deleted_keys']}" ) @@ -617,6 +638,11 @@ async def info_key_fn_v2( key_info = await prisma_client.get_data( token=data.keys, table_name="key", query_type="find_all" ) + if key_info is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={"message": "No keys found"}, + ) filtered_key_info = [] for k in key_info: try: @@ -691,6 +717,11 @@ async def info_key_fn( if key == None: key = user_api_key_dict.api_key key_info = await prisma_client.get_data(token=key) + if key_info is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={"message": "No keys found"}, + ) ## REMOVE HASHED TOKEN INFO BEFORE RETURNING ## try: key_info = key_info.model_dump() # noqa @@ -864,7 +895,7 @@ async def generate_key_helper_fn( } if ( - litellm.get_secret("DISABLE_KEY_NAME", False) is True + get_secret("DISABLE_KEY_NAME", False) is True ): # allow user to disable storing abbreviated key name (shown in UI, to help figure out which key spent how much) pass else: @@ -904,9 +935,12 @@ async def generate_key_helper_fn( user_row = await prisma_client.insert_data( data=user_data, table_name="user" ) + + if user_row is None: + raise Exception("Failed to create user") ## use default user model list if no key-specific model list provided if len(user_row.models) > 0 and len(key_data["models"]) == 0: # type: ignore - key_data["models"] = user_row.models + key_data["models"] = user_row.models # type: ignore elif query_type == "update_data": user_row = await prisma_client.update_data( data=user_data, @@ -964,6 +998,10 @@ async def delete_verification_token(tokens: List, user_id: Optional[str] = None) deleted_tokens = await prisma_client.delete_data( tokens=tokens, user_id=user_id ) + if deleted_tokens is None: + raise Exception( + "Failed to delete tokens got response None when deleting tokens" + ) _num_deleted_tokens = deleted_tokens.get("deleted_keys", 0) if _num_deleted_tokens != len(tokens): raise Exception( @@ -1168,10 +1206,6 @@ async def list_keys( return response except Exception as e: - logging.error(f"Error in list_keys: {str(e)}") - logging.error(f"Error type: {type(e)}") - logging.error(f"Error traceback: {traceback.format_exc()}") - raise ProxyException( message=f"Error listing keys: {str(e)}", type=ProxyErrorTypes.internal_server_error, # Use the enum value diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index 3a81cc27a..06b087ee7 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -2045,7 +2045,7 @@ async def test_default_key_params(prisma_client): @pytest.mark.asyncio() -async def test_upperbound_key_params(prisma_client): +async def test_upperbound_key_param_larger_budget(prisma_client): """ - create key - get key info @@ -2068,6 +2068,54 @@ async def test_upperbound_key_params(prisma_client): assert e.code == str(400) +@pytest.mark.asyncio() +async def test_upperbound_key_param_larger_duration(prisma_client): + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + litellm.upperbound_key_generate_params = LiteLLM_UpperboundKeyGenerateParams( + max_budget=100, duration="14d" + ) + await litellm.proxy.proxy_server.prisma_client.connect() + try: + request = GenerateKeyRequest( + max_budget=10, + duration="30d", + ) + key = await generate_key_fn(request) + pytest.fail("Expected this to fail but it passed") + # print(result) + except Exception as e: + assert e.code == str(400) + + +@pytest.mark.asyncio() +async def test_upperbound_key_param_none_duration(prisma_client): + from datetime import datetime, timedelta + + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + litellm.upperbound_key_generate_params = LiteLLM_UpperboundKeyGenerateParams( + max_budget=100, duration="14d" + ) + await litellm.proxy.proxy_server.prisma_client.connect() + try: + request = GenerateKeyRequest() + key = await generate_key_fn(request) + + print(key) + # print(result) + + assert key.max_budget == 100 + assert key.expires is not None + + _date_key_expires = key.expires.date() + _fourteen_days_from_now = (datetime.now() + timedelta(days=14)).date() + + assert _date_key_expires == _fourteen_days_from_now + except Exception as e: + pytest.fail(f"Got exception {e}") + + def test_get_bearer_token(): from litellm.proxy.auth.user_api_key_auth import _get_bearer_token