From 34bfebe47033c2db5cab24edba65553f41d63209 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 23 Nov 2024 16:59:46 -0800 Subject: [PATCH] (QOL improvement) Provider budget routing - allow using 1s, 1d, 1mo, 2mo etc (#6885) * use 1 file for duration_in_seconds * add to readme.md * re use duration_in_seconds * fix importing _extract_from_regex, get_last_day_of_month * fix import * update provider budget routing * fix - remove dup test --- .../docs/proxy/provider_budget_routing.md | 24 ++++- litellm/litellm_core_utils/README.md | 1 + litellm/litellm_core_utils/duration_parser.py | 92 +++++++++++++++++++ .../internal_user_endpoints.py | 8 +- .../key_management_endpoints.py | 16 ++-- .../team_callback_endpoints.py | 2 +- .../management_endpoints/team_endpoints.py | 22 ++--- litellm/proxy/proxy_server.py | 2 +- litellm/proxy/utils.py | 91 ++---------------- litellm/router_strategy/provider_budgets.py | 12 +-- .../test_router_provider_budgets.py | 17 ---- tests/local_testing/test_utils.py | 4 +- 12 files changed, 149 insertions(+), 142 deletions(-) create mode 100644 litellm/litellm_core_utils/duration_parser.py diff --git a/docs/my-website/docs/proxy/provider_budget_routing.md b/docs/my-website/docs/proxy/provider_budget_routing.md index fea3f483c..293f9e9d8 100644 --- a/docs/my-website/docs/proxy/provider_budget_routing.md +++ b/docs/my-website/docs/proxy/provider_budget_routing.md @@ -22,14 +22,14 @@ router_settings: provider_budget_config: openai: budget_limit: 0.000000000001 # float of $ value budget for time period - time_period: 1d # can be 1d, 2d, 30d + time_period: 1d # can be 1d, 2d, 30d, 1mo, 2mo azure: budget_limit: 100 time_period: 1d anthropic: budget_limit: 100 time_period: 10d - vertexai: + vertex_ai: budget_limit: 100 time_period: 12d gemini: @@ -112,8 +112,11 @@ Expected response on failure - If all providers exceed budget, raises an error 3. **Supported Time Periods**: - - Format: "Xd" where X is number of days - - Examples: "1d" (1 day), "30d" (30 days) + - Seconds: "Xs" (e.g., "30s") + - Minutes: "Xm" (e.g., "10m") + - Hours: "Xh" (e.g., "24h") + - Days: "Xd" (e.g., "1d", "30d") + - Months: "Xmo" (e.g., "1mo", "2mo") 4. **Requirements**: - Redis required for tracking spend across instances @@ -136,7 +139,12 @@ The `provider_budget_config` is a dictionary where: - **Key**: Provider name (string) - Must be a valid [LiteLLM provider name](https://docs.litellm.ai/docs/providers) - **Value**: Budget configuration object with the following parameters: - `budget_limit`: Float value representing the budget in USD - - `time_period`: String in the format "Xd" where X is the number of days (e.g., "1d", "30d") + - `time_period`: Duration string in one of the following formats: + - Seconds: `"Xs"` (e.g., "30s") + - Minutes: `"Xm"` (e.g., "10m") + - Hours: `"Xh"` (e.g., "24h") + - Days: `"Xd"` (e.g., "1d", "30d") + - Months: `"Xmo"` (e.g., "1mo", "2mo") Example structure: ```yaml @@ -147,4 +155,10 @@ provider_budget_config: azure: budget_limit: 500.0 # $500 USD time_period: "30d" # 30 day period + anthropic: + budget_limit: 200.0 # $200 USD + time_period: "1mo" # 1 month period + gemini: + budget_limit: 50.0 # $50 USD + time_period: "24h" # 24 hour period ``` \ No newline at end of file diff --git a/litellm/litellm_core_utils/README.md b/litellm/litellm_core_utils/README.md index 9cd351453..649404129 100644 --- a/litellm/litellm_core_utils/README.md +++ b/litellm/litellm_core_utils/README.md @@ -8,4 +8,5 @@ Core files: - `exception_mapping_utils.py`: utils for mapping exceptions to openai-compatible error types. - `default_encoding.py`: code for loading the default encoding (tiktoken) - `get_llm_provider_logic.py`: code for inferring the LLM provider from a given model name. +- `duration_parser.py`: code for parsing durations - e.g. "1d", "1mo", "10s" diff --git a/litellm/litellm_core_utils/duration_parser.py b/litellm/litellm_core_utils/duration_parser.py new file mode 100644 index 000000000..c8c6bea83 --- /dev/null +++ b/litellm/litellm_core_utils/duration_parser.py @@ -0,0 +1,92 @@ +""" +Helper utilities for parsing durations - 1s, 1d, 10d, 30d, 1mo, 2mo + +duration_in_seconds is used in diff parts of the code base, example +- Router - Provider budget routing +- Proxy - Key, Team Generation +""" + +import re +import time +from datetime import datetime, timedelta +from typing import Tuple + + +def _extract_from_regex(duration: str) -> Tuple[int, str]: + match = re.match(r"(\d+)(mo|[smhd]?)", duration) + + if not match: + raise ValueError("Invalid duration format") + + value, unit = match.groups() + value = int(value) + + return value, unit + + +def get_last_day_of_month(year, month): + # Handle December case + if month == 12: + return 31 + # Next month is January, so subtract a day from March 1st + next_month = datetime(year=year, month=month + 1, day=1) + last_day_of_month = (next_month - timedelta(days=1)).day + return last_day_of_month + + +def duration_in_seconds(duration: str) -> int: + """ + Parameters: + - duration: + - "s" - seconds + - "m" - minutes + - "h" - hours + - "d" - days + - "mo" - months + + Returns time in seconds till when budget needs to be reset + """ + value, unit = _extract_from_regex(duration=duration) + + if unit == "s": + return value + elif unit == "m": + return value * 60 + elif unit == "h": + return value * 3600 + elif unit == "d": + return value * 86400 + elif unit == "mo": + now = time.time() + current_time = datetime.fromtimestamp(now) + + if current_time.month == 12: + target_year = current_time.year + 1 + target_month = 1 + else: + target_year = current_time.year + target_month = current_time.month + value + + # Determine the day to set for next month + target_day = current_time.day + last_day_of_target_month = get_last_day_of_month(target_year, target_month) + + if target_day > last_day_of_target_month: + target_day = last_day_of_target_month + + next_month = datetime( + year=target_year, + month=target_month, + day=target_day, + hour=current_time.hour, + minute=current_time.minute, + second=current_time.second, + microsecond=current_time.microsecond, + ) + + # Calculate the duration until the first day of the next month + duration_until_next_month = next_month - current_time + return int(duration_until_next_month.total_seconds()) + + else: + raise ValueError(f"Unsupported duration unit, passed duration: {duration}") diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py index c69e255f2..c41975f50 100644 --- a/litellm/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -30,7 +30,7 @@ from litellm._logging import verbose_proxy_logger from litellm.proxy._types import * from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.management_endpoints.key_management_endpoints import ( - _duration_in_seconds, + duration_in_seconds, generate_key_helper_fn, ) from litellm.proxy.management_helpers.utils import ( @@ -516,7 +516,7 @@ async def user_update( is_internal_user = True if "budget_duration" in non_default_values: - duration_s = _duration_in_seconds( + duration_s = duration_in_seconds( duration=non_default_values["budget_duration"] ) user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) @@ -535,7 +535,7 @@ async def user_update( non_default_values["budget_duration"] = ( litellm.internal_user_budget_duration ) - duration_s = _duration_in_seconds( + duration_s = duration_in_seconds( duration=non_default_values["budget_duration"] ) user_reset_at = datetime.now(timezone.utc) + timedelta( @@ -725,8 +725,8 @@ async def delete_user( - user_ids: List[str] - The list of user id's to be deleted. """ from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, user_api_key_cache, diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 511e5a940..8353b5d91 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -34,8 +34,8 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks from litellm.proxy.management_helpers.utils import management_endpoint_wrapper from litellm.proxy.utils import ( - _duration_in_seconds, _hash_token_if_needed, + duration_in_seconds, handle_exception_on_proxy, ) from litellm.secret_managers.main import get_secret @@ -321,10 +321,10 @@ async def generate_key_fn( # noqa: PLR0915 ) # Compare durations elif key in ["budget_duration", "duration"]: - upperbound_duration = _duration_in_seconds( + upperbound_duration = duration_in_seconds( duration=upperbound_value ) - user_duration = _duration_in_seconds(duration=value) + user_duration = duration_in_seconds(duration=value) if user_duration > upperbound_duration: raise HTTPException( status_code=400, @@ -421,7 +421,7 @@ def prepare_key_update_data( if "duration" in non_default_values: duration = non_default_values.pop("duration") if duration and (isinstance(duration, str)) and len(duration) > 0: - duration_s = _duration_in_seconds(duration=duration) + duration_s = duration_in_seconds(duration=duration) expires = datetime.now(timezone.utc) + timedelta(seconds=duration_s) non_default_values["expires"] = expires @@ -432,7 +432,7 @@ def prepare_key_update_data( and (isinstance(budget_duration, str)) and len(budget_duration) > 0 ): - duration_s = _duration_in_seconds(duration=budget_duration) + duration_s = duration_in_seconds(duration=budget_duration) key_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) non_default_values["budget_reset_at"] = key_reset_at @@ -932,19 +932,19 @@ async def generate_key_helper_fn( # noqa: PLR0915 if duration is None: # allow tokens that never expire expires = None else: - duration_s = _duration_in_seconds(duration=duration) + duration_s = duration_in_seconds(duration=duration) expires = datetime.now(timezone.utc) + timedelta(seconds=duration_s) if key_budget_duration is None: # one-time budget key_reset_at = None else: - duration_s = _duration_in_seconds(duration=key_budget_duration) + duration_s = duration_in_seconds(duration=key_budget_duration) key_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) if budget_duration is None: # one-time budget reset_at = None else: - duration_s = _duration_in_seconds(duration=budget_duration) + duration_s = duration_in_seconds(duration=budget_duration) reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) aliases_json = json.dumps(aliases) diff --git a/litellm/proxy/management_endpoints/team_callback_endpoints.py b/litellm/proxy/management_endpoints/team_callback_endpoints.py index b60ea3acf..6c5fa80a2 100644 --- a/litellm/proxy/management_endpoints/team_callback_endpoints.py +++ b/litellm/proxy/management_endpoints/team_callback_endpoints.py @@ -90,8 +90,8 @@ async def add_team_callbacks( """ try: from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, ) diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index dc1ec444d..575e71119 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -169,8 +169,8 @@ async def new_team( # noqa: PLR0915 ``` """ from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, ) @@ -289,7 +289,7 @@ async def new_team( # noqa: PLR0915 # If budget_duration is set, set `budget_reset_at` if complete_team_data.budget_duration is not None: - duration_s = _duration_in_seconds(duration=complete_team_data.budget_duration) + duration_s = duration_in_seconds(duration=complete_team_data.budget_duration) reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) complete_team_data.budget_reset_at = reset_at @@ -396,8 +396,8 @@ async def update_team( """ from litellm.proxy.auth.auth_checks import _cache_team_object from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, proxy_logging_obj, @@ -425,7 +425,7 @@ async def update_team( # Check budget_duration and budget_reset_at if data.budget_duration is not None: - duration_s = _duration_in_seconds(duration=data.budget_duration) + duration_s = duration_in_seconds(duration=data.budget_duration) reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) # set the budget_reset_at in DB @@ -709,8 +709,8 @@ async def team_member_delete( ``` """ from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, ) @@ -829,8 +829,8 @@ async def team_member_update( Update team member budgets """ from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, ) @@ -965,8 +965,8 @@ async def delete_team( ``` """ from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, ) @@ -1054,8 +1054,8 @@ async def team_info( ``` """ from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, ) @@ -1203,8 +1203,8 @@ async def block_team( """ from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, ) @@ -1251,8 +1251,8 @@ async def unblock_team( ``` """ from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, ) @@ -1294,8 +1294,8 @@ async def list_team( - user_id: str - Optional. If passed will only return teams that the user_id is a member of. """ from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, ) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 70bf5b523..4a867f46a 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -182,8 +182,8 @@ from litellm.proxy.management_endpoints.internal_user_endpoints import ( ) from litellm.proxy.management_endpoints.internal_user_endpoints import user_update from litellm.proxy.management_endpoints.key_management_endpoints import ( - _duration_in_seconds, delete_verification_token, + duration_in_seconds, generate_key_helper_fn, ) from litellm.proxy.management_endpoints.key_management_endpoints import ( diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 0f7d6c3e0..2a298af21 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -26,6 +26,11 @@ from typing import ( overload, ) +from litellm.litellm_core_utils.duration_parser import ( + _extract_from_regex, + duration_in_seconds, + get_last_day_of_month, +) from litellm.proxy._types import ProxyErrorTypes, ProxyException try: @@ -2429,86 +2434,6 @@ def _hash_token_if_needed(token: str) -> str: return token -def _extract_from_regex(duration: str) -> Tuple[int, str]: - match = re.match(r"(\d+)(mo|[smhd]?)", duration) - - if not match: - raise ValueError("Invalid duration format") - - value, unit = match.groups() - value = int(value) - - return value, unit - - -def get_last_day_of_month(year, month): - # Handle December case - if month == 12: - return 31 - # Next month is January, so subtract a day from March 1st - next_month = datetime(year=year, month=month + 1, day=1) - last_day_of_month = (next_month - timedelta(days=1)).day - return last_day_of_month - - -def _duration_in_seconds(duration: str) -> int: - """ - Parameters: - - duration: - - "s" - seconds - - "m" - minutes - - "h" - hours - - "d" - days - - "mo" - months - - Returns time in seconds till when budget needs to be reset - """ - value, unit = _extract_from_regex(duration=duration) - - if unit == "s": - return value - elif unit == "m": - return value * 60 - elif unit == "h": - return value * 3600 - elif unit == "d": - return value * 86400 - elif unit == "mo": - now = time.time() - current_time = datetime.fromtimestamp(now) - - if current_time.month == 12: - target_year = current_time.year + 1 - target_month = 1 - else: - target_year = current_time.year - target_month = current_time.month + value - - # Determine the day to set for next month - target_day = current_time.day - last_day_of_target_month = get_last_day_of_month(target_year, target_month) - - if target_day > last_day_of_target_month: - target_day = last_day_of_target_month - - next_month = datetime( - year=target_year, - month=target_month, - day=target_day, - hour=current_time.hour, - minute=current_time.minute, - second=current_time.second, - microsecond=current_time.microsecond, - ) - - # Calculate the duration until the first day of the next month - duration_until_next_month = next_month - current_time - return int(duration_until_next_month.total_seconds()) - - else: - raise ValueError("Unsupported duration unit") - - async def reset_budget(prisma_client: PrismaClient): """ Gets all the non-expired keys for a db, which need spend to be reset @@ -2527,7 +2452,7 @@ async def reset_budget(prisma_client: PrismaClient): if keys_to_reset is not None and len(keys_to_reset) > 0: for key in keys_to_reset: key.spend = 0.0 - duration_s = _duration_in_seconds(duration=key.budget_duration) + duration_s = duration_in_seconds(duration=key.budget_duration) key.budget_reset_at = now + timedelta(seconds=duration_s) await prisma_client.update_data( @@ -2543,7 +2468,7 @@ async def reset_budget(prisma_client: PrismaClient): if users_to_reset is not None and len(users_to_reset) > 0: for user in users_to_reset: user.spend = 0.0 - duration_s = _duration_in_seconds(duration=user.budget_duration) + duration_s = duration_in_seconds(duration=user.budget_duration) user.budget_reset_at = now + timedelta(seconds=duration_s) await prisma_client.update_data( @@ -2561,7 +2486,7 @@ async def reset_budget(prisma_client: PrismaClient): if teams_to_reset is not None and len(teams_to_reset) > 0: team_reset_requests = [] for team in teams_to_reset: - duration_s = _duration_in_seconds(duration=team.budget_duration) + duration_s = duration_in_seconds(duration=team.budget_duration) reset_team_budget_request = ResetTeamBudgetRequest( team_id=team.team_id, spend=0.0, diff --git a/litellm/router_strategy/provider_budgets.py b/litellm/router_strategy/provider_budgets.py index 23d8b6c39..ea26d2c0f 100644 --- a/litellm/router_strategy/provider_budgets.py +++ b/litellm/router_strategy/provider_budgets.py @@ -25,6 +25,7 @@ from litellm._logging import verbose_router_logger from litellm.caching.caching import DualCache from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs +from litellm.litellm_core_utils.duration_parser import duration_in_seconds from litellm.router_utils.cooldown_callbacks import ( _get_prometheus_logger_from_callbacks, ) @@ -207,7 +208,7 @@ class ProviderBudgetLimiting(CustomLogger): ) spend_key = f"provider_spend:{custom_llm_provider}:{budget_config.time_period}" - ttl_seconds = self.get_ttl_seconds(budget_config.time_period) + ttl_seconds = duration_in_seconds(duration=budget_config.time_period) verbose_router_logger.debug( f"Incrementing spend for {spend_key} by {response_cost}, ttl: {ttl_seconds}" ) @@ -242,15 +243,6 @@ class ProviderBudgetLimiting(CustomLogger): return None return custom_llm_provider - def get_ttl_seconds(self, time_period: str) -> int: - """ - Convert time period (e.g., '1d', '30d') to seconds for Redis TTL - """ - if time_period.endswith("d"): - days = int(time_period[:-1]) - return days * 24 * 60 * 60 - raise ValueError(f"Unsupported time period format: {time_period}") - def _track_provider_remaining_budget_prometheus( self, provider: str, spend: float, budget_limit: float ): diff --git a/tests/local_testing/test_router_provider_budgets.py b/tests/local_testing/test_router_provider_budgets.py index 46b9ee29e..a6574ba4b 100644 --- a/tests/local_testing/test_router_provider_budgets.py +++ b/tests/local_testing/test_router_provider_budgets.py @@ -142,23 +142,6 @@ async def test_provider_budgets_e2e_test_expect_to_fail(): assert "Exceeded budget for provider" in str(exc_info.value) -def test_get_ttl_seconds(): - """ - Test the get_ttl_seconds helper method" - - """ - provider_budget = ProviderBudgetLimiting( - router_cache=DualCache(), provider_budget_config={} - ) - - assert provider_budget.get_ttl_seconds("1d") == 86400 # 1 day in seconds - assert provider_budget.get_ttl_seconds("7d") == 604800 # 7 days in seconds - assert provider_budget.get_ttl_seconds("30d") == 2592000 # 30 days in seconds - - with pytest.raises(ValueError, match="Unsupported time period format"): - provider_budget.get_ttl_seconds("1h") - - def test_get_llm_provider_for_deployment(): """ Test the _get_llm_provider_for_deployment helper method diff --git a/tests/local_testing/test_utils.py b/tests/local_testing/test_utils.py index cf1db27e8..7c349a658 100644 --- a/tests/local_testing/test_utils.py +++ b/tests/local_testing/test_utils.py @@ -17,7 +17,7 @@ import pytest import litellm from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, headers from litellm.proxy.utils import ( - _duration_in_seconds, + duration_in_seconds, _extract_from_regex, get_last_day_of_month, ) @@ -593,7 +593,7 @@ def test_duration_in_seconds(): duration_until_next_month = next_month - current_time expected_duration = int(duration_until_next_month.total_seconds()) - value = _duration_in_seconds(duration="1mo") + value = duration_in_seconds(duration="1mo") assert value - expected_duration < 2