diff --git a/litellm/litellm_core_utils/duration_parser.py b/litellm/litellm_core_utils/duration_parser.py new file mode 100644 index 000000000..c8c6bea83 --- /dev/null +++ b/litellm/litellm_core_utils/duration_parser.py @@ -0,0 +1,92 @@ +""" +Helper utilities for parsing durations - 1s, 1d, 10d, 30d, 1mo, 2mo + +duration_in_seconds is used in diff parts of the code base, example +- Router - Provider budget routing +- Proxy - Key, Team Generation +""" + +import re +import time +from datetime import datetime, timedelta +from typing import Tuple + + +def _extract_from_regex(duration: str) -> Tuple[int, str]: + match = re.match(r"(\d+)(mo|[smhd]?)", duration) + + if not match: + raise ValueError("Invalid duration format") + + value, unit = match.groups() + value = int(value) + + return value, unit + + +def get_last_day_of_month(year, month): + # Handle December case + if month == 12: + return 31 + # Next month is January, so subtract a day from March 1st + next_month = datetime(year=year, month=month + 1, day=1) + last_day_of_month = (next_month - timedelta(days=1)).day + return last_day_of_month + + +def duration_in_seconds(duration: str) -> int: + """ + Parameters: + - duration: + - "s" - seconds + - "m" - minutes + - "h" - hours + - "d" - days + - "mo" - months + + Returns time in seconds till when budget needs to be reset + """ + value, unit = _extract_from_regex(duration=duration) + + if unit == "s": + return value + elif unit == "m": + return value * 60 + elif unit == "h": + return value * 3600 + elif unit == "d": + return value * 86400 + elif unit == "mo": + now = time.time() + current_time = datetime.fromtimestamp(now) + + if current_time.month == 12: + target_year = current_time.year + 1 + target_month = 1 + else: + target_year = current_time.year + target_month = current_time.month + value + + # Determine the day to set for next month + target_day = current_time.day + last_day_of_target_month = get_last_day_of_month(target_year, target_month) + + if target_day > last_day_of_target_month: + target_day = last_day_of_target_month + + next_month = datetime( + year=target_year, + month=target_month, + day=target_day, + hour=current_time.hour, + minute=current_time.minute, + second=current_time.second, + microsecond=current_time.microsecond, + ) + + # Calculate the duration until the first day of the next month + duration_until_next_month = next_month - current_time + return int(duration_until_next_month.total_seconds()) + + else: + raise ValueError(f"Unsupported duration unit, passed duration: {duration}") diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py index c69e255f2..c41975f50 100644 --- a/litellm/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -30,7 +30,7 @@ from litellm._logging import verbose_proxy_logger from litellm.proxy._types import * from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.management_endpoints.key_management_endpoints import ( - _duration_in_seconds, + duration_in_seconds, generate_key_helper_fn, ) from litellm.proxy.management_helpers.utils import ( @@ -516,7 +516,7 @@ async def user_update( is_internal_user = True if "budget_duration" in non_default_values: - duration_s = _duration_in_seconds( + duration_s = duration_in_seconds( duration=non_default_values["budget_duration"] ) user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) @@ -535,7 +535,7 @@ async def user_update( non_default_values["budget_duration"] = ( litellm.internal_user_budget_duration ) - duration_s = _duration_in_seconds( + duration_s = duration_in_seconds( duration=non_default_values["budget_duration"] ) user_reset_at = datetime.now(timezone.utc) + timedelta( @@ -725,8 +725,8 @@ async def delete_user( - user_ids: List[str] - The list of user id's to be deleted. """ from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, user_api_key_cache, diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 511e5a940..8353b5d91 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -34,8 +34,8 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks from litellm.proxy.management_helpers.utils import management_endpoint_wrapper from litellm.proxy.utils import ( - _duration_in_seconds, _hash_token_if_needed, + duration_in_seconds, handle_exception_on_proxy, ) from litellm.secret_managers.main import get_secret @@ -321,10 +321,10 @@ async def generate_key_fn( # noqa: PLR0915 ) # Compare durations elif key in ["budget_duration", "duration"]: - upperbound_duration = _duration_in_seconds( + upperbound_duration = duration_in_seconds( duration=upperbound_value ) - user_duration = _duration_in_seconds(duration=value) + user_duration = duration_in_seconds(duration=value) if user_duration > upperbound_duration: raise HTTPException( status_code=400, @@ -421,7 +421,7 @@ def prepare_key_update_data( if "duration" in non_default_values: duration = non_default_values.pop("duration") if duration and (isinstance(duration, str)) and len(duration) > 0: - duration_s = _duration_in_seconds(duration=duration) + duration_s = duration_in_seconds(duration=duration) expires = datetime.now(timezone.utc) + timedelta(seconds=duration_s) non_default_values["expires"] = expires @@ -432,7 +432,7 @@ def prepare_key_update_data( and (isinstance(budget_duration, str)) and len(budget_duration) > 0 ): - duration_s = _duration_in_seconds(duration=budget_duration) + duration_s = duration_in_seconds(duration=budget_duration) key_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) non_default_values["budget_reset_at"] = key_reset_at @@ -932,19 +932,19 @@ async def generate_key_helper_fn( # noqa: PLR0915 if duration is None: # allow tokens that never expire expires = None else: - duration_s = _duration_in_seconds(duration=duration) + duration_s = duration_in_seconds(duration=duration) expires = datetime.now(timezone.utc) + timedelta(seconds=duration_s) if key_budget_duration is None: # one-time budget key_reset_at = None else: - duration_s = _duration_in_seconds(duration=key_budget_duration) + duration_s = duration_in_seconds(duration=key_budget_duration) key_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) if budget_duration is None: # one-time budget reset_at = None else: - duration_s = _duration_in_seconds(duration=budget_duration) + duration_s = duration_in_seconds(duration=budget_duration) reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) aliases_json = json.dumps(aliases) diff --git a/litellm/proxy/management_endpoints/team_callback_endpoints.py b/litellm/proxy/management_endpoints/team_callback_endpoints.py index b60ea3acf..6c5fa80a2 100644 --- a/litellm/proxy/management_endpoints/team_callback_endpoints.py +++ b/litellm/proxy/management_endpoints/team_callback_endpoints.py @@ -90,8 +90,8 @@ async def add_team_callbacks( """ try: from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, ) diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index dc1ec444d..575e71119 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -169,8 +169,8 @@ async def new_team( # noqa: PLR0915 ``` """ from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, ) @@ -289,7 +289,7 @@ async def new_team( # noqa: PLR0915 # If budget_duration is set, set `budget_reset_at` if complete_team_data.budget_duration is not None: - duration_s = _duration_in_seconds(duration=complete_team_data.budget_duration) + duration_s = duration_in_seconds(duration=complete_team_data.budget_duration) reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) complete_team_data.budget_reset_at = reset_at @@ -396,8 +396,8 @@ async def update_team( """ from litellm.proxy.auth.auth_checks import _cache_team_object from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, proxy_logging_obj, @@ -425,7 +425,7 @@ async def update_team( # Check budget_duration and budget_reset_at if data.budget_duration is not None: - duration_s = _duration_in_seconds(duration=data.budget_duration) + duration_s = duration_in_seconds(duration=data.budget_duration) reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) # set the budget_reset_at in DB @@ -709,8 +709,8 @@ async def team_member_delete( ``` """ from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, ) @@ -829,8 +829,8 @@ async def team_member_update( Update team member budgets """ from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, ) @@ -965,8 +965,8 @@ async def delete_team( ``` """ from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, ) @@ -1054,8 +1054,8 @@ async def team_info( ``` """ from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, ) @@ -1203,8 +1203,8 @@ async def block_team( """ from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, ) @@ -1251,8 +1251,8 @@ async def unblock_team( ``` """ from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, ) @@ -1294,8 +1294,8 @@ async def list_team( - user_id: str - Optional. If passed will only return teams that the user_id is a member of. """ from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, ) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 70bf5b523..4a867f46a 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -182,8 +182,8 @@ from litellm.proxy.management_endpoints.internal_user_endpoints import ( ) from litellm.proxy.management_endpoints.internal_user_endpoints import user_update from litellm.proxy.management_endpoints.key_management_endpoints import ( - _duration_in_seconds, delete_verification_token, + duration_in_seconds, generate_key_helper_fn, ) from litellm.proxy.management_endpoints.key_management_endpoints import ( diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 0f7d6c3e0..a75a09b29 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -26,6 +26,7 @@ from typing import ( overload, ) +from litellm.litellm_core_utils.duration_parser import duration_in_seconds from litellm.proxy._types import ProxyErrorTypes, ProxyException try: @@ -2429,86 +2430,6 @@ def _hash_token_if_needed(token: str) -> str: return token -def _extract_from_regex(duration: str) -> Tuple[int, str]: - match = re.match(r"(\d+)(mo|[smhd]?)", duration) - - if not match: - raise ValueError("Invalid duration format") - - value, unit = match.groups() - value = int(value) - - return value, unit - - -def get_last_day_of_month(year, month): - # Handle December case - if month == 12: - return 31 - # Next month is January, so subtract a day from March 1st - next_month = datetime(year=year, month=month + 1, day=1) - last_day_of_month = (next_month - timedelta(days=1)).day - return last_day_of_month - - -def _duration_in_seconds(duration: str) -> int: - """ - Parameters: - - duration: - - "s" - seconds - - "m" - minutes - - "h" - hours - - "d" - days - - "mo" - months - - Returns time in seconds till when budget needs to be reset - """ - value, unit = _extract_from_regex(duration=duration) - - if unit == "s": - return value - elif unit == "m": - return value * 60 - elif unit == "h": - return value * 3600 - elif unit == "d": - return value * 86400 - elif unit == "mo": - now = time.time() - current_time = datetime.fromtimestamp(now) - - if current_time.month == 12: - target_year = current_time.year + 1 - target_month = 1 - else: - target_year = current_time.year - target_month = current_time.month + value - - # Determine the day to set for next month - target_day = current_time.day - last_day_of_target_month = get_last_day_of_month(target_year, target_month) - - if target_day > last_day_of_target_month: - target_day = last_day_of_target_month - - next_month = datetime( - year=target_year, - month=target_month, - day=target_day, - hour=current_time.hour, - minute=current_time.minute, - second=current_time.second, - microsecond=current_time.microsecond, - ) - - # Calculate the duration until the first day of the next month - duration_until_next_month = next_month - current_time - return int(duration_until_next_month.total_seconds()) - - else: - raise ValueError("Unsupported duration unit") - - async def reset_budget(prisma_client: PrismaClient): """ Gets all the non-expired keys for a db, which need spend to be reset @@ -2527,7 +2448,7 @@ async def reset_budget(prisma_client: PrismaClient): if keys_to_reset is not None and len(keys_to_reset) > 0: for key in keys_to_reset: key.spend = 0.0 - duration_s = _duration_in_seconds(duration=key.budget_duration) + duration_s = duration_in_seconds(duration=key.budget_duration) key.budget_reset_at = now + timedelta(seconds=duration_s) await prisma_client.update_data( @@ -2543,7 +2464,7 @@ async def reset_budget(prisma_client: PrismaClient): if users_to_reset is not None and len(users_to_reset) > 0: for user in users_to_reset: user.spend = 0.0 - duration_s = _duration_in_seconds(duration=user.budget_duration) + duration_s = duration_in_seconds(duration=user.budget_duration) user.budget_reset_at = now + timedelta(seconds=duration_s) await prisma_client.update_data( @@ -2561,7 +2482,7 @@ async def reset_budget(prisma_client: PrismaClient): if teams_to_reset is not None and len(teams_to_reset) > 0: team_reset_requests = [] for team in teams_to_reset: - duration_s = _duration_in_seconds(duration=team.budget_duration) + duration_s = duration_in_seconds(duration=team.budget_duration) reset_team_budget_request = ResetTeamBudgetRequest( team_id=team.team_id, spend=0.0, diff --git a/tests/local_testing/test_utils.py b/tests/local_testing/test_utils.py index cf1db27e8..70a7eff59 100644 --- a/tests/local_testing/test_utils.py +++ b/tests/local_testing/test_utils.py @@ -17,7 +17,7 @@ import pytest import litellm from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, headers from litellm.proxy.utils import ( - _duration_in_seconds, + duration_in_seconds, _extract_from_regex, get_last_day_of_month, ) @@ -556,7 +556,7 @@ def test_extract_from_regex(duration, unit): assert _unit == unit -def test_duration_in_seconds(): +def testduration_in_seconds(): """ Test if duration int is correctly calculated for different str """ @@ -593,7 +593,7 @@ def test_duration_in_seconds(): duration_until_next_month = next_month - current_time expected_duration = int(duration_until_next_month.total_seconds()) - value = _duration_in_seconds(duration="1mo") + value = duration_in_seconds(duration="1mo") assert value - expected_duration < 2