(QOL improvement) Provider budget routing - allow using 1s, 1d, 1mo, 2mo etc (#6885)

* use 1 file for duration_in_seconds

* add to readme.md

* re use duration_in_seconds

* fix importing _extract_from_regex, get_last_day_of_month

* fix import

* update provider budget routing

* fix - remove dup test
This commit is contained in:
Ishaan Jaff 2024-11-23 16:59:46 -08:00 committed by GitHub
parent e69678a9b3
commit 34bfebe470
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 149 additions and 142 deletions

View file

@ -26,6 +26,11 @@ from typing import (
overload,
)
from litellm.litellm_core_utils.duration_parser import (
_extract_from_regex,
duration_in_seconds,
get_last_day_of_month,
)
from litellm.proxy._types import ProxyErrorTypes, ProxyException
try:
@ -2429,86 +2434,6 @@ def _hash_token_if_needed(token: str) -> str:
return token
def _extract_from_regex(duration: str) -> Tuple[int, str]:
match = re.match(r"(\d+)(mo|[smhd]?)", duration)
if not match:
raise ValueError("Invalid duration format")
value, unit = match.groups()
value = int(value)
return value, unit
def get_last_day_of_month(year, month):
# Handle December case
if month == 12:
return 31
# Next month is January, so subtract a day from March 1st
next_month = datetime(year=year, month=month + 1, day=1)
last_day_of_month = (next_month - timedelta(days=1)).day
return last_day_of_month
def _duration_in_seconds(duration: str) -> int:
"""
Parameters:
- duration:
- "<number>s" - seconds
- "<number>m" - minutes
- "<number>h" - hours
- "<number>d" - days
- "<number>mo" - months
Returns time in seconds till when budget needs to be reset
"""
value, unit = _extract_from_regex(duration=duration)
if unit == "s":
return value
elif unit == "m":
return value * 60
elif unit == "h":
return value * 3600
elif unit == "d":
return value * 86400
elif unit == "mo":
now = time.time()
current_time = datetime.fromtimestamp(now)
if current_time.month == 12:
target_year = current_time.year + 1
target_month = 1
else:
target_year = current_time.year
target_month = current_time.month + value
# Determine the day to set for next month
target_day = current_time.day
last_day_of_target_month = get_last_day_of_month(target_year, target_month)
if target_day > last_day_of_target_month:
target_day = last_day_of_target_month
next_month = datetime(
year=target_year,
month=target_month,
day=target_day,
hour=current_time.hour,
minute=current_time.minute,
second=current_time.second,
microsecond=current_time.microsecond,
)
# Calculate the duration until the first day of the next month
duration_until_next_month = next_month - current_time
return int(duration_until_next_month.total_seconds())
else:
raise ValueError("Unsupported duration unit")
async def reset_budget(prisma_client: PrismaClient):
"""
Gets all the non-expired keys for a db, which need spend to be reset
@ -2527,7 +2452,7 @@ async def reset_budget(prisma_client: PrismaClient):
if keys_to_reset is not None and len(keys_to_reset) > 0:
for key in keys_to_reset:
key.spend = 0.0
duration_s = _duration_in_seconds(duration=key.budget_duration)
duration_s = duration_in_seconds(duration=key.budget_duration)
key.budget_reset_at = now + timedelta(seconds=duration_s)
await prisma_client.update_data(
@ -2543,7 +2468,7 @@ async def reset_budget(prisma_client: PrismaClient):
if users_to_reset is not None and len(users_to_reset) > 0:
for user in users_to_reset:
user.spend = 0.0
duration_s = _duration_in_seconds(duration=user.budget_duration)
duration_s = duration_in_seconds(duration=user.budget_duration)
user.budget_reset_at = now + timedelta(seconds=duration_s)
await prisma_client.update_data(
@ -2561,7 +2486,7 @@ async def reset_budget(prisma_client: PrismaClient):
if teams_to_reset is not None and len(teams_to_reset) > 0:
team_reset_requests = []
for team in teams_to_reset:
duration_s = _duration_in_seconds(duration=team.budget_duration)
duration_s = duration_in_seconds(duration=team.budget_duration)
reset_team_budget_request = ResetTeamBudgetRequest(
team_id=team.team_id,
spend=0.0,