From 37462ea55c1a63350fbd97bafba98d23384af429 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 23 Nov 2024 12:42:33 -0800 Subject: [PATCH 01/29] use 1 file for duration_in_seconds --- litellm/litellm_core_utils/duration_parser.py | 92 +++++++++++++++++++ .../internal_user_endpoints.py | 8 +- .../key_management_endpoints.py | 16 ++-- .../team_callback_endpoints.py | 2 +- .../management_endpoints/team_endpoints.py | 22 ++--- litellm/proxy/proxy_server.py | 2 +- litellm/proxy/utils.py | 87 +----------------- tests/local_testing/test_utils.py | 6 +- 8 files changed, 124 insertions(+), 111 deletions(-) create mode 100644 litellm/litellm_core_utils/duration_parser.py diff --git a/litellm/litellm_core_utils/duration_parser.py b/litellm/litellm_core_utils/duration_parser.py new file mode 100644 index 000000000..c8c6bea83 --- /dev/null +++ b/litellm/litellm_core_utils/duration_parser.py @@ -0,0 +1,92 @@ +""" +Helper utilities for parsing durations - 1s, 1d, 10d, 30d, 1mo, 2mo + +duration_in_seconds is used in diff parts of the code base, example +- Router - Provider budget routing +- Proxy - Key, Team Generation +""" + +import re +import time +from datetime import datetime, timedelta +from typing import Tuple + + +def _extract_from_regex(duration: str) -> Tuple[int, str]: + match = re.match(r"(\d+)(mo|[smhd]?)", duration) + + if not match: + raise ValueError("Invalid duration format") + + value, unit = match.groups() + value = int(value) + + return value, unit + + +def get_last_day_of_month(year, month): + # Handle December case + if month == 12: + return 31 + # Next month is January, so subtract a day from March 1st + next_month = datetime(year=year, month=month + 1, day=1) + last_day_of_month = (next_month - timedelta(days=1)).day + return last_day_of_month + + +def duration_in_seconds(duration: str) -> int: + """ + Parameters: + - duration: + - "s" - seconds + - "m" - minutes + - "h" - hours + - "d" - days + - "mo" - months + + Returns time in seconds till when budget needs to be reset + """ + value, unit = _extract_from_regex(duration=duration) + + if unit == "s": + return value + elif unit == "m": + return value * 60 + elif unit == "h": + return value * 3600 + elif unit == "d": + return value * 86400 + elif unit == "mo": + now = time.time() + current_time = datetime.fromtimestamp(now) + + if current_time.month == 12: + target_year = current_time.year + 1 + target_month = 1 + else: + target_year = current_time.year + target_month = current_time.month + value + + # Determine the day to set for next month + target_day = current_time.day + last_day_of_target_month = get_last_day_of_month(target_year, target_month) + + if target_day > last_day_of_target_month: + target_day = last_day_of_target_month + + next_month = datetime( + year=target_year, + month=target_month, + day=target_day, + hour=current_time.hour, + minute=current_time.minute, + second=current_time.second, + microsecond=current_time.microsecond, + ) + + # Calculate the duration until the first day of the next month + duration_until_next_month = next_month - current_time + return int(duration_until_next_month.total_seconds()) + + else: + raise ValueError(f"Unsupported duration unit, passed duration: {duration}") diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py index c69e255f2..c41975f50 100644 --- a/litellm/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -30,7 +30,7 @@ from litellm._logging import verbose_proxy_logger from litellm.proxy._types import * from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.management_endpoints.key_management_endpoints import ( - _duration_in_seconds, + duration_in_seconds, generate_key_helper_fn, ) from litellm.proxy.management_helpers.utils import ( @@ -516,7 +516,7 @@ async def user_update( is_internal_user = True if "budget_duration" in non_default_values: - duration_s = _duration_in_seconds( + duration_s = duration_in_seconds( duration=non_default_values["budget_duration"] ) user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) @@ -535,7 +535,7 @@ async def user_update( non_default_values["budget_duration"] = ( litellm.internal_user_budget_duration ) - duration_s = _duration_in_seconds( + duration_s = duration_in_seconds( duration=non_default_values["budget_duration"] ) user_reset_at = datetime.now(timezone.utc) + timedelta( @@ -725,8 +725,8 @@ async def delete_user( - user_ids: List[str] - The list of user id's to be deleted. """ from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, user_api_key_cache, diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 511e5a940..8353b5d91 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -34,8 +34,8 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks from litellm.proxy.management_helpers.utils import management_endpoint_wrapper from litellm.proxy.utils import ( - _duration_in_seconds, _hash_token_if_needed, + duration_in_seconds, handle_exception_on_proxy, ) from litellm.secret_managers.main import get_secret @@ -321,10 +321,10 @@ async def generate_key_fn( # noqa: PLR0915 ) # Compare durations elif key in ["budget_duration", "duration"]: - upperbound_duration = _duration_in_seconds( + upperbound_duration = duration_in_seconds( duration=upperbound_value ) - user_duration = _duration_in_seconds(duration=value) + user_duration = duration_in_seconds(duration=value) if user_duration > upperbound_duration: raise HTTPException( status_code=400, @@ -421,7 +421,7 @@ def prepare_key_update_data( if "duration" in non_default_values: duration = non_default_values.pop("duration") if duration and (isinstance(duration, str)) and len(duration) > 0: - duration_s = _duration_in_seconds(duration=duration) + duration_s = duration_in_seconds(duration=duration) expires = datetime.now(timezone.utc) + timedelta(seconds=duration_s) non_default_values["expires"] = expires @@ -432,7 +432,7 @@ def prepare_key_update_data( and (isinstance(budget_duration, str)) and len(budget_duration) > 0 ): - duration_s = _duration_in_seconds(duration=budget_duration) + duration_s = duration_in_seconds(duration=budget_duration) key_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) non_default_values["budget_reset_at"] = key_reset_at @@ -932,19 +932,19 @@ async def generate_key_helper_fn( # noqa: PLR0915 if duration is None: # allow tokens that never expire expires = None else: - duration_s = _duration_in_seconds(duration=duration) + duration_s = duration_in_seconds(duration=duration) expires = datetime.now(timezone.utc) + timedelta(seconds=duration_s) if key_budget_duration is None: # one-time budget key_reset_at = None else: - duration_s = _duration_in_seconds(duration=key_budget_duration) + duration_s = duration_in_seconds(duration=key_budget_duration) key_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) if budget_duration is None: # one-time budget reset_at = None else: - duration_s = _duration_in_seconds(duration=budget_duration) + duration_s = duration_in_seconds(duration=budget_duration) reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) aliases_json = json.dumps(aliases) diff --git a/litellm/proxy/management_endpoints/team_callback_endpoints.py b/litellm/proxy/management_endpoints/team_callback_endpoints.py index b60ea3acf..6c5fa80a2 100644 --- a/litellm/proxy/management_endpoints/team_callback_endpoints.py +++ b/litellm/proxy/management_endpoints/team_callback_endpoints.py @@ -90,8 +90,8 @@ async def add_team_callbacks( """ try: from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, ) diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index dc1ec444d..575e71119 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -169,8 +169,8 @@ async def new_team( # noqa: PLR0915 ``` """ from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, ) @@ -289,7 +289,7 @@ async def new_team( # noqa: PLR0915 # If budget_duration is set, set `budget_reset_at` if complete_team_data.budget_duration is not None: - duration_s = _duration_in_seconds(duration=complete_team_data.budget_duration) + duration_s = duration_in_seconds(duration=complete_team_data.budget_duration) reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) complete_team_data.budget_reset_at = reset_at @@ -396,8 +396,8 @@ async def update_team( """ from litellm.proxy.auth.auth_checks import _cache_team_object from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, proxy_logging_obj, @@ -425,7 +425,7 @@ async def update_team( # Check budget_duration and budget_reset_at if data.budget_duration is not None: - duration_s = _duration_in_seconds(duration=data.budget_duration) + duration_s = duration_in_seconds(duration=data.budget_duration) reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) # set the budget_reset_at in DB @@ -709,8 +709,8 @@ async def team_member_delete( ``` """ from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, ) @@ -829,8 +829,8 @@ async def team_member_update( Update team member budgets """ from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, ) @@ -965,8 +965,8 @@ async def delete_team( ``` """ from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, ) @@ -1054,8 +1054,8 @@ async def team_info( ``` """ from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, ) @@ -1203,8 +1203,8 @@ async def block_team( """ from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, ) @@ -1251,8 +1251,8 @@ async def unblock_team( ``` """ from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, ) @@ -1294,8 +1294,8 @@ async def list_team( - user_id: str - Optional. If passed will only return teams that the user_id is a member of. """ from litellm.proxy.proxy_server import ( - _duration_in_seconds, create_audit_log_for_update, + duration_in_seconds, litellm_proxy_admin_name, prisma_client, ) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 70bf5b523..4a867f46a 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -182,8 +182,8 @@ from litellm.proxy.management_endpoints.internal_user_endpoints import ( ) from litellm.proxy.management_endpoints.internal_user_endpoints import user_update from litellm.proxy.management_endpoints.key_management_endpoints import ( - _duration_in_seconds, delete_verification_token, + duration_in_seconds, generate_key_helper_fn, ) from litellm.proxy.management_endpoints.key_management_endpoints import ( diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 0f7d6c3e0..a75a09b29 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -26,6 +26,7 @@ from typing import ( overload, ) +from litellm.litellm_core_utils.duration_parser import duration_in_seconds from litellm.proxy._types import ProxyErrorTypes, ProxyException try: @@ -2429,86 +2430,6 @@ def _hash_token_if_needed(token: str) -> str: return token -def _extract_from_regex(duration: str) -> Tuple[int, str]: - match = re.match(r"(\d+)(mo|[smhd]?)", duration) - - if not match: - raise ValueError("Invalid duration format") - - value, unit = match.groups() - value = int(value) - - return value, unit - - -def get_last_day_of_month(year, month): - # Handle December case - if month == 12: - return 31 - # Next month is January, so subtract a day from March 1st - next_month = datetime(year=year, month=month + 1, day=1) - last_day_of_month = (next_month - timedelta(days=1)).day - return last_day_of_month - - -def _duration_in_seconds(duration: str) -> int: - """ - Parameters: - - duration: - - "s" - seconds - - "m" - minutes - - "h" - hours - - "d" - days - - "mo" - months - - Returns time in seconds till when budget needs to be reset - """ - value, unit = _extract_from_regex(duration=duration) - - if unit == "s": - return value - elif unit == "m": - return value * 60 - elif unit == "h": - return value * 3600 - elif unit == "d": - return value * 86400 - elif unit == "mo": - now = time.time() - current_time = datetime.fromtimestamp(now) - - if current_time.month == 12: - target_year = current_time.year + 1 - target_month = 1 - else: - target_year = current_time.year - target_month = current_time.month + value - - # Determine the day to set for next month - target_day = current_time.day - last_day_of_target_month = get_last_day_of_month(target_year, target_month) - - if target_day > last_day_of_target_month: - target_day = last_day_of_target_month - - next_month = datetime( - year=target_year, - month=target_month, - day=target_day, - hour=current_time.hour, - minute=current_time.minute, - second=current_time.second, - microsecond=current_time.microsecond, - ) - - # Calculate the duration until the first day of the next month - duration_until_next_month = next_month - current_time - return int(duration_until_next_month.total_seconds()) - - else: - raise ValueError("Unsupported duration unit") - - async def reset_budget(prisma_client: PrismaClient): """ Gets all the non-expired keys for a db, which need spend to be reset @@ -2527,7 +2448,7 @@ async def reset_budget(prisma_client: PrismaClient): if keys_to_reset is not None and len(keys_to_reset) > 0: for key in keys_to_reset: key.spend = 0.0 - duration_s = _duration_in_seconds(duration=key.budget_duration) + duration_s = duration_in_seconds(duration=key.budget_duration) key.budget_reset_at = now + timedelta(seconds=duration_s) await prisma_client.update_data( @@ -2543,7 +2464,7 @@ async def reset_budget(prisma_client: PrismaClient): if users_to_reset is not None and len(users_to_reset) > 0: for user in users_to_reset: user.spend = 0.0 - duration_s = _duration_in_seconds(duration=user.budget_duration) + duration_s = duration_in_seconds(duration=user.budget_duration) user.budget_reset_at = now + timedelta(seconds=duration_s) await prisma_client.update_data( @@ -2561,7 +2482,7 @@ async def reset_budget(prisma_client: PrismaClient): if teams_to_reset is not None and len(teams_to_reset) > 0: team_reset_requests = [] for team in teams_to_reset: - duration_s = _duration_in_seconds(duration=team.budget_duration) + duration_s = duration_in_seconds(duration=team.budget_duration) reset_team_budget_request = ResetTeamBudgetRequest( team_id=team.team_id, spend=0.0, diff --git a/tests/local_testing/test_utils.py b/tests/local_testing/test_utils.py index cf1db27e8..70a7eff59 100644 --- a/tests/local_testing/test_utils.py +++ b/tests/local_testing/test_utils.py @@ -17,7 +17,7 @@ import pytest import litellm from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, headers from litellm.proxy.utils import ( - _duration_in_seconds, + duration_in_seconds, _extract_from_regex, get_last_day_of_month, ) @@ -556,7 +556,7 @@ def test_extract_from_regex(duration, unit): assert _unit == unit -def test_duration_in_seconds(): +def testduration_in_seconds(): """ Test if duration int is correctly calculated for different str """ @@ -593,7 +593,7 @@ def test_duration_in_seconds(): duration_until_next_month = next_month - current_time expected_duration = int(duration_until_next_month.total_seconds()) - value = _duration_in_seconds(duration="1mo") + value = duration_in_seconds(duration="1mo") assert value - expected_duration < 2 From 653d16e158ef8a26df0010e0695fe7d5fd90d12a Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 23 Nov 2024 12:43:01 -0800 Subject: [PATCH 02/29] add to readme.md --- litellm/litellm_core_utils/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/litellm/litellm_core_utils/README.md b/litellm/litellm_core_utils/README.md index 9cd351453..649404129 100644 --- a/litellm/litellm_core_utils/README.md +++ b/litellm/litellm_core_utils/README.md @@ -8,4 +8,5 @@ Core files: - `exception_mapping_utils.py`: utils for mapping exceptions to openai-compatible error types. - `default_encoding.py`: code for loading the default encoding (tiktoken) - `get_llm_provider_logic.py`: code for inferring the LLM provider from a given model name. +- `duration_parser.py`: code for parsing durations - e.g. "1d", "1mo", "10s" From 2b9ff03cd051f2b33a9ee5684a18af131a5b6a5b Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 23 Nov 2024 12:44:28 -0800 Subject: [PATCH 03/29] re use duration_in_seconds --- litellm/router_strategy/provider_budgets.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/litellm/router_strategy/provider_budgets.py b/litellm/router_strategy/provider_budgets.py index 23d8b6c39..ea26d2c0f 100644 --- a/litellm/router_strategy/provider_budgets.py +++ b/litellm/router_strategy/provider_budgets.py @@ -25,6 +25,7 @@ from litellm._logging import verbose_router_logger from litellm.caching.caching import DualCache from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs +from litellm.litellm_core_utils.duration_parser import duration_in_seconds from litellm.router_utils.cooldown_callbacks import ( _get_prometheus_logger_from_callbacks, ) @@ -207,7 +208,7 @@ class ProviderBudgetLimiting(CustomLogger): ) spend_key = f"provider_spend:{custom_llm_provider}:{budget_config.time_period}" - ttl_seconds = self.get_ttl_seconds(budget_config.time_period) + ttl_seconds = duration_in_seconds(duration=budget_config.time_period) verbose_router_logger.debug( f"Incrementing spend for {spend_key} by {response_cost}, ttl: {ttl_seconds}" ) @@ -242,15 +243,6 @@ class ProviderBudgetLimiting(CustomLogger): return None return custom_llm_provider - def get_ttl_seconds(self, time_period: str) -> int: - """ - Convert time period (e.g., '1d', '30d') to seconds for Redis TTL - """ - if time_period.endswith("d"): - days = int(time_period[:-1]) - return days * 24 * 60 * 60 - raise ValueError(f"Unsupported time period format: {time_period}") - def _track_provider_remaining_budget_prometheus( self, provider: str, spend: float, budget_limit: float ): From c88048ae5cff38b95486a74bd85dc6380d408744 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 23 Nov 2024 12:46:49 -0800 Subject: [PATCH 04/29] fix importing _extract_from_regex, get_last_day_of_month --- tests/local_testing/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/local_testing/test_utils.py b/tests/local_testing/test_utils.py index 70a7eff59..7c349a658 100644 --- a/tests/local_testing/test_utils.py +++ b/tests/local_testing/test_utils.py @@ -556,7 +556,7 @@ def test_extract_from_regex(duration, unit): assert _unit == unit -def testduration_in_seconds(): +def test_duration_in_seconds(): """ Test if duration int is correctly calculated for different str """ From cf76f308de780303559bb9240f14de44f429e4de Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 23 Nov 2024 12:47:06 -0800 Subject: [PATCH 05/29] fix import --- litellm/proxy/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index a75a09b29..2a298af21 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -26,7 +26,11 @@ from typing import ( overload, ) -from litellm.litellm_core_utils.duration_parser import duration_in_seconds +from litellm.litellm_core_utils.duration_parser import ( + _extract_from_regex, + duration_in_seconds, + get_last_day_of_month, +) from litellm.proxy._types import ProxyErrorTypes, ProxyException try: From ac4ecce2bc5ac4a1586ff2f936202528cf6467d4 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 23 Nov 2024 12:49:13 -0800 Subject: [PATCH 06/29] update provider budget routing --- .../docs/proxy/provider_budget_routing.md | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/docs/my-website/docs/proxy/provider_budget_routing.md b/docs/my-website/docs/proxy/provider_budget_routing.md index fea3f483c..293f9e9d8 100644 --- a/docs/my-website/docs/proxy/provider_budget_routing.md +++ b/docs/my-website/docs/proxy/provider_budget_routing.md @@ -22,14 +22,14 @@ router_settings: provider_budget_config: openai: budget_limit: 0.000000000001 # float of $ value budget for time period - time_period: 1d # can be 1d, 2d, 30d + time_period: 1d # can be 1d, 2d, 30d, 1mo, 2mo azure: budget_limit: 100 time_period: 1d anthropic: budget_limit: 100 time_period: 10d - vertexai: + vertex_ai: budget_limit: 100 time_period: 12d gemini: @@ -112,8 +112,11 @@ Expected response on failure - If all providers exceed budget, raises an error 3. **Supported Time Periods**: - - Format: "Xd" where X is number of days - - Examples: "1d" (1 day), "30d" (30 days) + - Seconds: "Xs" (e.g., "30s") + - Minutes: "Xm" (e.g., "10m") + - Hours: "Xh" (e.g., "24h") + - Days: "Xd" (e.g., "1d", "30d") + - Months: "Xmo" (e.g., "1mo", "2mo") 4. **Requirements**: - Redis required for tracking spend across instances @@ -136,7 +139,12 @@ The `provider_budget_config` is a dictionary where: - **Key**: Provider name (string) - Must be a valid [LiteLLM provider name](https://docs.litellm.ai/docs/providers) - **Value**: Budget configuration object with the following parameters: - `budget_limit`: Float value representing the budget in USD - - `time_period`: String in the format "Xd" where X is the number of days (e.g., "1d", "30d") + - `time_period`: Duration string in one of the following formats: + - Seconds: `"Xs"` (e.g., "30s") + - Minutes: `"Xm"` (e.g., "10m") + - Hours: `"Xh"` (e.g., "24h") + - Days: `"Xd"` (e.g., "1d", "30d") + - Months: `"Xmo"` (e.g., "1mo", "2mo") Example structure: ```yaml @@ -147,4 +155,10 @@ provider_budget_config: azure: budget_limit: 500.0 # $500 USD time_period: "30d" # 30 day period + anthropic: + budget_limit: 200.0 # $200 USD + time_period: "1mo" # 1 month period + gemini: + budget_limit: 50.0 # $50 USD + time_period: "24h" # 24 hour period ``` \ No newline at end of file From 94e2e292cd33073d31578deff6f692cc754f918d Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 23 Nov 2024 13:27:56 -0800 Subject: [PATCH 07/29] fix - remove dup test --- .../test_router_provider_budgets.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/tests/local_testing/test_router_provider_budgets.py b/tests/local_testing/test_router_provider_budgets.py index 46b9ee29e..a6574ba4b 100644 --- a/tests/local_testing/test_router_provider_budgets.py +++ b/tests/local_testing/test_router_provider_budgets.py @@ -142,23 +142,6 @@ async def test_provider_budgets_e2e_test_expect_to_fail(): assert "Exceeded budget for provider" in str(exc_info.value) -def test_get_ttl_seconds(): - """ - Test the get_ttl_seconds helper method" - - """ - provider_budget = ProviderBudgetLimiting( - router_cache=DualCache(), provider_budget_config={} - ) - - assert provider_budget.get_ttl_seconds("1d") == 86400 # 1 day in seconds - assert provider_budget.get_ttl_seconds("7d") == 604800 # 7 days in seconds - assert provider_budget.get_ttl_seconds("30d") == 2592000 # 30 days in seconds - - with pytest.raises(ValueError, match="Unsupported time period format"): - provider_budget.get_ttl_seconds("1h") - - def test_get_llm_provider_for_deployment(): """ Test the _get_llm_provider_for_deployment helper method From 84395e7a1910edf82cb4b4d43479511f976c2775 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 23 Nov 2024 15:46:39 -0800 Subject: [PATCH 08/29] add support for using in multi instance environments --- .../docs/proxy/provider_budget_routing.md | 63 +++++++++++----- litellm/proxy/proxy_config.yaml | 25 +++++-- litellm/router_strategy/provider_budgets.py | 72 +++++++++++++++++++ 3 files changed, 138 insertions(+), 22 deletions(-) diff --git a/docs/my-website/docs/proxy/provider_budget_routing.md b/docs/my-website/docs/proxy/provider_budget_routing.md index 293f9e9d8..1cb75d667 100644 --- a/docs/my-website/docs/proxy/provider_budget_routing.md +++ b/docs/my-website/docs/proxy/provider_budget_routing.md @@ -16,25 +16,27 @@ model_list: api_key: os.environ/OPENAI_API_KEY router_settings: - redis_host: - redis_password: - redis_port: provider_budget_config: - openai: - budget_limit: 0.000000000001 # float of $ value budget for time period - time_period: 1d # can be 1d, 2d, 30d, 1mo, 2mo - azure: - budget_limit: 100 - time_period: 1d - anthropic: - budget_limit: 100 - time_period: 10d - vertex_ai: - budget_limit: 100 - time_period: 12d - gemini: - budget_limit: 100 - time_period: 12d + openai: + budget_limit: 0.000000000001 # float of $ value budget for time period + time_period: 1d # can be 1d, 2d, 30d, 1mo, 2mo + azure: + budget_limit: 100 + time_period: 1d + anthropic: + budget_limit: 100 + time_period: 10d + vertex_ai: + budget_limit: 100 + time_period: 12d + gemini: + budget_limit: 100 + time_period: 12d + + # OPTIONAL: Set Redis Host, Port, and Password if using multiple instance of LiteLLM + redis_host: os.environ/REDIS_HOST + redis_port: os.environ/REDIS_PORT + redis_password: os.environ/REDIS_PASSWORD general_settings: master_key: sk-1234 @@ -132,6 +134,31 @@ This metric indicates the remaining budget for a provider in dollars (USD) litellm_provider_remaining_budget_metric{api_provider="openai"} 10 ``` +## Multi-instance setup + +If you are using a multi-instance setup, you will need to set the Redis host, port, and password in the `proxy_config.yaml` file. Redis is used to sync the spend across LiteLLM instances. + +```yaml +model_list: + - model_name: gpt-3.5-turbo + litellm_params: + model: openai/gpt-3.5-turbo + api_key: os.environ/OPENAI_API_KEY + +router_settings: + provider_budget_config: + openai: + budget_limit: 0.000000000001 # float of $ value budget for time period + time_period: 1d # can be 1d, 2d, 30d, 1mo, 2mo + + # 👇 Add this: Set Redis Host, Port, and Password if using multiple instance of LiteLLM + redis_host: os.environ/REDIS_HOST + redis_port: os.environ/REDIS_PORT + redis_password: os.environ/REDIS_PASSWORD + +general_settings: + master_key: sk-1234 +``` ## Spec for provider_budget_config diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 956a17a75..f716585b3 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -2,8 +2,25 @@ model_list: - model_name: gpt-4o litellm_params: model: openai/gpt-4o - api_key: os.environ/OPENAI_API_KEY + api_base: https://exampleopenaiendpoint-production.up.railway.app/ + - model_name: fake-anthropic-endpoint + litellm_params: + model: anthropic/fake + api_base: https://exampleanthropicendpoint-production.up.railway.app/ + +router_settings: + provider_budget_config: + openai: + budget_limit: 1 # float of $ value budget for time period + time_period: 1d # can be 1d, 2d, 30d + anthropic: + budget_limit: 5 + time_period: 1d + redis_host: os.environ/REDIS_HOST + redis_port: os.environ/REDIS_PORT + redis_password: os.environ/REDIS_PASSWORD + +litellm_settings: + callbacks: ["prometheus"] + -default_vertex_config: - vertex_project: "adroit-crow-413218" - vertex_location: "us-central1" diff --git a/litellm/router_strategy/provider_budgets.py b/litellm/router_strategy/provider_budgets.py index ea26d2c0f..42e63b297 100644 --- a/litellm/router_strategy/provider_budgets.py +++ b/litellm/router_strategy/provider_budgets.py @@ -18,6 +18,7 @@ anthropic: ``` """ +import asyncio from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict, Union import litellm @@ -44,10 +45,13 @@ if TYPE_CHECKING: else: Span = Any +DEFAULT_REDIS_SYNC_INTERVAL = 60 + class ProviderBudgetLimiting(CustomLogger): def __init__(self, router_cache: DualCache, provider_budget_config: dict): self.router_cache = router_cache + asyncio.create_task(self.periodic_sync_in_memory_spend_with_redis()) # cast elements of provider_budget_config to ProviderBudgetInfo for provider, config in provider_budget_config.items(): @@ -222,6 +226,74 @@ class ProviderBudgetLimiting(CustomLogger): f"Incremented spend for {spend_key} by {response_cost}, ttl: {ttl_seconds}" ) + async def periodic_sync_in_memory_spend_with_redis(self): + """ + Handler that triggers sync_in_memory_spend_with_redis every DEFAULT_REDIS_SYNC_INTERVAL seconds + + Required for multi-instance environment usage of provider budgets + """ + while True: + try: + await self._sync_in_memory_spend_with_redis() + await asyncio.sleep( + DEFAULT_REDIS_SYNC_INTERVAL + ) # Wait for 5 seconds before next sync + except Exception as e: + verbose_router_logger.error(f"Error in periodic sync task: {str(e)}") + await asyncio.sleep( + DEFAULT_REDIS_SYNC_INTERVAL + ) # Still wait 5 seconds on error before retrying + + async def _sync_in_memory_spend_with_redis(self): + """ + Ensures in-memory cache is updated with latest Redis values for all provider spends. + + Why Do we need this? + - Redis is our source of truth for provider spend + - In-memory cache goes out of sync if it does not get updated with the values from Redis + + Why not just rely on DualCache ? + - DualCache does not handle synchronization between in-memory and Redis + + In a multi-instance evironment, each instance needs to periodically get the provider spend from Redis to ensure it is consistent across all instances. + """ + + try: + # No need to sync if Redis cache is not initialized + if self.router_cache.redis_cache is None: + return + + # Get all providers and their budget configs + cache_keys = [] + for provider, config in self.provider_budget_config.items(): + if config is None: + continue + cache_keys.append(f"provider_spend:{provider}:{config.time_period}") + + # Batch fetch current spend values from Redis + redis_values = await self.router_cache.redis_cache.async_batch_get_cache( + key_list=cache_keys + ) + + # Update in-memory cache with Redis values + if isinstance(redis_values, dict): # Check if redis_values is a dictionary + for key, value in redis_values.items(): + if value is not None: + self.router_cache.in_memory_cache.set_cache( + key=key, value=float(value) + ) + verbose_router_logger.debug( + f"Updated in-memory cache for {key}: {value}" + ) + + except Exception as e: + import traceback + + traceback.print_exc() + verbose_router_logger.error( + f"Error syncing in-memory cache with Redis: {str(e)}" + ) + def _get_budget_config_for_provider( self, provider: str ) -> Optional[ProviderBudgetInfo]: From 5f04c04cc5e1bd3ba85cfb1a766a6cb06796d612 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 23 Nov 2024 16:20:41 -0800 Subject: [PATCH 09/29] test_in_memory_redis_sync_e2e --- .../test_router_provider_budgets.py | 69 +++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/tests/local_testing/test_router_provider_budgets.py b/tests/local_testing/test_router_provider_budgets.py index a6574ba4b..3ef8de64b 100644 --- a/tests/local_testing/test_router_provider_budgets.py +++ b/tests/local_testing/test_router_provider_budgets.py @@ -267,3 +267,72 @@ async def test_prometheus_metric_tracking(): # Verify the mock was called correctly mock_prometheus.track_provider_remaining_budget.assert_called_once() + + +@pytest.mark.asyncio +async def test_in_memory_redis_sync_e2e(): + """ + Test that the in-memory cache gets properly synced with Redis values through the periodic sync mechanism + + Critical test for using provider budgets in a multi-instance environment + """ + setattr(litellm.router_strategy.provider_budgets, "DEFAULT_REDIS_SYNC_INTERVAL", 2) + + provider_budget_config = { + "openai": ProviderBudgetInfo(time_period="1d", budget_limit=100), + } + + router = Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "openai/gpt-3.5-turbo", + }, + }, + ], + provider_budget_config=provider_budget_config, + redis_host=os.getenv("REDIS_HOST"), + redis_port=int(os.getenv("REDIS_PORT")), + redis_password=os.getenv("REDIS_PASSWORD"), + ) + + if router.cache is None: + raise ValueError("Router cache is not initialized") + if router.cache.redis_cache is None: + raise ValueError("Redis cache is not initialized") + + # Get the ProviderBudgetLimiting instance + spend_key = "provider_spend:openai:1d" + + # Set initial values + test_spend_1 = 50.0 + await router.cache.redis_cache.async_set_cache(key=spend_key, value=test_spend_1) + + # Make a completion call to trigger spend tracking + response = await router.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hello"}], + mock_response="Hello there!", + ) + + # Wait for periodic sync (should be less than DEFAULT_REDIS_SYNC_INTERVAL) + await asyncio.sleep(2.5) + + # Verify in-memory cache matches Redis + in_memory_spend = float(router.cache.in_memory_cache.get_cache(spend_key) or 0) + redis_spend = float(await router.cache.redis_cache.async_get_cache(spend_key) or 0) + assert ( + abs(in_memory_spend - redis_spend) < 0.01 + ) # Allow for small floating point differences + + # Update Redis with new value from a "different litellm proxy instance" + test_spend_2 = 75.0 + await router.cache.redis_cache.async_set_cache(key=spend_key, value=test_spend_2) + + # Wait for periodic sync + await asyncio.sleep(2.5) + + # Verify in-memory cache was updated + in_memory_spend = float(router.cache.in_memory_cache.get_cache(spend_key) or 0) + assert abs(in_memory_spend - test_spend_2) < 0.01 From 33a0744abe30f71baa43c4a690810c2240782768 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 23 Nov 2024 16:24:13 -0800 Subject: [PATCH 10/29] test_in_memory_redis_sync_e2e --- .../test_router_provider_budgets.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/local_testing/test_router_provider_budgets.py b/tests/local_testing/test_router_provider_budgets.py index 3ef8de64b..415a93f4d 100644 --- a/tests/local_testing/test_router_provider_budgets.py +++ b/tests/local_testing/test_router_provider_budgets.py @@ -2,6 +2,7 @@ import sys, os, asyncio, time, random from datetime import datetime import traceback from dotenv import load_dotenv +from httpx import delete load_dotenv() import os, copy @@ -276,6 +277,11 @@ async def test_in_memory_redis_sync_e2e(): Critical test for using provider budgets in a multi-instance environment """ + original_sync_interval = getattr( + litellm.router_strategy.provider_budgets, "DEFAULT_REDIS_SYNC_INTERVAL" + ) + + # Modify for test setattr(litellm.router_strategy.provider_budgets, "DEFAULT_REDIS_SYNC_INTERVAL", 2) provider_budget_config = { @@ -336,3 +342,13 @@ async def test_in_memory_redis_sync_e2e(): # Verify in-memory cache was updated in_memory_spend = float(router.cache.in_memory_cache.get_cache(spend_key) or 0) assert abs(in_memory_spend - test_spend_2) < 0.01 + + # clean up key from router cache + await router.cache.async_delete_cache(spend_key) + + # Restore original value + setattr( + litellm.router_strategy.provider_budgets, + "DEFAULT_REDIS_SYNC_INTERVAL", + original_sync_interval, + ) From e5c718992216a7da4a9395b904cf830c44c5418c Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 23 Nov 2024 16:48:36 -0800 Subject: [PATCH 11/29] fix test_in_memory_redis_sync_e2e --- tests/local_testing/test_router_provider_budgets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/local_testing/test_router_provider_budgets.py b/tests/local_testing/test_router_provider_budgets.py index 415a93f4d..3f67dd0cc 100644 --- a/tests/local_testing/test_router_provider_budgets.py +++ b/tests/local_testing/test_router_provider_budgets.py @@ -291,7 +291,7 @@ async def test_in_memory_redis_sync_e2e(): router = Router( model_list=[ { - "model_name": "gpt-3.5-turbo", + "model_name": "gpt-3.5-turbo-very-new", "litellm_params": { "model": "openai/gpt-3.5-turbo", }, @@ -317,7 +317,7 @@ async def test_in_memory_redis_sync_e2e(): # Make a completion call to trigger spend tracking response = await router.acompletion( - model="gpt-3.5-turbo", + model="gpt-3.5-turbo-very-new", messages=[{"role": "user", "content": "Hello"}], mock_response="Hello there!", ) From d86a7c3702f3b9c4a57080b20ef77bd16bfc2229 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 23 Nov 2024 16:52:45 -0800 Subject: [PATCH 12/29] fix code quality check --- litellm/router_strategy/provider_budgets.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/litellm/router_strategy/provider_budgets.py b/litellm/router_strategy/provider_budgets.py index 42e63b297..65b1e3c97 100644 --- a/litellm/router_strategy/provider_budgets.py +++ b/litellm/router_strategy/provider_budgets.py @@ -279,7 +279,7 @@ class ProviderBudgetLimiting(CustomLogger): if isinstance(redis_values, dict): # Check if redis_values is a dictionary for key, value in redis_values.items(): if value is not None: - self.router_cache.in_memory_cache.set_cache( + await self.router_cache.in_memory_cache.async_set_cache( key=key, value=float(value) ) verbose_router_logger.debug( @@ -287,9 +287,6 @@ class ProviderBudgetLimiting(CustomLogger): ) except Exception as e: - import traceback - - traceback.print_exc() verbose_router_logger.error( f"Error syncing in-memory cache with Redis: {str(e)}" ) From a40b3bcbbd658b41d5c86105fbf15c05a998a557 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 23 Nov 2024 18:07:56 -0800 Subject: [PATCH 13/29] fix test provider budgets --- litellm/router_strategy/provider_budgets.py | 55 ++++++++++++++++++++- 1 file changed, 53 insertions(+), 2 deletions(-) diff --git a/litellm/router_strategy/provider_budgets.py b/litellm/router_strategy/provider_budgets.py index 65b1e3c97..002e2d1c4 100644 --- a/litellm/router_strategy/provider_budgets.py +++ b/litellm/router_strategy/provider_budgets.py @@ -51,6 +51,7 @@ DEFAULT_REDIS_SYNC_INTERVAL = 60 class ProviderBudgetLimiting(CustomLogger): def __init__(self, router_cache: DualCache, provider_budget_config: dict): self.router_cache = router_cache + self.last_synced_values = {} # To track last synced values for each key asyncio.create_task(self.periodic_sync_in_memory_spend_with_redis()) # cast elements of provider_budget_config to ProviderBudgetInfo @@ -217,10 +218,9 @@ class ProviderBudgetLimiting(CustomLogger): f"Incrementing spend for {spend_key} by {response_cost}, ttl: {ttl_seconds}" ) # Increment the spend in Redis and set TTL - await self.router_cache.async_increment_cache( + await self.router_cache.in_memory_cache.async_increment( key=spend_key, value=response_cost, - ttl=ttl_seconds, ) verbose_router_logger.debug( f"Incremented spend for {spend_key} by {response_cost}, ttl: {ttl_seconds}" @@ -244,6 +244,55 @@ class ProviderBudgetLimiting(CustomLogger): DEFAULT_REDIS_SYNC_INTERVAL ) # Still wait 5 seconds on error before retrying + async def _push_in_memory_increments_to_redis(self): + """ + Sync in-memory spend to Redis. + + Pushes all increments from in-memory counter to Redis and resets the counter. + """ + try: + if not self.router_cache.redis_cache: + return # Redis is not initialized + + # Build cache keys for all providers + cache_keys = [ + f"provider_spend:{provider}:{config.time_period}" + for provider, config in self.provider_budget_config.items() + if config is not None + ] + + # Fetch current in-memory values + current_values = ( + await self.router_cache.in_memory_cache.async_batch_get_cache( + keys=cache_keys + ) + ) + + for key, current_value in zip(cache_keys, current_values): + if current_value is None: + continue # Skip keys with no in-memory value + + # Get the last synced value (default to 0 if not synced before) + last_synced = self.last_synced_values.get(key, 0.0) + + # Calculate the delta to push to Redis + delta = float(current_value) - last_synced + if delta > 0: # Only push if there is a positive increment + await self.router_cache.redis_cache.async_increment( + key=key, value=delta + ) + verbose_router_logger.debug( + f"Pushed delta to Redis for {key}: {delta} (last synced: {last_synced}, current: {current_value})" + ) + + # Update last synced value + self.last_synced_values[key] = float(current_value) + + except Exception as e: + verbose_router_logger.error( + f"Error syncing in-memory cache with Redis: {str(e)}" + ) + async def _sync_in_memory_spend_with_redis(self): """ Ensures in-memory cache is updated with latest Redis values for all provider spends. @@ -263,6 +312,8 @@ class ProviderBudgetLimiting(CustomLogger): if self.router_cache.redis_cache is None: return + await self._push_in_memory_increments_to_redis() + # Get all providers and their budget configs cache_keys = [] for provider, config in self.provider_budget_config.items(): From 6f4fdc58c7bccacfd5a2564ce577cb2e449251db Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 23 Nov 2024 18:09:47 -0800 Subject: [PATCH 14/29] working provider budget tests --- .../test_router_provider_budgets.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/local_testing/test_router_provider_budgets.py b/tests/local_testing/test_router_provider_budgets.py index 3f67dd0cc..bd970ea1b 100644 --- a/tests/local_testing/test_router_provider_budgets.py +++ b/tests/local_testing/test_router_provider_budgets.py @@ -2,7 +2,6 @@ import sys, os, asyncio, time, random from datetime import datetime import traceback from dotenv import load_dotenv -from httpx import delete load_dotenv() import os, copy @@ -35,6 +34,8 @@ async def test_provider_budgets_e2e_test(): - Next 3 requests all go to Azure """ + # Modify for test + setattr(litellm.router_strategy.provider_budgets, "DEFAULT_REDIS_SYNC_INTERVAL", 2) provider_budget_config: ProviderBudgetConfigType = { "openai": ProviderBudgetInfo(time_period="1d", budget_limit=0.000000000001), "azure": ProviderBudgetInfo(time_period="1d", budget_limit=100), @@ -72,7 +73,7 @@ async def test_provider_budgets_e2e_test(): ) print(response) - await asyncio.sleep(0.5) + await asyncio.sleep(2.5) for _ in range(3): response = await router.acompletion( @@ -95,7 +96,7 @@ async def test_provider_budgets_e2e_test_expect_to_fail(): - first request passes, all subsequent requests fail """ - + setattr(litellm.router_strategy.provider_budgets, "DEFAULT_REDIS_SYNC_INTERVAL", 2) # Note: We intentionally use a dictionary with string keys for budget_limit and time_period # we want to test that the router can handle type conversion, since the proxy config yaml passes these values as a dictionary provider_budget_config = { @@ -126,7 +127,7 @@ async def test_provider_budgets_e2e_test_expect_to_fail(): ) print(response) - await asyncio.sleep(0.5) + await asyncio.sleep(2.5) for _ in range(3): with pytest.raises(Exception) as exc_info: @@ -143,7 +144,8 @@ async def test_provider_budgets_e2e_test_expect_to_fail(): assert "Exceeded budget for provider" in str(exc_info.value) -def test_get_llm_provider_for_deployment(): +@pytest.mark.asyncio +async def test_get_llm_provider_for_deployment(): """ Test the _get_llm_provider_for_deployment helper method @@ -173,7 +175,8 @@ def test_get_llm_provider_for_deployment(): assert provider_budget._get_llm_provider_for_deployment(unknown_deployment) is None -def test_get_budget_config_for_provider(): +@pytest.mark.asyncio +async def test_get_budget_config_for_provider(): """ Test the _get_budget_config_for_provider helper method @@ -207,6 +210,7 @@ async def test_prometheus_metric_tracking(): """ Test that the Prometheus metric for provider budget is tracked correctly """ + setattr(litellm.router_strategy.provider_budgets, "DEFAULT_REDIS_SYNC_INTERVAL", 2) from unittest.mock import MagicMock from litellm.integrations.prometheus import PrometheusLogger @@ -264,7 +268,7 @@ async def test_prometheus_metric_tracking(): except Exception as e: print("error", e) - await asyncio.sleep(0.5) + await asyncio.sleep(2.5) # Verify the mock was called correctly mock_prometheus.track_provider_remaining_budget.assert_called_once() From face50edadb20c5b28dfbd92119ae217a9717438 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 23 Nov 2024 18:16:20 -0800 Subject: [PATCH 15/29] add fixture for provider budget routing --- .../test_router_provider_budgets.py | 103 ++++-------------- 1 file changed, 19 insertions(+), 84 deletions(-) diff --git a/tests/local_testing/test_router_provider_budgets.py b/tests/local_testing/test_router_provider_budgets.py index bd970ea1b..87dc46a44 100644 --- a/tests/local_testing/test_router_provider_budgets.py +++ b/tests/local_testing/test_router_provider_budgets.py @@ -25,6 +25,25 @@ import litellm verbose_router_logger.setLevel(logging.DEBUG) +@pytest.fixture(autouse=True) +async def cleanup_redis(): + """Cleanup Redis cache before each test""" + try: + import redis + + redis_client = redis.Redis( + host=os.getenv("REDIS_HOST"), + port=int(os.getenv("REDIS_PORT")), + password=os.getenv("REDIS_PASSWORD"), + ) + # Delete all provider spend keys + for key in redis_client.scan_iter("provider_spend:*"): + redis_client.delete(key) + except Exception as e: + print(f"Error cleaning up Redis: {str(e)}") + yield + + @pytest.mark.asyncio async def test_provider_budgets_e2e_test(): """ @@ -272,87 +291,3 @@ async def test_prometheus_metric_tracking(): # Verify the mock was called correctly mock_prometheus.track_provider_remaining_budget.assert_called_once() - - -@pytest.mark.asyncio -async def test_in_memory_redis_sync_e2e(): - """ - Test that the in-memory cache gets properly synced with Redis values through the periodic sync mechanism - - Critical test for using provider budgets in a multi-instance environment - """ - original_sync_interval = getattr( - litellm.router_strategy.provider_budgets, "DEFAULT_REDIS_SYNC_INTERVAL" - ) - - # Modify for test - setattr(litellm.router_strategy.provider_budgets, "DEFAULT_REDIS_SYNC_INTERVAL", 2) - - provider_budget_config = { - "openai": ProviderBudgetInfo(time_period="1d", budget_limit=100), - } - - router = Router( - model_list=[ - { - "model_name": "gpt-3.5-turbo-very-new", - "litellm_params": { - "model": "openai/gpt-3.5-turbo", - }, - }, - ], - provider_budget_config=provider_budget_config, - redis_host=os.getenv("REDIS_HOST"), - redis_port=int(os.getenv("REDIS_PORT")), - redis_password=os.getenv("REDIS_PASSWORD"), - ) - - if router.cache is None: - raise ValueError("Router cache is not initialized") - if router.cache.redis_cache is None: - raise ValueError("Redis cache is not initialized") - - # Get the ProviderBudgetLimiting instance - spend_key = "provider_spend:openai:1d" - - # Set initial values - test_spend_1 = 50.0 - await router.cache.redis_cache.async_set_cache(key=spend_key, value=test_spend_1) - - # Make a completion call to trigger spend tracking - response = await router.acompletion( - model="gpt-3.5-turbo-very-new", - messages=[{"role": "user", "content": "Hello"}], - mock_response="Hello there!", - ) - - # Wait for periodic sync (should be less than DEFAULT_REDIS_SYNC_INTERVAL) - await asyncio.sleep(2.5) - - # Verify in-memory cache matches Redis - in_memory_spend = float(router.cache.in_memory_cache.get_cache(spend_key) or 0) - redis_spend = float(await router.cache.redis_cache.async_get_cache(spend_key) or 0) - assert ( - abs(in_memory_spend - redis_spend) < 0.01 - ) # Allow for small floating point differences - - # Update Redis with new value from a "different litellm proxy instance" - test_spend_2 = 75.0 - await router.cache.redis_cache.async_set_cache(key=spend_key, value=test_spend_2) - - # Wait for periodic sync - await asyncio.sleep(2.5) - - # Verify in-memory cache was updated - in_memory_spend = float(router.cache.in_memory_cache.get_cache(spend_key) or 0) - assert abs(in_memory_spend - test_spend_2) < 0.01 - - # clean up key from router cache - await router.cache.async_delete_cache(spend_key) - - # Restore original value - setattr( - litellm.router_strategy.provider_budgets, - "DEFAULT_REDIS_SYNC_INTERVAL", - original_sync_interval, - ) From 6db00270c13e6af75630331067714c782a407ded Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 23 Nov 2024 18:20:56 -0800 Subject: [PATCH 16/29] fix router testing for provider budgets --- tests/local_testing/test_router_provider_budgets.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/local_testing/test_router_provider_budgets.py b/tests/local_testing/test_router_provider_budgets.py index 87dc46a44..5fa2e08ee 100644 --- a/tests/local_testing/test_router_provider_budgets.py +++ b/tests/local_testing/test_router_provider_budgets.py @@ -25,23 +25,25 @@ import litellm verbose_router_logger.setLevel(logging.DEBUG) -@pytest.fixture(autouse=True) -async def cleanup_redis(): +def cleanup_redis(): """Cleanup Redis cache before each test""" try: import redis + print("cleaning up redis..") + redis_client = redis.Redis( host=os.getenv("REDIS_HOST"), port=int(os.getenv("REDIS_PORT")), password=os.getenv("REDIS_PASSWORD"), ) + print("scan iter result", redis_client.scan_iter("provider_spend:*")) # Delete all provider spend keys for key in redis_client.scan_iter("provider_spend:*"): + print("deleting key", key) redis_client.delete(key) except Exception as e: print(f"Error cleaning up Redis: {str(e)}") - yield @pytest.mark.asyncio @@ -53,6 +55,7 @@ async def test_provider_budgets_e2e_test(): - Next 3 requests all go to Azure """ + cleanup_redis() # Modify for test setattr(litellm.router_strategy.provider_budgets, "DEFAULT_REDIS_SYNC_INTERVAL", 2) provider_budget_config: ProviderBudgetConfigType = { @@ -115,6 +118,7 @@ async def test_provider_budgets_e2e_test_expect_to_fail(): - first request passes, all subsequent requests fail """ + cleanup_redis() setattr(litellm.router_strategy.provider_budgets, "DEFAULT_REDIS_SYNC_INTERVAL", 2) # Note: We intentionally use a dictionary with string keys for budget_limit and time_period # we want to test that the router can handle type conversion, since the proxy config yaml passes these values as a dictionary @@ -169,6 +173,7 @@ async def test_get_llm_provider_for_deployment(): Test the _get_llm_provider_for_deployment helper method """ + cleanup_redis() provider_budget = ProviderBudgetLimiting( router_cache=DualCache(), provider_budget_config={} ) @@ -200,6 +205,7 @@ async def test_get_budget_config_for_provider(): Test the _get_budget_config_for_provider helper method """ + cleanup_redis() config = { "openai": ProviderBudgetInfo(time_period="1d", budget_limit=100), "anthropic": ProviderBudgetInfo(time_period="7d", budget_limit=500), @@ -229,6 +235,7 @@ async def test_prometheus_metric_tracking(): """ Test that the Prometheus metric for provider budget is tracked correctly """ + cleanup_redis() setattr(litellm.router_strategy.provider_budgets, "DEFAULT_REDIS_SYNC_INTERVAL", 2) from unittest.mock import MagicMock from litellm.integrations.prometheus import PrometheusLogger From a061f0e39cb86984328db1d580d1bcca3e32375e Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 23 Nov 2024 18:25:28 -0800 Subject: [PATCH 17/29] add comments on provider budget routing --- litellm/router_strategy/provider_budgets.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/litellm/router_strategy/provider_budgets.py b/litellm/router_strategy/provider_budgets.py index 002e2d1c4..55ac18606 100644 --- a/litellm/router_strategy/provider_budgets.py +++ b/litellm/router_strategy/provider_budgets.py @@ -246,9 +246,12 @@ class ProviderBudgetLimiting(CustomLogger): async def _push_in_memory_increments_to_redis(self): """ - Sync in-memory spend to Redis. + This is a latency / speed optimization. - Pushes all increments from in-memory counter to Redis and resets the counter. + How this works: + - Collect all provider spend increments in `router_cache.in_memory_cache`, done in async_log_success_event + - Push all increments to Redis in this function + - Reset the in-memory `last_synced_values` """ try: if not self.router_cache.redis_cache: @@ -274,12 +277,12 @@ class ProviderBudgetLimiting(CustomLogger): # Get the last synced value (default to 0 if not synced before) last_synced = self.last_synced_values.get(key, 0.0) - # Calculate the delta to push to Redis delta = float(current_value) - last_synced if delta > 0: # Only push if there is a positive increment await self.router_cache.redis_cache.async_increment( - key=key, value=delta + key=key, + value=delta, ) verbose_router_logger.debug( f"Pushed delta to Redis for {key}: {delta} (last synced: {last_synced}, current: {current_value})" @@ -299,10 +302,8 @@ class ProviderBudgetLimiting(CustomLogger): Why Do we need this? - Redis is our source of truth for provider spend - - In-memory cache goes out of sync if it does not get updated with the values from Redis + - Optimization to hit ~100ms latency. Performance was impacted when redis was used for read/write per request - Why not just rely on DualCache ? - - DualCache does not handle synchronization between in-memory and Redis In a multi-instance evironment, each instance needs to periodically get the provider spend from Redis to ensure it is consistent across all instances. """ @@ -312,8 +313,10 @@ class ProviderBudgetLimiting(CustomLogger): if self.router_cache.redis_cache is None: return + # Push all provider spend increments to Redis await self._push_in_memory_increments_to_redis() + # Handle Reading all current provider spend from Redis in Memory # Get all providers and their budget configs cache_keys = [] for provider, config in self.provider_budget_config.items(): From 8f74da64386ec5033893ae8a2dd872854d15da3a Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sun, 24 Nov 2024 09:38:47 -0800 Subject: [PATCH 18/29] use RedisPipelineIncrementOperation --- litellm/types/caching.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/litellm/types/caching.py b/litellm/types/caching.py index 7fca4c041..644c6e8be 100644 --- a/litellm/types/caching.py +++ b/litellm/types/caching.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Literal +from typing import Literal, TypedDict class LiteLLMCacheType(str, Enum): @@ -23,3 +23,13 @@ CachingSupportedCallTypes = Literal[ "arerank", "rerank", ] + + +class RedisPipelineIncrementOperation(TypedDict): + """ + TypeDict for 1 Redis Pipeline Increment Operation + """ + + key: str + increment_value: float + ttl_seconds: int From 5dd8726685bef9d47be0756258fed5c0bff18912 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sun, 24 Nov 2024 09:40:26 -0800 Subject: [PATCH 19/29] add redis async_increment_pipeline --- litellm/caching/redis_cache.py | 90 ++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/litellm/caching/redis_cache.py b/litellm/caching/redis_cache.py index e15a3f83d..58ef51557 100644 --- a/litellm/caching/redis_cache.py +++ b/litellm/caching/redis_cache.py @@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, Tuple import litellm from litellm._logging import print_verbose, verbose_logger from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs +from litellm.types.caching import RedisPipelineIncrementOperation from litellm.types.services import ServiceLoggerPayload, ServiceTypes from litellm.types.utils import all_litellm_params @@ -890,3 +891,92 @@ class RedisCache(BaseCache): def delete_cache(self, key): self.redis_client.delete(key) + + async def _pipeline_increment_helper( + self, + pipe: pipeline, + increment_list: List[RedisPipelineIncrementOperation], + ) -> Optional[List[float]]: + """Helper function for pipeline increment operations""" + # Iterate through each increment operation and add commands to pipeline + for increment_op in increment_list: + cache_key = self.check_and_fix_namespace(key=increment_op["key"]) + print_verbose( + f"Increment ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {increment_op['increment_value']}\nttl={increment_op['ttl_seconds']}" + ) + pipe.incrbyfloat(cache_key, increment_op["increment_value"]) + if increment_op["ttl_seconds"] is not None: + _td = timedelta(seconds=increment_op["ttl_seconds"]) + pipe.expire(cache_key, _td) + # Execute the pipeline and return results + results = await pipe.execute() + print_verbose(f"Increment ASYNC Redis Cache PIPELINE: results: {results}") + return results + + async def async_increment_pipeline( + self, increment_list: List[RedisPipelineIncrementOperation], **kwargs + ) -> Optional[List[float]]: + """ + Use Redis Pipelines for bulk increment operations + Args: + increment_list: List of RedisPipelineIncrementOperation dicts containing: + - key: str + - increment_value: float + - ttl_seconds: int + """ + # don't waste a network request if there's nothing to increment + if len(increment_list) == 0: + return + + from redis.asyncio import Redis + + _redis_client: Redis = self.init_async_client() # type: ignore + start_time = time.time() + + print_verbose( + f"Increment Async Redis Cache Pipeline: increment list: {increment_list}" + ) + + try: + async with _redis_client as redis_client: + async with redis_client.pipeline(transaction=True) as pipe: + results = await self._pipeline_increment_helper( + pipe, increment_list + ) + + print_verbose(f"pipeline increment results: {results}") + + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_success_hook( + service=ServiceTypes.REDIS, + duration=_duration, + call_type="async_increment_pipeline", + start_time=start_time, + end_time=end_time, + parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + ) + ) + return results + except Exception as e: + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_failure_hook( + service=ServiceTypes.REDIS, + duration=_duration, + error=e, + call_type="async_increment_pipeline", + start_time=start_time, + end_time=end_time, + parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + ) + ) + verbose_logger.error( + "LiteLLM Redis Caching: async increment_pipeline() - Got exception from REDIS %s", + str(e), + ) + raise e From 24ab97948621b8dd785960b2f13a1eb86fcc4de7 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sun, 24 Nov 2024 09:42:25 -0800 Subject: [PATCH 20/29] use redis async_increment_pipeline --- litellm/router_strategy/provider_budgets.py | 51 ++++++--------------- 1 file changed, 15 insertions(+), 36 deletions(-) diff --git a/litellm/router_strategy/provider_budgets.py b/litellm/router_strategy/provider_budgets.py index 55ac18606..c0686104c 100644 --- a/litellm/router_strategy/provider_budgets.py +++ b/litellm/router_strategy/provider_budgets.py @@ -24,6 +24,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict, Union import litellm from litellm._logging import verbose_router_logger from litellm.caching.caching import DualCache +from litellm.caching.redis_cache import RedisPipelineIncrementOperation from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs from litellm.litellm_core_utils.duration_parser import duration_in_seconds @@ -51,7 +52,7 @@ DEFAULT_REDIS_SYNC_INTERVAL = 60 class ProviderBudgetLimiting(CustomLogger): def __init__(self, router_cache: DualCache, provider_budget_config: dict): self.router_cache = router_cache - self.last_synced_values = {} # To track last synced values for each key + self.redis_increment_operation_queue: List[RedisPipelineIncrementOperation] = [] asyncio.create_task(self.periodic_sync_in_memory_spend_with_redis()) # cast elements of provider_budget_config to ProviderBudgetInfo @@ -214,14 +215,17 @@ class ProviderBudgetLimiting(CustomLogger): spend_key = f"provider_spend:{custom_llm_provider}:{budget_config.time_period}" ttl_seconds = duration_in_seconds(duration=budget_config.time_period) - verbose_router_logger.debug( - f"Incrementing spend for {spend_key} by {response_cost}, ttl: {ttl_seconds}" + + # Create RedisPipelineIncrementOperation object + increment_op = RedisPipelineIncrementOperation( + key=spend_key, increment_value=response_cost, ttl_seconds=ttl_seconds ) - # Increment the spend in Redis and set TTL + await self.router_cache.in_memory_cache.async_increment( key=spend_key, value=response_cost, ) + self.redis_increment_operation_queue.append(increment_op) verbose_router_logger.debug( f"Incremented spend for {spend_key} by {response_cost}, ttl: {ttl_seconds}" ) @@ -256,40 +260,15 @@ class ProviderBudgetLimiting(CustomLogger): try: if not self.router_cache.redis_cache: return # Redis is not initialized - - # Build cache keys for all providers - cache_keys = [ - f"provider_spend:{provider}:{config.time_period}" - for provider, config in self.provider_budget_config.items() - if config is not None - ] - - # Fetch current in-memory values - current_values = ( - await self.router_cache.in_memory_cache.async_batch_get_cache( - keys=cache_keys - ) + verbose_router_logger.debug( + "Pushing Redis Increment Pipeline for queue: %s", + self.redis_increment_operation_queue, + ) + await self.router_cache.redis_cache.async_increment_pipeline( + increment_list=self.redis_increment_operation_queue, ) - for key, current_value in zip(cache_keys, current_values): - if current_value is None: - continue # Skip keys with no in-memory value - - # Get the last synced value (default to 0 if not synced before) - last_synced = self.last_synced_values.get(key, 0.0) - # Calculate the delta to push to Redis - delta = float(current_value) - last_synced - if delta > 0: # Only push if there is a positive increment - await self.router_cache.redis_cache.async_increment( - key=key, - value=delta, - ) - verbose_router_logger.debug( - f"Pushed delta to Redis for {key}: {delta} (last synced: {last_synced}, current: {current_value})" - ) - - # Update last synced value - self.last_synced_values[key] = float(current_value) + self.redis_increment_operation_queue = [] except Exception as e: verbose_router_logger.error( From 87e30cd5628f0921166c611d32171db12e30724b Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sun, 24 Nov 2024 09:42:40 -0800 Subject: [PATCH 21/29] use lower value for testing --- litellm/proxy/proxy_config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index f716585b3..c40b56eeb 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -11,7 +11,7 @@ model_list: router_settings: provider_budget_config: openai: - budget_limit: 1 # float of $ value budget for time period + budget_limit: 0.2 # float of $ value budget for time period time_period: 1d # can be 1d, 2d, 30d anthropic: budget_limit: 5 From c4937dffe2cb0b190c7e95aac63bc73f94673501 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sun, 24 Nov 2024 09:45:33 -0800 Subject: [PATCH 22/29] use redis async_increment_pipeline --- litellm/router_strategy/provider_budgets.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/litellm/router_strategy/provider_budgets.py b/litellm/router_strategy/provider_budgets.py index c0686104c..2b34f01eb 100644 --- a/litellm/router_strategy/provider_budgets.py +++ b/litellm/router_strategy/provider_budgets.py @@ -264,9 +264,12 @@ class ProviderBudgetLimiting(CustomLogger): "Pushing Redis Increment Pipeline for queue: %s", self.redis_increment_operation_queue, ) - await self.router_cache.redis_cache.async_increment_pipeline( - increment_list=self.redis_increment_operation_queue, - ) + if len(self.redis_increment_operation_queue) > 0: + asyncio.create_task( + self.router_cache.redis_cache.async_increment_pipeline( + increment_list=self.redis_increment_operation_queue, + ) + ) self.redis_increment_operation_queue = [] From be25706736c82c27e45aec25ccf8297ea5d5b01e Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sun, 24 Nov 2024 10:22:00 -0800 Subject: [PATCH 23/29] use consistent key name for increment op --- litellm/caching/redis_cache.py | 6 +++--- litellm/types/caching.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/litellm/caching/redis_cache.py b/litellm/caching/redis_cache.py index 58ef51557..1cabbd4f9 100644 --- a/litellm/caching/redis_cache.py +++ b/litellm/caching/redis_cache.py @@ -902,11 +902,11 @@ class RedisCache(BaseCache): for increment_op in increment_list: cache_key = self.check_and_fix_namespace(key=increment_op["key"]) print_verbose( - f"Increment ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {increment_op['increment_value']}\nttl={increment_op['ttl_seconds']}" + f"Increment ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {increment_op['increment_value']}\nttl={increment_op['ttl']}" ) pipe.incrbyfloat(cache_key, increment_op["increment_value"]) - if increment_op["ttl_seconds"] is not None: - _td = timedelta(seconds=increment_op["ttl_seconds"]) + if increment_op["ttl"] is not None: + _td = timedelta(seconds=increment_op["ttl"]) pipe.expire(cache_key, _td) # Execute the pipeline and return results results = await pipe.execute() diff --git a/litellm/types/caching.py b/litellm/types/caching.py index 644c6e8be..a6f9de308 100644 --- a/litellm/types/caching.py +++ b/litellm/types/caching.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Literal, TypedDict +from typing import Literal, Optional, TypedDict class LiteLLMCacheType(str, Enum): @@ -32,4 +32,4 @@ class RedisPipelineIncrementOperation(TypedDict): key: str increment_value: float - ttl_seconds: int + ttl: Optional[int] From 8aa8f2e4ab70471ccff0e998ccf1862daaef6563 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sun, 24 Nov 2024 15:47:56 -0800 Subject: [PATCH 24/29] add handling for budget windows --- litellm/proxy/proxy_config.yaml | 6 +- litellm/router_strategy/provider_budgets.py | 110 ++++++++++++++++---- 2 files changed, 89 insertions(+), 27 deletions(-) diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index c40b56eeb..13fb1bcbe 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -11,7 +11,7 @@ model_list: router_settings: provider_budget_config: openai: - budget_limit: 0.2 # float of $ value budget for time period + budget_limit: 0.3 # float of $ value budget for time period time_period: 1d # can be 1d, 2d, 30d anthropic: budget_limit: 5 @@ -21,6 +21,4 @@ router_settings: redis_password: os.environ/REDIS_PASSWORD litellm_settings: - callbacks: ["prometheus"] - - + callbacks: ["prometheus"] \ No newline at end of file diff --git a/litellm/router_strategy/provider_budgets.py b/litellm/router_strategy/provider_budgets.py index 2b34f01eb..730447e7e 100644 --- a/litellm/router_strategy/provider_budgets.py +++ b/litellm/router_strategy/provider_budgets.py @@ -19,6 +19,7 @@ anthropic: """ import asyncio +from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict, Union import litellm @@ -46,7 +47,7 @@ if TYPE_CHECKING: else: Span = Any -DEFAULT_REDIS_SYNC_INTERVAL = 60 +DEFAULT_REDIS_SYNC_INTERVAL = 1 class ProviderBudgetLimiting(CustomLogger): @@ -179,19 +180,55 @@ class ProviderBudgetLimiting(CustomLogger): return potential_deployments + async def _get_or_set_budget_start_time( + self, start_time_key: str, current_time: float, ttl_seconds: int + ) -> float: + """ + Get existing budget start time or set a new one + """ + budget_start = await self.router_cache.async_get_cache(start_time_key) + if budget_start is None: + await self.router_cache.async_set_cache( + key=start_time_key, value=current_time, ttl=ttl_seconds + ) + return current_time + return float(budget_start) + + async def _handle_new_budget_window( + self, + spend_key: str, + start_time_key: str, + current_time: float, + response_cost: float, + ttl_seconds: int, + ) -> float: + """Handle start of new budget window by resetting spend and start time""" + await self.router_cache.async_set_cache( + key=spend_key, value=response_cost, ttl=ttl_seconds + ) + await self.router_cache.async_set_cache( + key=start_time_key, value=current_time, ttl=ttl_seconds + ) + return current_time + + async def _increment_spend_in_current_window( + self, spend_key: str, response_cost: float, ttl: int + ): + """Increment spend within existing budget window""" + await self.router_cache.in_memory_cache.async_increment( + key=spend_key, + value=response_cost, + ttl=ttl, + ) + increment_op = RedisPipelineIncrementOperation( + key=spend_key, + increment_value=response_cost, + ttl=ttl, + ) + self.redis_increment_operation_queue.append(increment_op) + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): - """ - Increment provider spend in DualCache (InMemory + Redis) - - Handles saving current provider spend to Redis. - - Spend is stored as: - provider_spend:{provider}:{time_period} - ex. provider_spend:openai:1d - ex. provider_spend:anthropic:7d - - The time period is tracked for time_periods set in the provider budget config. - """ + """Original method now uses helper functions""" verbose_router_logger.debug("in ProviderBudgetLimiting.async_log_success_event") standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( "standard_logging_object", None @@ -214,20 +251,47 @@ class ProviderBudgetLimiting(CustomLogger): ) spend_key = f"provider_spend:{custom_llm_provider}:{budget_config.time_period}" - ttl_seconds = duration_in_seconds(duration=budget_config.time_period) + start_time_key = f"provider_budget_start_time:{custom_llm_provider}" - # Create RedisPipelineIncrementOperation object - increment_op = RedisPipelineIncrementOperation( - key=spend_key, increment_value=response_cost, ttl_seconds=ttl_seconds + current_time = datetime.now(timezone.utc).timestamp() + ttl_seconds = duration_in_seconds(budget_config.time_period) + + budget_start = await self._get_or_set_budget_start_time( + start_time_key=start_time_key, + current_time=current_time, + ttl_seconds=ttl_seconds, ) - await self.router_cache.in_memory_cache.async_increment( - key=spend_key, - value=response_cost, - ) - self.redis_increment_operation_queue.append(increment_op) + if budget_start is None: + # First spend for this provider + budget_start = await self._handle_new_budget_window( + spend_key=spend_key, + start_time_key=start_time_key, + current_time=current_time, + response_cost=response_cost, + ttl_seconds=ttl_seconds, + ) + elif (current_time - budget_start) > ttl_seconds: + # Budget window expired - reset everything + verbose_router_logger.debug("Budget window expired - resetting everything") + budget_start = await self._handle_new_budget_window( + spend_key=spend_key, + start_time_key=start_time_key, + current_time=current_time, + response_cost=response_cost, + ttl_seconds=ttl_seconds, + ) + else: + # Within existing window - increment spend + remaining_time = ttl_seconds - (current_time - budget_start) + ttl_for_increment = int(remaining_time) + + await self._increment_spend_in_current_window( + spend_key=spend_key, response_cost=response_cost, ttl=ttl_for_increment + ) + verbose_router_logger.debug( - f"Incremented spend for {spend_key} by {response_cost}, ttl: {ttl_seconds}" + f"Incremented spend for {spend_key} by {response_cost}" ) async def periodic_sync_in_memory_spend_with_redis(self): From ac5763843432e04aea10594fd3766f37104157f5 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sun, 24 Nov 2024 15:50:57 -0800 Subject: [PATCH 25/29] fix typing async_increment_pipeline --- litellm/caching/redis_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/caching/redis_cache.py b/litellm/caching/redis_cache.py index 1cabbd4f9..ba5c3a695 100644 --- a/litellm/caching/redis_cache.py +++ b/litellm/caching/redis_cache.py @@ -926,7 +926,7 @@ class RedisCache(BaseCache): """ # don't waste a network request if there's nothing to increment if len(increment_list) == 0: - return + return None from redis.asyncio import Redis From 2fb9b245a1e18ec7d7a6702de56f51adec092fec Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sun, 24 Nov 2024 15:54:28 -0800 Subject: [PATCH 26/29] fix set attr --- tests/local_testing/test_router_provider_budgets.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/local_testing/test_router_provider_budgets.py b/tests/local_testing/test_router_provider_budgets.py index 5fa2e08ee..c16fc5d5c 100644 --- a/tests/local_testing/test_router_provider_budgets.py +++ b/tests/local_testing/test_router_provider_budgets.py @@ -57,7 +57,6 @@ async def test_provider_budgets_e2e_test(): """ cleanup_redis() # Modify for test - setattr(litellm.router_strategy.provider_budgets, "DEFAULT_REDIS_SYNC_INTERVAL", 2) provider_budget_config: ProviderBudgetConfigType = { "openai": ProviderBudgetInfo(time_period="1d", budget_limit=0.000000000001), "azure": ProviderBudgetInfo(time_period="1d", budget_limit=100), @@ -119,7 +118,7 @@ async def test_provider_budgets_e2e_test_expect_to_fail(): """ cleanup_redis() - setattr(litellm.router_strategy.provider_budgets, "DEFAULT_REDIS_SYNC_INTERVAL", 2) + # Note: We intentionally use a dictionary with string keys for budget_limit and time_period # we want to test that the router can handle type conversion, since the proxy config yaml passes these values as a dictionary provider_budget_config = { @@ -236,7 +235,6 @@ async def test_prometheus_metric_tracking(): Test that the Prometheus metric for provider budget is tracked correctly """ cleanup_redis() - setattr(litellm.router_strategy.provider_budgets, "DEFAULT_REDIS_SYNC_INTERVAL", 2) from unittest.mock import MagicMock from litellm.integrations.prometheus import PrometheusLogger From d27b5274778cb52c5110374b746b3c3ad37e2b31 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sun, 24 Nov 2024 16:17:07 -0800 Subject: [PATCH 27/29] add clear doc strings --- litellm/router_strategy/provider_budgets.py | 55 ++++++++++++++------- 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/litellm/router_strategy/provider_budgets.py b/litellm/router_strategy/provider_budgets.py index 730447e7e..f4dc1ba94 100644 --- a/litellm/router_strategy/provider_budgets.py +++ b/litellm/router_strategy/provider_budgets.py @@ -184,7 +184,10 @@ class ProviderBudgetLimiting(CustomLogger): self, start_time_key: str, current_time: float, ttl_seconds: int ) -> float: """ - Get existing budget start time or set a new one + Checks if the key = `provider_budget_start_time:{provider}` exists in cache. + + If it does, return the value. + If it does not, set the key to `current_time` and return the value. """ budget_start = await self.router_cache.async_get_cache(start_time_key) if budget_start is None: @@ -202,7 +205,18 @@ class ProviderBudgetLimiting(CustomLogger): response_cost: float, ttl_seconds: int, ) -> float: - """Handle start of new budget window by resetting spend and start time""" + """ + Handle start of new budget window by resetting spend and start time + + Enters this when: + - The budget does not exist in cache, so we need to set it + - The budget window has expired, so we need to reset everything + + Does 2 things: + - stores key: `provider_spend:{provider}:1d`, value: response_cost + - stores key: `provider_budget_start_time:{provider}`, value: current_time. + This stores the start time of the new budget window + """ await self.router_cache.async_set_cache( key=spend_key, value=response_cost, ttl=ttl_seconds ) @@ -214,7 +228,14 @@ class ProviderBudgetLimiting(CustomLogger): async def _increment_spend_in_current_window( self, spend_key: str, response_cost: float, ttl: int ): - """Increment spend within existing budget window""" + """ + Increment spend within existing budget window + + Runs once the budget start time exists in Redis Cache (on the 2nd and subsequent requests to the same provider) + + - Increments the spend in memory cache (so spend instantly updated in memory) + - Queues the increment operation to Redis Pipeline (using batched pipeline to optimize performance. Using Redis for multi instance environment of LiteLLM) + """ await self.router_cache.in_memory_cache.async_increment( key=spend_key, value=response_cost, @@ -305,25 +326,25 @@ class ProviderBudgetLimiting(CustomLogger): await self._sync_in_memory_spend_with_redis() await asyncio.sleep( DEFAULT_REDIS_SYNC_INTERVAL - ) # Wait for 5 seconds before next sync + ) # Wait for DEFAULT_REDIS_SYNC_INTERVAL seconds before next sync except Exception as e: verbose_router_logger.error(f"Error in periodic sync task: {str(e)}") await asyncio.sleep( DEFAULT_REDIS_SYNC_INTERVAL - ) # Still wait 5 seconds on error before retrying + ) # Still wait DEFAULT_REDIS_SYNC_INTERVAL seconds on error before retrying async def _push_in_memory_increments_to_redis(self): """ - This is a latency / speed optimization. - How this works: - - Collect all provider spend increments in `router_cache.in_memory_cache`, done in async_log_success_event - - Push all increments to Redis in this function - - Reset the in-memory `last_synced_values` + - async_log_success_event collects all provider spend increments in `redis_increment_operation_queue` + - This function pushes all increments to Redis in a batched pipeline to optimize performance + + Only runs if Redis is initialized """ try: if not self.router_cache.redis_cache: return # Redis is not initialized + verbose_router_logger.debug( "Pushing Redis Increment Pipeline for queue: %s", self.redis_increment_operation_queue, @@ -347,11 +368,12 @@ class ProviderBudgetLimiting(CustomLogger): Ensures in-memory cache is updated with latest Redis values for all provider spends. Why Do we need this? - - Redis is our source of truth for provider spend - - Optimization to hit ~100ms latency. Performance was impacted when redis was used for read/write per request + - Optimization to hit sub 100ms latency. Performance was impacted when redis was used for read/write per request + - Use provider budgets in multi-instance environment, we use Redis to sync spend across all instances - - In a multi-instance evironment, each instance needs to periodically get the provider spend from Redis to ensure it is consistent across all instances. + What this does: + 1. Push all provider spend increments to Redis + 2. Fetch all current provider spend from Redis to update in-memory cache """ try: @@ -359,11 +381,10 @@ class ProviderBudgetLimiting(CustomLogger): if self.router_cache.redis_cache is None: return - # Push all provider spend increments to Redis + # 1. Push all provider spend increments to Redis await self._push_in_memory_increments_to_redis() - # Handle Reading all current provider spend from Redis in Memory - # Get all providers and their budget configs + # 2. Fetch all current provider spend from Redis to update in-memory cache cache_keys = [] for provider, config in self.provider_budget_config.items(): if config is None: From 4ff941eeba8b0743baeb5e7476099eba9488ac89 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sun, 24 Nov 2024 16:22:32 -0800 Subject: [PATCH 28/29] unit testing for provider budgets --- .../test_router_provider_budgets.py | 182 +++++++++++++++++- 1 file changed, 181 insertions(+), 1 deletion(-) diff --git a/tests/local_testing/test_router_provider_budgets.py b/tests/local_testing/test_router_provider_budgets.py index c16fc5d5c..430550632 100644 --- a/tests/local_testing/test_router_provider_budgets.py +++ b/tests/local_testing/test_router_provider_budgets.py @@ -17,7 +17,7 @@ from litellm.types.router import ( ProviderBudgetConfigType, ProviderBudgetInfo, ) -from litellm.caching.caching import DualCache +from litellm.caching.caching import DualCache, RedisCache import logging from litellm._logging import verbose_router_logger import litellm @@ -296,3 +296,183 @@ async def test_prometheus_metric_tracking(): # Verify the mock was called correctly mock_prometheus.track_provider_remaining_budget.assert_called_once() + + +@pytest.mark.asyncio +async def test_handle_new_budget_window(): + """ + Test _handle_new_budget_window helper method + + Current + """ + cleanup_redis() + provider_budget = ProviderBudgetLimiting( + router_cache=DualCache(), provider_budget_config={} + ) + + spend_key = "provider_spend:openai:7d" + start_time_key = "provider_budget_start_time:openai" + current_time = 1000.0 + response_cost = 0.5 + ttl_seconds = 86400 # 1 day + + # Test handling new budget window + new_start_time = await provider_budget._handle_new_budget_window( + spend_key=spend_key, + start_time_key=start_time_key, + current_time=current_time, + response_cost=response_cost, + ttl_seconds=ttl_seconds, + ) + + assert new_start_time == current_time + + # Verify the spend was set correctly + spend = await provider_budget.router_cache.async_get_cache(spend_key) + print("spend in cache for key", spend_key, "is", spend) + assert float(spend) == response_cost + + # Verify start time was set correctly + start_time = await provider_budget.router_cache.async_get_cache(start_time_key) + print("start time in cache for key", start_time_key, "is", start_time) + assert float(start_time) == current_time + + +@pytest.mark.asyncio +async def test_get_or_set_budget_start_time(): + """ + Test _get_or_set_budget_start_time helper method + + scenario 1: no existing start time in cache, should return current time + scenario 2: existing start time in cache, should return existing start time + """ + cleanup_redis() + provider_budget = ProviderBudgetLimiting( + router_cache=DualCache(), provider_budget_config={} + ) + + start_time_key = "test_start_time" + current_time = 1000.0 + ttl_seconds = 86400 # 1 day + + # When there is no existing start time, we should set it to the current time + start_time = await provider_budget._get_or_set_budget_start_time( + start_time_key=start_time_key, + current_time=current_time, + ttl_seconds=ttl_seconds, + ) + print("budget start time when no existing start time is in cache", start_time) + assert start_time == current_time + + # When there is an existing start time, we should return it even if the current time is later + new_current_time = 2000.0 + existing_start_time = await provider_budget._get_or_set_budget_start_time( + start_time_key=start_time_key, + current_time=new_current_time, + ttl_seconds=ttl_seconds, + ) + print( + "budget start time when existing start time is in cache, but current time is later", + existing_start_time, + ) + assert existing_start_time == current_time # Should return the original start time + + +@pytest.mark.asyncio +async def test_increment_spend_in_current_window(): + """ + Test _increment_spend_in_current_window helper method + + Expected behavior: + - Increment the spend in memory cache + - Queue the increment operation to Redis + """ + cleanup_redis() + provider_budget = ProviderBudgetLimiting( + router_cache=DualCache(), provider_budget_config={} + ) + + spend_key = "provider_spend:openai:1d" + response_cost = 0.5 + ttl = 86400 # 1 day + + # Set initial spend + await provider_budget.router_cache.async_set_cache( + key=spend_key, value=1.0, ttl=ttl + ) + + # Test incrementing spend + await provider_budget._increment_spend_in_current_window( + spend_key=spend_key, + response_cost=response_cost, + ttl=ttl, + ) + + # Verify the spend was incremented correctly in memory + spend = await provider_budget.router_cache.async_get_cache(spend_key) + assert float(spend) == 1.5 + + # Verify the increment operation was queued for Redis + print( + "redis_increment_operation_queue", + provider_budget.redis_increment_operation_queue, + ) + assert len(provider_budget.redis_increment_operation_queue) == 1 + queued_op = provider_budget.redis_increment_operation_queue[0] + assert queued_op["key"] == spend_key + assert queued_op["increment_value"] == response_cost + assert queued_op["ttl"] == ttl + + +@pytest.mark.asyncio +async def test_sync_in_memory_spend_with_redis(): + """ + Test _sync_in_memory_spend_with_redis helper method + + Expected behavior: + - Push all provider spend increments to Redis + - Fetch all current provider spend from Redis to update in-memory cache + """ + cleanup_redis() + provider_budget_config = { + "openai": ProviderBudgetInfo(time_period="1d", budget_limit=100), + "anthropic": ProviderBudgetInfo(time_period="1d", budget_limit=200), + } + + provider_budget = ProviderBudgetLimiting( + router_cache=DualCache( + redis_cache=RedisCache( + host=os.getenv("REDIS_HOST"), + port=int(os.getenv("REDIS_PORT")), + password=os.getenv("REDIS_PASSWORD"), + ) + ), + provider_budget_config=provider_budget_config, + ) + + # Set some values in Redis + spend_key_openai = "provider_spend:openai:1d" + spend_key_anthropic = "provider_spend:anthropic:1d" + + await provider_budget.router_cache.redis_cache.async_set_cache( + key=spend_key_openai, value=50.0 + ) + await provider_budget.router_cache.redis_cache.async_set_cache( + key=spend_key_anthropic, value=75.0 + ) + + # Test syncing with Redis + await provider_budget._sync_in_memory_spend_with_redis() + + # Verify in-memory cache was updated + openai_spend = await provider_budget.router_cache.in_memory_cache.async_get_cache( + spend_key_openai + ) + anthropic_spend = ( + await provider_budget.router_cache.in_memory_cache.async_get_cache( + spend_key_anthropic + ) + ) + + assert float(openai_spend) == 50.0 + assert float(anthropic_spend) == 75.0 From f80f4b0f9ea1796b4e3c0150e306c8e446ca3abb Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sun, 24 Nov 2024 16:31:47 -0800 Subject: [PATCH 29/29] test_redis_increment_pipeline --- tests/local_testing/test_caching.py | 45 +++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/tests/local_testing/test_caching.py b/tests/local_testing/test_caching.py index 222013a86..08da89172 100644 --- a/tests/local_testing/test_caching.py +++ b/tests/local_testing/test_caching.py @@ -2433,3 +2433,48 @@ async def test_dual_cache_caching_batch_get_cache(): await dc.async_batch_get_cache(keys=["test_key1", "test_key2"]) assert mock_async_get_cache.call_count == 1 + + +@pytest.mark.asyncio +async def test_redis_increment_pipeline(): + """Test Redis increment pipeline functionality""" + try: + from litellm.caching.redis_cache import RedisCache + + litellm.set_verbose = True + redis_cache = RedisCache( + host=os.environ["REDIS_HOST"], + port=os.environ["REDIS_PORT"], + password=os.environ["REDIS_PASSWORD"], + ) + + # Create test increment operations + increment_list = [ + {"key": "test_key1", "increment_value": 1.5, "ttl": 60}, + {"key": "test_key1", "increment_value": 1.1, "ttl": 58}, + {"key": "test_key1", "increment_value": 0.4, "ttl": 55}, + {"key": "test_key2", "increment_value": 2.5, "ttl": 60}, + ] + + # Test pipeline increment + results = await redis_cache.async_increment_pipeline(increment_list) + + # Verify results + assert len(results) == 8 # 4 increment operations + 4 expire operations + + # Verify the values were actually set in Redis + value1 = await redis_cache.async_get_cache("test_key1") + print("result in cache for key=test_key1", value1) + value2 = await redis_cache.async_get_cache("test_key2") + print("result in cache for key=test_key2", value2) + + assert float(value1) == 3.0 + assert float(value2) == 2.5 + + # Clean up + await redis_cache.async_delete_cache("test_key1") + await redis_cache.async_delete_cache("test_key2") + + except Exception as e: + print(f"Error occurred: {str(e)}") + raise e