diff --git a/docs/my-website/docs/proxy/provider_budget_routing.md b/docs/my-website/docs/proxy/provider_budget_routing.md index 293f9e9d8..1cb75d667 100644 --- a/docs/my-website/docs/proxy/provider_budget_routing.md +++ b/docs/my-website/docs/proxy/provider_budget_routing.md @@ -16,25 +16,27 @@ model_list: api_key: os.environ/OPENAI_API_KEY router_settings: - redis_host: - redis_password: - redis_port: provider_budget_config: - openai: - budget_limit: 0.000000000001 # float of $ value budget for time period - time_period: 1d # can be 1d, 2d, 30d, 1mo, 2mo - azure: - budget_limit: 100 - time_period: 1d - anthropic: - budget_limit: 100 - time_period: 10d - vertex_ai: - budget_limit: 100 - time_period: 12d - gemini: - budget_limit: 100 - time_period: 12d + openai: + budget_limit: 0.000000000001 # float of $ value budget for time period + time_period: 1d # can be 1d, 2d, 30d, 1mo, 2mo + azure: + budget_limit: 100 + time_period: 1d + anthropic: + budget_limit: 100 + time_period: 10d + vertex_ai: + budget_limit: 100 + time_period: 12d + gemini: + budget_limit: 100 + time_period: 12d + + # OPTIONAL: Set Redis Host, Port, and Password if using multiple instance of LiteLLM + redis_host: os.environ/REDIS_HOST + redis_port: os.environ/REDIS_PORT + redis_password: os.environ/REDIS_PASSWORD general_settings: master_key: sk-1234 @@ -132,6 +134,31 @@ This metric indicates the remaining budget for a provider in dollars (USD) litellm_provider_remaining_budget_metric{api_provider="openai"} 10 ``` +## Multi-instance setup + +If you are using a multi-instance setup, you will need to set the Redis host, port, and password in the `proxy_config.yaml` file. Redis is used to sync the spend across LiteLLM instances. + +```yaml +model_list: + - model_name: gpt-3.5-turbo + litellm_params: + model: openai/gpt-3.5-turbo + api_key: os.environ/OPENAI_API_KEY + +router_settings: + provider_budget_config: + openai: + budget_limit: 0.000000000001 # float of $ value budget for time period + time_period: 1d # can be 1d, 2d, 30d, 1mo, 2mo + + # 👇 Add this: Set Redis Host, Port, and Password if using multiple instance of LiteLLM + redis_host: os.environ/REDIS_HOST + redis_port: os.environ/REDIS_PORT + redis_password: os.environ/REDIS_PASSWORD + +general_settings: + master_key: sk-1234 +``` ## Spec for provider_budget_config diff --git a/litellm/caching/redis_cache.py b/litellm/caching/redis_cache.py index e15a3f83d..ba5c3a695 100644 --- a/litellm/caching/redis_cache.py +++ b/litellm/caching/redis_cache.py @@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, Tuple import litellm from litellm._logging import print_verbose, verbose_logger from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs +from litellm.types.caching import RedisPipelineIncrementOperation from litellm.types.services import ServiceLoggerPayload, ServiceTypes from litellm.types.utils import all_litellm_params @@ -890,3 +891,92 @@ class RedisCache(BaseCache): def delete_cache(self, key): self.redis_client.delete(key) + + async def _pipeline_increment_helper( + self, + pipe: pipeline, + increment_list: List[RedisPipelineIncrementOperation], + ) -> Optional[List[float]]: + """Helper function for pipeline increment operations""" + # Iterate through each increment operation and add commands to pipeline + for increment_op in increment_list: + cache_key = self.check_and_fix_namespace(key=increment_op["key"]) + print_verbose( + f"Increment ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {increment_op['increment_value']}\nttl={increment_op['ttl']}" + ) + pipe.incrbyfloat(cache_key, increment_op["increment_value"]) + if increment_op["ttl"] is not None: + _td = timedelta(seconds=increment_op["ttl"]) + pipe.expire(cache_key, _td) + # Execute the pipeline and return results + results = await pipe.execute() + print_verbose(f"Increment ASYNC Redis Cache PIPELINE: results: {results}") + return results + + async def async_increment_pipeline( + self, increment_list: List[RedisPipelineIncrementOperation], **kwargs + ) -> Optional[List[float]]: + """ + Use Redis Pipelines for bulk increment operations + Args: + increment_list: List of RedisPipelineIncrementOperation dicts containing: + - key: str + - increment_value: float + - ttl_seconds: int + """ + # don't waste a network request if there's nothing to increment + if len(increment_list) == 0: + return None + + from redis.asyncio import Redis + + _redis_client: Redis = self.init_async_client() # type: ignore + start_time = time.time() + + print_verbose( + f"Increment Async Redis Cache Pipeline: increment list: {increment_list}" + ) + + try: + async with _redis_client as redis_client: + async with redis_client.pipeline(transaction=True) as pipe: + results = await self._pipeline_increment_helper( + pipe, increment_list + ) + + print_verbose(f"pipeline increment results: {results}") + + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_success_hook( + service=ServiceTypes.REDIS, + duration=_duration, + call_type="async_increment_pipeline", + start_time=start_time, + end_time=end_time, + parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + ) + ) + return results + except Exception as e: + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_failure_hook( + service=ServiceTypes.REDIS, + duration=_duration, + error=e, + call_type="async_increment_pipeline", + start_time=start_time, + end_time=end_time, + parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + ) + ) + verbose_logger.error( + "LiteLLM Redis Caching: async increment_pipeline() - Got exception from REDIS %s", + str(e), + ) + raise e diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 956a17a75..13fb1bcbe 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -2,8 +2,23 @@ model_list: - model_name: gpt-4o litellm_params: model: openai/gpt-4o - api_key: os.environ/OPENAI_API_KEY + api_base: https://exampleopenaiendpoint-production.up.railway.app/ + - model_name: fake-anthropic-endpoint + litellm_params: + model: anthropic/fake + api_base: https://exampleanthropicendpoint-production.up.railway.app/ -default_vertex_config: - vertex_project: "adroit-crow-413218" - vertex_location: "us-central1" +router_settings: + provider_budget_config: + openai: + budget_limit: 0.3 # float of $ value budget for time period + time_period: 1d # can be 1d, 2d, 30d + anthropic: + budget_limit: 5 + time_period: 1d + redis_host: os.environ/REDIS_HOST + redis_port: os.environ/REDIS_PORT + redis_password: os.environ/REDIS_PASSWORD + +litellm_settings: + callbacks: ["prometheus"] \ No newline at end of file diff --git a/litellm/router_strategy/provider_budgets.py b/litellm/router_strategy/provider_budgets.py index ea26d2c0f..f4dc1ba94 100644 --- a/litellm/router_strategy/provider_budgets.py +++ b/litellm/router_strategy/provider_budgets.py @@ -18,11 +18,14 @@ anthropic: ``` """ +import asyncio +from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict, Union import litellm from litellm._logging import verbose_router_logger from litellm.caching.caching import DualCache +from litellm.caching.redis_cache import RedisPipelineIncrementOperation from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs from litellm.litellm_core_utils.duration_parser import duration_in_seconds @@ -44,10 +47,14 @@ if TYPE_CHECKING: else: Span = Any +DEFAULT_REDIS_SYNC_INTERVAL = 1 + class ProviderBudgetLimiting(CustomLogger): def __init__(self, router_cache: DualCache, provider_budget_config: dict): self.router_cache = router_cache + self.redis_increment_operation_queue: List[RedisPipelineIncrementOperation] = [] + asyncio.create_task(self.periodic_sync_in_memory_spend_with_redis()) # cast elements of provider_budget_config to ProviderBudgetInfo for provider, config in provider_budget_config.items(): @@ -173,19 +180,76 @@ class ProviderBudgetLimiting(CustomLogger): return potential_deployments + async def _get_or_set_budget_start_time( + self, start_time_key: str, current_time: float, ttl_seconds: int + ) -> float: + """ + Checks if the key = `provider_budget_start_time:{provider}` exists in cache. + + If it does, return the value. + If it does not, set the key to `current_time` and return the value. + """ + budget_start = await self.router_cache.async_get_cache(start_time_key) + if budget_start is None: + await self.router_cache.async_set_cache( + key=start_time_key, value=current_time, ttl=ttl_seconds + ) + return current_time + return float(budget_start) + + async def _handle_new_budget_window( + self, + spend_key: str, + start_time_key: str, + current_time: float, + response_cost: float, + ttl_seconds: int, + ) -> float: + """ + Handle start of new budget window by resetting spend and start time + + Enters this when: + - The budget does not exist in cache, so we need to set it + - The budget window has expired, so we need to reset everything + + Does 2 things: + - stores key: `provider_spend:{provider}:1d`, value: response_cost + - stores key: `provider_budget_start_time:{provider}`, value: current_time. + This stores the start time of the new budget window + """ + await self.router_cache.async_set_cache( + key=spend_key, value=response_cost, ttl=ttl_seconds + ) + await self.router_cache.async_set_cache( + key=start_time_key, value=current_time, ttl=ttl_seconds + ) + return current_time + + async def _increment_spend_in_current_window( + self, spend_key: str, response_cost: float, ttl: int + ): + """ + Increment spend within existing budget window + + Runs once the budget start time exists in Redis Cache (on the 2nd and subsequent requests to the same provider) + + - Increments the spend in memory cache (so spend instantly updated in memory) + - Queues the increment operation to Redis Pipeline (using batched pipeline to optimize performance. Using Redis for multi instance environment of LiteLLM) + """ + await self.router_cache.in_memory_cache.async_increment( + key=spend_key, + value=response_cost, + ttl=ttl, + ) + increment_op = RedisPipelineIncrementOperation( + key=spend_key, + increment_value=response_cost, + ttl=ttl, + ) + self.redis_increment_operation_queue.append(increment_op) + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): - """ - Increment provider spend in DualCache (InMemory + Redis) - - Handles saving current provider spend to Redis. - - Spend is stored as: - provider_spend:{provider}:{time_period} - ex. provider_spend:openai:1d - ex. provider_spend:anthropic:7d - - The time period is tracked for time_periods set in the provider budget config. - """ + """Original method now uses helper functions""" verbose_router_logger.debug("in ProviderBudgetLimiting.async_log_success_event") standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( "standard_logging_object", None @@ -208,20 +272,146 @@ class ProviderBudgetLimiting(CustomLogger): ) spend_key = f"provider_spend:{custom_llm_provider}:{budget_config.time_period}" - ttl_seconds = duration_in_seconds(duration=budget_config.time_period) + start_time_key = f"provider_budget_start_time:{custom_llm_provider}" + + current_time = datetime.now(timezone.utc).timestamp() + ttl_seconds = duration_in_seconds(budget_config.time_period) + + budget_start = await self._get_or_set_budget_start_time( + start_time_key=start_time_key, + current_time=current_time, + ttl_seconds=ttl_seconds, + ) + + if budget_start is None: + # First spend for this provider + budget_start = await self._handle_new_budget_window( + spend_key=spend_key, + start_time_key=start_time_key, + current_time=current_time, + response_cost=response_cost, + ttl_seconds=ttl_seconds, + ) + elif (current_time - budget_start) > ttl_seconds: + # Budget window expired - reset everything + verbose_router_logger.debug("Budget window expired - resetting everything") + budget_start = await self._handle_new_budget_window( + spend_key=spend_key, + start_time_key=start_time_key, + current_time=current_time, + response_cost=response_cost, + ttl_seconds=ttl_seconds, + ) + else: + # Within existing window - increment spend + remaining_time = ttl_seconds - (current_time - budget_start) + ttl_for_increment = int(remaining_time) + + await self._increment_spend_in_current_window( + spend_key=spend_key, response_cost=response_cost, ttl=ttl_for_increment + ) + verbose_router_logger.debug( - f"Incrementing spend for {spend_key} by {response_cost}, ttl: {ttl_seconds}" - ) - # Increment the spend in Redis and set TTL - await self.router_cache.async_increment_cache( - key=spend_key, - value=response_cost, - ttl=ttl_seconds, - ) - verbose_router_logger.debug( - f"Incremented spend for {spend_key} by {response_cost}, ttl: {ttl_seconds}" + f"Incremented spend for {spend_key} by {response_cost}" ) + async def periodic_sync_in_memory_spend_with_redis(self): + """ + Handler that triggers sync_in_memory_spend_with_redis every DEFAULT_REDIS_SYNC_INTERVAL seconds + + Required for multi-instance environment usage of provider budgets + """ + while True: + try: + await self._sync_in_memory_spend_with_redis() + await asyncio.sleep( + DEFAULT_REDIS_SYNC_INTERVAL + ) # Wait for DEFAULT_REDIS_SYNC_INTERVAL seconds before next sync + except Exception as e: + verbose_router_logger.error(f"Error in periodic sync task: {str(e)}") + await asyncio.sleep( + DEFAULT_REDIS_SYNC_INTERVAL + ) # Still wait DEFAULT_REDIS_SYNC_INTERVAL seconds on error before retrying + + async def _push_in_memory_increments_to_redis(self): + """ + How this works: + - async_log_success_event collects all provider spend increments in `redis_increment_operation_queue` + - This function pushes all increments to Redis in a batched pipeline to optimize performance + + Only runs if Redis is initialized + """ + try: + if not self.router_cache.redis_cache: + return # Redis is not initialized + + verbose_router_logger.debug( + "Pushing Redis Increment Pipeline for queue: %s", + self.redis_increment_operation_queue, + ) + if len(self.redis_increment_operation_queue) > 0: + asyncio.create_task( + self.router_cache.redis_cache.async_increment_pipeline( + increment_list=self.redis_increment_operation_queue, + ) + ) + + self.redis_increment_operation_queue = [] + + except Exception as e: + verbose_router_logger.error( + f"Error syncing in-memory cache with Redis: {str(e)}" + ) + + async def _sync_in_memory_spend_with_redis(self): + """ + Ensures in-memory cache is updated with latest Redis values for all provider spends. + + Why Do we need this? + - Optimization to hit sub 100ms latency. Performance was impacted when redis was used for read/write per request + - Use provider budgets in multi-instance environment, we use Redis to sync spend across all instances + + What this does: + 1. Push all provider spend increments to Redis + 2. Fetch all current provider spend from Redis to update in-memory cache + """ + + try: + # No need to sync if Redis cache is not initialized + if self.router_cache.redis_cache is None: + return + + # 1. Push all provider spend increments to Redis + await self._push_in_memory_increments_to_redis() + + # 2. Fetch all current provider spend from Redis to update in-memory cache + cache_keys = [] + for provider, config in self.provider_budget_config.items(): + if config is None: + continue + cache_keys.append(f"provider_spend:{provider}:{config.time_period}") + + # Batch fetch current spend values from Redis + redis_values = await self.router_cache.redis_cache.async_batch_get_cache( + key_list=cache_keys + ) + + # Update in-memory cache with Redis values + if isinstance(redis_values, dict): # Check if redis_values is a dictionary + for key, value in redis_values.items(): + if value is not None: + await self.router_cache.in_memory_cache.async_set_cache( + key=key, value=float(value) + ) + verbose_router_logger.debug( + f"Updated in-memory cache for {key}: {value}" + ) + + except Exception as e: + verbose_router_logger.error( + f"Error syncing in-memory cache with Redis: {str(e)}" + ) + def _get_budget_config_for_provider( self, provider: str ) -> Optional[ProviderBudgetInfo]: diff --git a/litellm/types/caching.py b/litellm/types/caching.py index 7fca4c041..a6f9de308 100644 --- a/litellm/types/caching.py +++ b/litellm/types/caching.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Literal +from typing import Literal, Optional, TypedDict class LiteLLMCacheType(str, Enum): @@ -23,3 +23,13 @@ CachingSupportedCallTypes = Literal[ "arerank", "rerank", ] + + +class RedisPipelineIncrementOperation(TypedDict): + """ + TypeDict for 1 Redis Pipeline Increment Operation + """ + + key: str + increment_value: float + ttl: Optional[int] diff --git a/tests/local_testing/test_caching.py b/tests/local_testing/test_caching.py index 222013a86..08da89172 100644 --- a/tests/local_testing/test_caching.py +++ b/tests/local_testing/test_caching.py @@ -2433,3 +2433,48 @@ async def test_dual_cache_caching_batch_get_cache(): await dc.async_batch_get_cache(keys=["test_key1", "test_key2"]) assert mock_async_get_cache.call_count == 1 + + +@pytest.mark.asyncio +async def test_redis_increment_pipeline(): + """Test Redis increment pipeline functionality""" + try: + from litellm.caching.redis_cache import RedisCache + + litellm.set_verbose = True + redis_cache = RedisCache( + host=os.environ["REDIS_HOST"], + port=os.environ["REDIS_PORT"], + password=os.environ["REDIS_PASSWORD"], + ) + + # Create test increment operations + increment_list = [ + {"key": "test_key1", "increment_value": 1.5, "ttl": 60}, + {"key": "test_key1", "increment_value": 1.1, "ttl": 58}, + {"key": "test_key1", "increment_value": 0.4, "ttl": 55}, + {"key": "test_key2", "increment_value": 2.5, "ttl": 60}, + ] + + # Test pipeline increment + results = await redis_cache.async_increment_pipeline(increment_list) + + # Verify results + assert len(results) == 8 # 4 increment operations + 4 expire operations + + # Verify the values were actually set in Redis + value1 = await redis_cache.async_get_cache("test_key1") + print("result in cache for key=test_key1", value1) + value2 = await redis_cache.async_get_cache("test_key2") + print("result in cache for key=test_key2", value2) + + assert float(value1) == 3.0 + assert float(value2) == 2.5 + + # Clean up + await redis_cache.async_delete_cache("test_key1") + await redis_cache.async_delete_cache("test_key2") + + except Exception as e: + print(f"Error occurred: {str(e)}") + raise e diff --git a/tests/local_testing/test_router_provider_budgets.py b/tests/local_testing/test_router_provider_budgets.py index a6574ba4b..430550632 100644 --- a/tests/local_testing/test_router_provider_budgets.py +++ b/tests/local_testing/test_router_provider_budgets.py @@ -17,7 +17,7 @@ from litellm.types.router import ( ProviderBudgetConfigType, ProviderBudgetInfo, ) -from litellm.caching.caching import DualCache +from litellm.caching.caching import DualCache, RedisCache import logging from litellm._logging import verbose_router_logger import litellm @@ -25,6 +25,27 @@ import litellm verbose_router_logger.setLevel(logging.DEBUG) +def cleanup_redis(): + """Cleanup Redis cache before each test""" + try: + import redis + + print("cleaning up redis..") + + redis_client = redis.Redis( + host=os.getenv("REDIS_HOST"), + port=int(os.getenv("REDIS_PORT")), + password=os.getenv("REDIS_PASSWORD"), + ) + print("scan iter result", redis_client.scan_iter("provider_spend:*")) + # Delete all provider spend keys + for key in redis_client.scan_iter("provider_spend:*"): + print("deleting key", key) + redis_client.delete(key) + except Exception as e: + print(f"Error cleaning up Redis: {str(e)}") + + @pytest.mark.asyncio async def test_provider_budgets_e2e_test(): """ @@ -34,6 +55,8 @@ async def test_provider_budgets_e2e_test(): - Next 3 requests all go to Azure """ + cleanup_redis() + # Modify for test provider_budget_config: ProviderBudgetConfigType = { "openai": ProviderBudgetInfo(time_period="1d", budget_limit=0.000000000001), "azure": ProviderBudgetInfo(time_period="1d", budget_limit=100), @@ -71,7 +94,7 @@ async def test_provider_budgets_e2e_test(): ) print(response) - await asyncio.sleep(0.5) + await asyncio.sleep(2.5) for _ in range(3): response = await router.acompletion( @@ -94,6 +117,7 @@ async def test_provider_budgets_e2e_test_expect_to_fail(): - first request passes, all subsequent requests fail """ + cleanup_redis() # Note: We intentionally use a dictionary with string keys for budget_limit and time_period # we want to test that the router can handle type conversion, since the proxy config yaml passes these values as a dictionary @@ -125,7 +149,7 @@ async def test_provider_budgets_e2e_test_expect_to_fail(): ) print(response) - await asyncio.sleep(0.5) + await asyncio.sleep(2.5) for _ in range(3): with pytest.raises(Exception) as exc_info: @@ -142,11 +166,13 @@ async def test_provider_budgets_e2e_test_expect_to_fail(): assert "Exceeded budget for provider" in str(exc_info.value) -def test_get_llm_provider_for_deployment(): +@pytest.mark.asyncio +async def test_get_llm_provider_for_deployment(): """ Test the _get_llm_provider_for_deployment helper method """ + cleanup_redis() provider_budget = ProviderBudgetLimiting( router_cache=DualCache(), provider_budget_config={} ) @@ -172,11 +198,13 @@ def test_get_llm_provider_for_deployment(): assert provider_budget._get_llm_provider_for_deployment(unknown_deployment) is None -def test_get_budget_config_for_provider(): +@pytest.mark.asyncio +async def test_get_budget_config_for_provider(): """ Test the _get_budget_config_for_provider helper method """ + cleanup_redis() config = { "openai": ProviderBudgetInfo(time_period="1d", budget_limit=100), "anthropic": ProviderBudgetInfo(time_period="7d", budget_limit=500), @@ -206,6 +234,7 @@ async def test_prometheus_metric_tracking(): """ Test that the Prometheus metric for provider budget is tracked correctly """ + cleanup_redis() from unittest.mock import MagicMock from litellm.integrations.prometheus import PrometheusLogger @@ -263,7 +292,187 @@ async def test_prometheus_metric_tracking(): except Exception as e: print("error", e) - await asyncio.sleep(0.5) + await asyncio.sleep(2.5) # Verify the mock was called correctly mock_prometheus.track_provider_remaining_budget.assert_called_once() + + +@pytest.mark.asyncio +async def test_handle_new_budget_window(): + """ + Test _handle_new_budget_window helper method + + Current + """ + cleanup_redis() + provider_budget = ProviderBudgetLimiting( + router_cache=DualCache(), provider_budget_config={} + ) + + spend_key = "provider_spend:openai:7d" + start_time_key = "provider_budget_start_time:openai" + current_time = 1000.0 + response_cost = 0.5 + ttl_seconds = 86400 # 1 day + + # Test handling new budget window + new_start_time = await provider_budget._handle_new_budget_window( + spend_key=spend_key, + start_time_key=start_time_key, + current_time=current_time, + response_cost=response_cost, + ttl_seconds=ttl_seconds, + ) + + assert new_start_time == current_time + + # Verify the spend was set correctly + spend = await provider_budget.router_cache.async_get_cache(spend_key) + print("spend in cache for key", spend_key, "is", spend) + assert float(spend) == response_cost + + # Verify start time was set correctly + start_time = await provider_budget.router_cache.async_get_cache(start_time_key) + print("start time in cache for key", start_time_key, "is", start_time) + assert float(start_time) == current_time + + +@pytest.mark.asyncio +async def test_get_or_set_budget_start_time(): + """ + Test _get_or_set_budget_start_time helper method + + scenario 1: no existing start time in cache, should return current time + scenario 2: existing start time in cache, should return existing start time + """ + cleanup_redis() + provider_budget = ProviderBudgetLimiting( + router_cache=DualCache(), provider_budget_config={} + ) + + start_time_key = "test_start_time" + current_time = 1000.0 + ttl_seconds = 86400 # 1 day + + # When there is no existing start time, we should set it to the current time + start_time = await provider_budget._get_or_set_budget_start_time( + start_time_key=start_time_key, + current_time=current_time, + ttl_seconds=ttl_seconds, + ) + print("budget start time when no existing start time is in cache", start_time) + assert start_time == current_time + + # When there is an existing start time, we should return it even if the current time is later + new_current_time = 2000.0 + existing_start_time = await provider_budget._get_or_set_budget_start_time( + start_time_key=start_time_key, + current_time=new_current_time, + ttl_seconds=ttl_seconds, + ) + print( + "budget start time when existing start time is in cache, but current time is later", + existing_start_time, + ) + assert existing_start_time == current_time # Should return the original start time + + +@pytest.mark.asyncio +async def test_increment_spend_in_current_window(): + """ + Test _increment_spend_in_current_window helper method + + Expected behavior: + - Increment the spend in memory cache + - Queue the increment operation to Redis + """ + cleanup_redis() + provider_budget = ProviderBudgetLimiting( + router_cache=DualCache(), provider_budget_config={} + ) + + spend_key = "provider_spend:openai:1d" + response_cost = 0.5 + ttl = 86400 # 1 day + + # Set initial spend + await provider_budget.router_cache.async_set_cache( + key=spend_key, value=1.0, ttl=ttl + ) + + # Test incrementing spend + await provider_budget._increment_spend_in_current_window( + spend_key=spend_key, + response_cost=response_cost, + ttl=ttl, + ) + + # Verify the spend was incremented correctly in memory + spend = await provider_budget.router_cache.async_get_cache(spend_key) + assert float(spend) == 1.5 + + # Verify the increment operation was queued for Redis + print( + "redis_increment_operation_queue", + provider_budget.redis_increment_operation_queue, + ) + assert len(provider_budget.redis_increment_operation_queue) == 1 + queued_op = provider_budget.redis_increment_operation_queue[0] + assert queued_op["key"] == spend_key + assert queued_op["increment_value"] == response_cost + assert queued_op["ttl"] == ttl + + +@pytest.mark.asyncio +async def test_sync_in_memory_spend_with_redis(): + """ + Test _sync_in_memory_spend_with_redis helper method + + Expected behavior: + - Push all provider spend increments to Redis + - Fetch all current provider spend from Redis to update in-memory cache + """ + cleanup_redis() + provider_budget_config = { + "openai": ProviderBudgetInfo(time_period="1d", budget_limit=100), + "anthropic": ProviderBudgetInfo(time_period="1d", budget_limit=200), + } + + provider_budget = ProviderBudgetLimiting( + router_cache=DualCache( + redis_cache=RedisCache( + host=os.getenv("REDIS_HOST"), + port=int(os.getenv("REDIS_PORT")), + password=os.getenv("REDIS_PASSWORD"), + ) + ), + provider_budget_config=provider_budget_config, + ) + + # Set some values in Redis + spend_key_openai = "provider_spend:openai:1d" + spend_key_anthropic = "provider_spend:anthropic:1d" + + await provider_budget.router_cache.redis_cache.async_set_cache( + key=spend_key_openai, value=50.0 + ) + await provider_budget.router_cache.redis_cache.async_set_cache( + key=spend_key_anthropic, value=75.0 + ) + + # Test syncing with Redis + await provider_budget._sync_in_memory_spend_with_redis() + + # Verify in-memory cache was updated + openai_spend = await provider_budget.router_cache.in_memory_cache.async_get_cache( + spend_key_openai + ) + anthropic_spend = ( + await provider_budget.router_cache.in_memory_cache.async_get_cache( + spend_key_anthropic + ) + ) + + assert float(openai_spend) == 50.0 + assert float(anthropic_spend) == 75.0