use 1 file for duration_in_seconds

This commit is contained in:
Ishaan Jaff 2024-11-23 12:42:33 -08:00
parent 50314a66ca
commit 37462ea55c
8 changed files with 124 additions and 111 deletions

View file

@ -0,0 +1,92 @@
"""
Helper utilities for parsing durations - 1s, 1d, 10d, 30d, 1mo, 2mo
duration_in_seconds is used in diff parts of the code base, example
- Router - Provider budget routing
- Proxy - Key, Team Generation
"""
import re
import time
from datetime import datetime, timedelta
from typing import Tuple
def _extract_from_regex(duration: str) -> Tuple[int, str]:
match = re.match(r"(\d+)(mo|[smhd]?)", duration)
if not match:
raise ValueError("Invalid duration format")
value, unit = match.groups()
value = int(value)
return value, unit
def get_last_day_of_month(year, month):
# Handle December case
if month == 12:
return 31
# Next month is January, so subtract a day from March 1st
next_month = datetime(year=year, month=month + 1, day=1)
last_day_of_month = (next_month - timedelta(days=1)).day
return last_day_of_month
def duration_in_seconds(duration: str) -> int:
"""
Parameters:
- duration:
- "<number>s" - seconds
- "<number>m" - minutes
- "<number>h" - hours
- "<number>d" - days
- "<number>mo" - months
Returns time in seconds till when budget needs to be reset
"""
value, unit = _extract_from_regex(duration=duration)
if unit == "s":
return value
elif unit == "m":
return value * 60
elif unit == "h":
return value * 3600
elif unit == "d":
return value * 86400
elif unit == "mo":
now = time.time()
current_time = datetime.fromtimestamp(now)
if current_time.month == 12:
target_year = current_time.year + 1
target_month = 1
else:
target_year = current_time.year
target_month = current_time.month + value
# Determine the day to set for next month
target_day = current_time.day
last_day_of_target_month = get_last_day_of_month(target_year, target_month)
if target_day > last_day_of_target_month:
target_day = last_day_of_target_month
next_month = datetime(
year=target_year,
month=target_month,
day=target_day,
hour=current_time.hour,
minute=current_time.minute,
second=current_time.second,
microsecond=current_time.microsecond,
)
# Calculate the duration until the first day of the next month
duration_until_next_month = next_month - current_time
return int(duration_until_next_month.total_seconds())
else:
raise ValueError(f"Unsupported duration unit, passed duration: {duration}")

View file

@ -30,7 +30,7 @@ from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.management_endpoints.key_management_endpoints import (
_duration_in_seconds,
duration_in_seconds,
generate_key_helper_fn,
)
from litellm.proxy.management_helpers.utils import (
@ -516,7 +516,7 @@ async def user_update(
is_internal_user = True
if "budget_duration" in non_default_values:
duration_s = _duration_in_seconds(
duration_s = duration_in_seconds(
duration=non_default_values["budget_duration"]
)
user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
@ -535,7 +535,7 @@ async def user_update(
non_default_values["budget_duration"] = (
litellm.internal_user_budget_duration
)
duration_s = _duration_in_seconds(
duration_s = duration_in_seconds(
duration=non_default_values["budget_duration"]
)
user_reset_at = datetime.now(timezone.utc) + timedelta(
@ -725,8 +725,8 @@ async def delete_user(
- user_ids: List[str] - The list of user id's to be deleted.
"""
from litellm.proxy.proxy_server import (
_duration_in_seconds,
create_audit_log_for_update,
duration_in_seconds,
litellm_proxy_admin_name,
prisma_client,
user_api_key_cache,

View file

@ -34,8 +34,8 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
from litellm.proxy.utils import (
_duration_in_seconds,
_hash_token_if_needed,
duration_in_seconds,
handle_exception_on_proxy,
)
from litellm.secret_managers.main import get_secret
@ -321,10 +321,10 @@ async def generate_key_fn( # noqa: PLR0915
)
# Compare durations
elif key in ["budget_duration", "duration"]:
upperbound_duration = _duration_in_seconds(
upperbound_duration = duration_in_seconds(
duration=upperbound_value
)
user_duration = _duration_in_seconds(duration=value)
user_duration = duration_in_seconds(duration=value)
if user_duration > upperbound_duration:
raise HTTPException(
status_code=400,
@ -421,7 +421,7 @@ def prepare_key_update_data(
if "duration" in non_default_values:
duration = non_default_values.pop("duration")
if duration and (isinstance(duration, str)) and len(duration) > 0:
duration_s = _duration_in_seconds(duration=duration)
duration_s = duration_in_seconds(duration=duration)
expires = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
non_default_values["expires"] = expires
@ -432,7 +432,7 @@ def prepare_key_update_data(
and (isinstance(budget_duration, str))
and len(budget_duration) > 0
):
duration_s = _duration_in_seconds(duration=budget_duration)
duration_s = duration_in_seconds(duration=budget_duration)
key_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
non_default_values["budget_reset_at"] = key_reset_at
@ -932,19 +932,19 @@ async def generate_key_helper_fn( # noqa: PLR0915
if duration is None: # allow tokens that never expire
expires = None
else:
duration_s = _duration_in_seconds(duration=duration)
duration_s = duration_in_seconds(duration=duration)
expires = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
if key_budget_duration is None: # one-time budget
key_reset_at = None
else:
duration_s = _duration_in_seconds(duration=key_budget_duration)
duration_s = duration_in_seconds(duration=key_budget_duration)
key_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
if budget_duration is None: # one-time budget
reset_at = None
else:
duration_s = _duration_in_seconds(duration=budget_duration)
duration_s = duration_in_seconds(duration=budget_duration)
reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
aliases_json = json.dumps(aliases)

View file

@ -90,8 +90,8 @@ async def add_team_callbacks(
"""
try:
from litellm.proxy.proxy_server import (
_duration_in_seconds,
create_audit_log_for_update,
duration_in_seconds,
litellm_proxy_admin_name,
prisma_client,
)

View file

@ -169,8 +169,8 @@ async def new_team( # noqa: PLR0915
```
"""
from litellm.proxy.proxy_server import (
_duration_in_seconds,
create_audit_log_for_update,
duration_in_seconds,
litellm_proxy_admin_name,
prisma_client,
)
@ -289,7 +289,7 @@ async def new_team( # noqa: PLR0915
# If budget_duration is set, set `budget_reset_at`
if complete_team_data.budget_duration is not None:
duration_s = _duration_in_seconds(duration=complete_team_data.budget_duration)
duration_s = duration_in_seconds(duration=complete_team_data.budget_duration)
reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
complete_team_data.budget_reset_at = reset_at
@ -396,8 +396,8 @@ async def update_team(
"""
from litellm.proxy.auth.auth_checks import _cache_team_object
from litellm.proxy.proxy_server import (
_duration_in_seconds,
create_audit_log_for_update,
duration_in_seconds,
litellm_proxy_admin_name,
prisma_client,
proxy_logging_obj,
@ -425,7 +425,7 @@ async def update_team(
# Check budget_duration and budget_reset_at
if data.budget_duration is not None:
duration_s = _duration_in_seconds(duration=data.budget_duration)
duration_s = duration_in_seconds(duration=data.budget_duration)
reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
# set the budget_reset_at in DB
@ -709,8 +709,8 @@ async def team_member_delete(
```
"""
from litellm.proxy.proxy_server import (
_duration_in_seconds,
create_audit_log_for_update,
duration_in_seconds,
litellm_proxy_admin_name,
prisma_client,
)
@ -829,8 +829,8 @@ async def team_member_update(
Update team member budgets
"""
from litellm.proxy.proxy_server import (
_duration_in_seconds,
create_audit_log_for_update,
duration_in_seconds,
litellm_proxy_admin_name,
prisma_client,
)
@ -965,8 +965,8 @@ async def delete_team(
```
"""
from litellm.proxy.proxy_server import (
_duration_in_seconds,
create_audit_log_for_update,
duration_in_seconds,
litellm_proxy_admin_name,
prisma_client,
)
@ -1054,8 +1054,8 @@ async def team_info(
```
"""
from litellm.proxy.proxy_server import (
_duration_in_seconds,
create_audit_log_for_update,
duration_in_seconds,
litellm_proxy_admin_name,
prisma_client,
)
@ -1203,8 +1203,8 @@ async def block_team(
"""
from litellm.proxy.proxy_server import (
_duration_in_seconds,
create_audit_log_for_update,
duration_in_seconds,
litellm_proxy_admin_name,
prisma_client,
)
@ -1251,8 +1251,8 @@ async def unblock_team(
```
"""
from litellm.proxy.proxy_server import (
_duration_in_seconds,
create_audit_log_for_update,
duration_in_seconds,
litellm_proxy_admin_name,
prisma_client,
)
@ -1294,8 +1294,8 @@ async def list_team(
- user_id: str - Optional. If passed will only return teams that the user_id is a member of.
"""
from litellm.proxy.proxy_server import (
_duration_in_seconds,
create_audit_log_for_update,
duration_in_seconds,
litellm_proxy_admin_name,
prisma_client,
)

View file

@ -182,8 +182,8 @@ from litellm.proxy.management_endpoints.internal_user_endpoints import (
)
from litellm.proxy.management_endpoints.internal_user_endpoints import user_update
from litellm.proxy.management_endpoints.key_management_endpoints import (
_duration_in_seconds,
delete_verification_token,
duration_in_seconds,
generate_key_helper_fn,
)
from litellm.proxy.management_endpoints.key_management_endpoints import (

View file

@ -26,6 +26,7 @@ from typing import (
overload,
)
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
from litellm.proxy._types import ProxyErrorTypes, ProxyException
try:
@ -2429,86 +2430,6 @@ def _hash_token_if_needed(token: str) -> str:
return token
def _extract_from_regex(duration: str) -> Tuple[int, str]:
match = re.match(r"(\d+)(mo|[smhd]?)", duration)
if not match:
raise ValueError("Invalid duration format")
value, unit = match.groups()
value = int(value)
return value, unit
def get_last_day_of_month(year, month):
# Handle December case
if month == 12:
return 31
# Next month is January, so subtract a day from March 1st
next_month = datetime(year=year, month=month + 1, day=1)
last_day_of_month = (next_month - timedelta(days=1)).day
return last_day_of_month
def _duration_in_seconds(duration: str) -> int:
"""
Parameters:
- duration:
- "<number>s" - seconds
- "<number>m" - minutes
- "<number>h" - hours
- "<number>d" - days
- "<number>mo" - months
Returns time in seconds till when budget needs to be reset
"""
value, unit = _extract_from_regex(duration=duration)
if unit == "s":
return value
elif unit == "m":
return value * 60
elif unit == "h":
return value * 3600
elif unit == "d":
return value * 86400
elif unit == "mo":
now = time.time()
current_time = datetime.fromtimestamp(now)
if current_time.month == 12:
target_year = current_time.year + 1
target_month = 1
else:
target_year = current_time.year
target_month = current_time.month + value
# Determine the day to set for next month
target_day = current_time.day
last_day_of_target_month = get_last_day_of_month(target_year, target_month)
if target_day > last_day_of_target_month:
target_day = last_day_of_target_month
next_month = datetime(
year=target_year,
month=target_month,
day=target_day,
hour=current_time.hour,
minute=current_time.minute,
second=current_time.second,
microsecond=current_time.microsecond,
)
# Calculate the duration until the first day of the next month
duration_until_next_month = next_month - current_time
return int(duration_until_next_month.total_seconds())
else:
raise ValueError("Unsupported duration unit")
async def reset_budget(prisma_client: PrismaClient):
"""
Gets all the non-expired keys for a db, which need spend to be reset
@ -2527,7 +2448,7 @@ async def reset_budget(prisma_client: PrismaClient):
if keys_to_reset is not None and len(keys_to_reset) > 0:
for key in keys_to_reset:
key.spend = 0.0
duration_s = _duration_in_seconds(duration=key.budget_duration)
duration_s = duration_in_seconds(duration=key.budget_duration)
key.budget_reset_at = now + timedelta(seconds=duration_s)
await prisma_client.update_data(
@ -2543,7 +2464,7 @@ async def reset_budget(prisma_client: PrismaClient):
if users_to_reset is not None and len(users_to_reset) > 0:
for user in users_to_reset:
user.spend = 0.0
duration_s = _duration_in_seconds(duration=user.budget_duration)
duration_s = duration_in_seconds(duration=user.budget_duration)
user.budget_reset_at = now + timedelta(seconds=duration_s)
await prisma_client.update_data(
@ -2561,7 +2482,7 @@ async def reset_budget(prisma_client: PrismaClient):
if teams_to_reset is not None and len(teams_to_reset) > 0:
team_reset_requests = []
for team in teams_to_reset:
duration_s = _duration_in_seconds(duration=team.budget_duration)
duration_s = duration_in_seconds(duration=team.budget_duration)
reset_team_budget_request = ResetTeamBudgetRequest(
team_id=team.team_id,
spend=0.0,

View file

@ -17,7 +17,7 @@ import pytest
import litellm
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, headers
from litellm.proxy.utils import (
_duration_in_seconds,
duration_in_seconds,
_extract_from_regex,
get_last_day_of_month,
)
@ -556,7 +556,7 @@ def test_extract_from_regex(duration, unit):
assert _unit == unit
def test_duration_in_seconds():
def testduration_in_seconds():
"""
Test if duration int is correctly calculated for different str
"""
@ -593,7 +593,7 @@ def test_duration_in_seconds():
duration_until_next_month = next_month - current_time
expected_duration = int(duration_until_next_month.total_seconds())
value = _duration_in_seconds(duration="1mo")
value = duration_in_seconds(duration="1mo")
assert value - expected_duration < 2