diff --git a/litellm/caching.py b/litellm/caching.py index fe7f53744..9bb03b99a 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -432,6 +432,7 @@ class RedisCache(BaseCache): start_time=start_time, end_time=end_time, parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + event_metadata={"key": key}, ) ) except Exception as e: @@ -446,6 +447,7 @@ class RedisCache(BaseCache): start_time=start_time, end_time=end_time, parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + event_metadata={"key": key}, ) ) # NON blocking - notify users Redis is throwing an exception @@ -753,6 +755,7 @@ class RedisCache(BaseCache): start_time=start_time, end_time=end_time, parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + event_metadata={"key": key}, ) ) return response @@ -769,6 +772,7 @@ class RedisCache(BaseCache): start_time=start_time, end_time=end_time, parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + event_metadata={"key": key}, ) ) # NON blocking - notify users Redis is throwing an exception diff --git a/litellm/litellm_core_utils/core_helpers.py b/litellm/litellm_core_utils/core_helpers.py index 269844ce8..f5619d237 100644 --- a/litellm/litellm_core_utils/core_helpers.py +++ b/litellm/litellm_core_utils/core_helpers.py @@ -1,10 +1,17 @@ # What is this? ## Helper utilities import os -from typing import List, Literal, Optional, Tuple +from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union from litellm._logging import verbose_logger +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any + def map_finish_reason( finish_reason: str, @@ -68,10 +75,12 @@ def get_litellm_metadata_from_kwargs(kwargs: dict): # Helper functions used for OTEL logging -def _get_parent_otel_span_from_kwargs(kwargs: Optional[dict] = None): +def _get_parent_otel_span_from_kwargs( + kwargs: Optional[dict] = None, +) -> Union[Span, None]: try: if kwargs is None: - return None + raise ValueError("kwargs is None") litellm_params = kwargs.get("litellm_params") _metadata = kwargs.get("metadata") or {} if "litellm_parent_otel_span" in _metadata: @@ -84,5 +93,9 @@ def _get_parent_otel_span_from_kwargs(kwargs: Optional[dict] = None): return litellm_params["metadata"]["litellm_parent_otel_span"] elif "litellm_parent_otel_span" in kwargs: return kwargs["litellm_parent_otel_span"] - except: + return None + except Exception as e: + verbose_logger.exception( + "Error in _get_parent_otel_span_from_kwargs: " + str(e) + ) return None diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 70aa3ef39..404b3ffc9 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -440,7 +440,7 @@ async def _cache_management_object( exclude_unset=True, exclude={"parent_otel_span": True} ) await proxy_logging_obj.internal_usage_cache.async_set_cache( - key=key, value=_value + key=key, value=_value, litellm_parent_otel_span=None ) @@ -493,7 +493,9 @@ async def _delete_cache_key_object( ## UPDATE REDIS CACHE ## if proxy_logging_obj is not None: - await proxy_logging_obj.internal_usage_cache.async_delete_cache(key=key) + await proxy_logging_obj.internal_usage_cache.dual_cache.async_delete_cache( + key=key + ) @log_to_opentelemetry @@ -522,12 +524,10 @@ async def get_team_object( ## CHECK REDIS CACHE ## if ( proxy_logging_obj is not None - and proxy_logging_obj.internal_usage_cache.redis_cache is not None + and proxy_logging_obj.internal_usage_cache.dual_cache.redis_cache is not None ): - cached_team_obj = ( - await proxy_logging_obj.internal_usage_cache.redis_cache.async_get_cache( - key=key - ) + cached_team_obj = await proxy_logging_obj.internal_usage_cache.dual_cache.redis_cache.async_get_cache( + key=key ) if cached_team_obj is None: @@ -595,12 +595,10 @@ async def get_key_object( ## CHECK REDIS CACHE ## if ( proxy_logging_obj is not None - and proxy_logging_obj.internal_usage_cache.redis_cache is not None + and proxy_logging_obj.internal_usage_cache.dual_cache.redis_cache is not None ): - cached_team_obj = ( - await proxy_logging_obj.internal_usage_cache.redis_cache.async_get_cache( - key=key - ) + cached_team_obj = await proxy_logging_obj.internal_usage_cache.dual_cache.redis_cache.async_get_cache( + key=key ) if cached_team_obj is None: diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index fd18fbac9..f5c45f3bd 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -1,7 +1,7 @@ import sys import traceback from datetime import datetime, timedelta -from typing import Literal, Optional +from typing import TYPE_CHECKING, Any, Literal, Optional, Union from fastapi import HTTPException @@ -10,17 +10,28 @@ from litellm import ModelResponse from litellm._logging import verbose_proxy_logger from litellm.caching import DualCache from litellm.integrations.custom_logger import CustomLogger +from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.auth.auth_utils import ( get_key_model_rpm_limit, get_key_model_tpm_limit, ) +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache + + Span = _Span + InternalUsageCache = _InternalUsageCache +else: + Span = Any + InternalUsageCache = Any + class _PROXY_MaxParallelRequestsHandler(CustomLogger): - # Class variables or attributes - def __init__(self, internal_usage_cache: DualCache): + def __init__(self, internal_usage_cache: InternalUsageCache): self.internal_usage_cache = internal_usage_cache def print_verbose(self, print_statement): @@ -44,7 +55,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): rate_limit_type: Literal["user", "customer", "team"], ): current = await self.internal_usage_cache.async_get_cache( - key=request_count_api_key + key=request_count_api_key, + litellm_parent_otel_span=user_api_key_dict.parent_otel_span, ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10} if current is None: if max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0: @@ -58,7 +70,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): "current_rpm": 0, } await self.internal_usage_cache.async_set_cache( - request_count_api_key, new_val + key=request_count_api_key, + value=new_val, + litellm_parent_otel_span=user_api_key_dict.parent_otel_span, ) elif ( int(current["current_requests"]) < max_parallel_requests @@ -72,7 +86,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): "current_rpm": current["current_rpm"], } await self.internal_usage_cache.async_set_cache( - request_count_api_key, new_val + key=request_count_api_key, + value=new_val, + litellm_parent_otel_span=user_api_key_dict.parent_otel_span, ) else: raise HTTPException( @@ -135,12 +151,14 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): # ------------ # Setup values # ------------ - + new_val: Optional[dict] = None if global_max_parallel_requests is not None: # get value from cache _key = "global_max_parallel_requests" current_global_requests = await self.internal_usage_cache.async_get_cache( - key=_key, local_only=True + key=_key, + local_only=True, + litellm_parent_otel_span=user_api_key_dict.parent_otel_span, ) # check if below limit if current_global_requests is None: @@ -153,7 +171,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): # if below -> increment else: await self.internal_usage_cache.async_increment_cache( - key=_key, value=1, local_only=True + key=_key, + value=1, + local_only=True, + litellm_parent_otel_span=user_api_key_dict.parent_otel_span, ) current_date = datetime.now().strftime("%Y-%m-%d") @@ -167,7 +188,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): # CHECK IF REQUEST ALLOWED for key current = await self.internal_usage_cache.async_get_cache( - key=request_count_api_key + key=request_count_api_key, + litellm_parent_otel_span=user_api_key_dict.parent_otel_span, ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10} self.print_verbose(f"current: {current}") if ( @@ -187,7 +209,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): "current_rpm": 0, } await self.internal_usage_cache.async_set_cache( - request_count_api_key, new_val + key=request_count_api_key, + value=new_val, + litellm_parent_otel_span=user_api_key_dict.parent_otel_span, ) elif ( int(current["current_requests"]) < max_parallel_requests @@ -201,7 +225,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): "current_rpm": current["current_rpm"], } await self.internal_usage_cache.async_set_cache( - request_count_api_key, new_val + key=request_count_api_key, + value=new_val, + litellm_parent_otel_span=user_api_key_dict.parent_otel_span, ) else: return self.raise_rate_limit_error( @@ -219,7 +245,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): ) current = await self.internal_usage_cache.async_get_cache( - key=request_count_api_key + key=request_count_api_key, + litellm_parent_otel_span=user_api_key_dict.parent_otel_span, ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10} tpm_limit_for_model = None @@ -242,7 +269,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): "current_rpm": 0, } await self.internal_usage_cache.async_set_cache( - request_count_api_key, new_val + key=request_count_api_key, + value=new_val, + litellm_parent_otel_span=user_api_key_dict.parent_otel_span, ) elif tpm_limit_for_model is not None or rpm_limit_for_model is not None: # Increase count for this token @@ -267,16 +296,19 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): ) else: await self.internal_usage_cache.async_set_cache( - request_count_api_key, new_val + key=request_count_api_key, + value=new_val, + litellm_parent_otel_span=user_api_key_dict.parent_otel_span, ) _remaining_tokens = None _remaining_requests = None # Add remaining tokens, requests to metadata - if tpm_limit_for_model is not None: - _remaining_tokens = tpm_limit_for_model - new_val["current_tpm"] - if rpm_limit_for_model is not None: - _remaining_requests = rpm_limit_for_model - new_val["current_rpm"] + if new_val: + if tpm_limit_for_model is not None: + _remaining_tokens = tpm_limit_for_model - new_val["current_tpm"] + if rpm_limit_for_model is not None: + _remaining_requests = rpm_limit_for_model - new_val["current_rpm"] _remaining_limits_data = { f"litellm-key-remaining-tokens-{_model}": _remaining_tokens, @@ -291,7 +323,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): user_id = user_api_key_dict.user_id if user_id is not None: _user_id_rate_limits = await self.internal_usage_cache.async_get_cache( - key=user_id + key=user_id, + litellm_parent_otel_span=user_api_key_dict.parent_otel_span, ) # get user tpm/rpm limits if _user_id_rate_limits is not None and isinstance( @@ -388,6 +421,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): get_model_group_from_litellm_kwargs, ) + litellm_parent_otel_span: Union[Span, None] = _get_parent_otel_span_from_kwargs( + kwargs=kwargs + ) try: self.print_verbose("INSIDE parallel request limiter ASYNC SUCCESS LOGGING") global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get( @@ -416,7 +452,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): _key = "global_max_parallel_requests" # decrement await self.internal_usage_cache.async_increment_cache( - key=_key, value=-1, local_only=True + key=_key, + value=-1, + local_only=True, + litellm_parent_otel_span=litellm_parent_otel_span, ) current_date = datetime.now().strftime("%Y-%m-%d") @@ -427,7 +466,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): total_tokens = 0 if isinstance(response_obj, ModelResponse): - total_tokens = response_obj.usage.total_tokens + total_tokens = response_obj.usage.total_tokens # type: ignore # ------------ # Update usage - API Key @@ -439,7 +478,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): ) current = await self.internal_usage_cache.async_get_cache( - key=request_count_api_key + key=request_count_api_key, + litellm_parent_otel_span=litellm_parent_otel_span, ) or { "current_requests": 1, "current_tpm": total_tokens, @@ -456,7 +496,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" ) await self.internal_usage_cache.async_set_cache( - request_count_api_key, new_val, ttl=60 + request_count_api_key, + new_val, + ttl=60, + litellm_parent_otel_span=litellm_parent_otel_span, ) # store in cache for 1 min. # ------------ @@ -476,7 +519,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): ) current = await self.internal_usage_cache.async_get_cache( - key=request_count_api_key + key=request_count_api_key, + litellm_parent_otel_span=litellm_parent_otel_span, ) or { "current_requests": 1, "current_tpm": total_tokens, @@ -493,7 +537,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" ) await self.internal_usage_cache.async_set_cache( - request_count_api_key, new_val, ttl=60 + request_count_api_key, + new_val, + ttl=60, + litellm_parent_otel_span=litellm_parent_otel_span, ) # ------------ @@ -503,14 +550,15 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): total_tokens = 0 if isinstance(response_obj, ModelResponse): - total_tokens = response_obj.usage.total_tokens + total_tokens = response_obj.usage.total_tokens # type: ignore request_count_api_key = ( f"{user_api_key_user_id}::{precise_minute}::request_count" ) current = await self.internal_usage_cache.async_get_cache( - key=request_count_api_key + key=request_count_api_key, + litellm_parent_otel_span=litellm_parent_otel_span, ) or { "current_requests": 1, "current_tpm": total_tokens, @@ -527,7 +575,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" ) await self.internal_usage_cache.async_set_cache( - request_count_api_key, new_val, ttl=60 + request_count_api_key, + new_val, + ttl=60, + litellm_parent_otel_span=litellm_parent_otel_span, ) # store in cache for 1 min. # ------------ @@ -537,14 +588,15 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): total_tokens = 0 if isinstance(response_obj, ModelResponse): - total_tokens = response_obj.usage.total_tokens + total_tokens = response_obj.usage.total_tokens # type: ignore request_count_api_key = ( f"{user_api_key_team_id}::{precise_minute}::request_count" ) current = await self.internal_usage_cache.async_get_cache( - key=request_count_api_key + key=request_count_api_key, + litellm_parent_otel_span=litellm_parent_otel_span, ) or { "current_requests": 1, "current_tpm": total_tokens, @@ -561,7 +613,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" ) await self.internal_usage_cache.async_set_cache( - request_count_api_key, new_val, ttl=60 + request_count_api_key, + new_val, + ttl=60, + litellm_parent_otel_span=litellm_parent_otel_span, ) # store in cache for 1 min. # ------------ @@ -571,14 +626,15 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): total_tokens = 0 if isinstance(response_obj, ModelResponse): - total_tokens = response_obj.usage.total_tokens + total_tokens = response_obj.usage.total_tokens # type: ignore request_count_api_key = ( f"{user_api_key_end_user_id}::{precise_minute}::request_count" ) current = await self.internal_usage_cache.async_get_cache( - key=request_count_api_key + key=request_count_api_key, + litellm_parent_otel_span=litellm_parent_otel_span, ) or { "current_requests": 1, "current_tpm": total_tokens, @@ -595,7 +651,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" ) await self.internal_usage_cache.async_set_cache( - request_count_api_key, new_val, ttl=60 + request_count_api_key, + new_val, + ttl=60, + litellm_parent_otel_span=litellm_parent_otel_span, ) # store in cache for 1 min. except Exception as e: @@ -604,6 +663,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): try: self.print_verbose("Inside Max Parallel Request Failure Hook") + litellm_parent_otel_span: Union[Span, None] = ( + _get_parent_otel_span_from_kwargs(kwargs=kwargs) + ) _metadata = kwargs["litellm_params"].get("metadata", {}) or {} global_max_parallel_requests = _metadata.get( "global_max_parallel_requests", None @@ -626,12 +688,17 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): _key = "global_max_parallel_requests" current_global_requests = ( await self.internal_usage_cache.async_get_cache( - key=_key, local_only=True + key=_key, + local_only=True, + litellm_parent_otel_span=litellm_parent_otel_span, ) ) # decrement await self.internal_usage_cache.async_increment_cache( - key=_key, value=-1, local_only=True + key=_key, + value=-1, + local_only=True, + litellm_parent_otel_span=litellm_parent_otel_span, ) current_date = datetime.now().strftime("%Y-%m-%d") @@ -647,7 +714,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): # Update usage # ------------ current = await self.internal_usage_cache.async_get_cache( - key=request_count_api_key + key=request_count_api_key, + litellm_parent_otel_span=litellm_parent_otel_span, ) or { "current_requests": 1, "current_tpm": 0, @@ -662,7 +730,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): self.print_verbose(f"updated_value in failure call: {new_val}") await self.internal_usage_cache.async_set_cache( - request_count_api_key, new_val, ttl=60 + request_count_api_key, + new_val, + ttl=60, + litellm_parent_otel_span=litellm_parent_otel_span, ) # save in cache for up to 1 min. except Exception as e: verbose_proxy_logger.exception( diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 4601d4980..69e6b52a3 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -21,7 +21,7 @@ model_list: litellm_settings: cache: true - # callbacks: ["otel"] + callbacks: ["otel"] general_settings: diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 509a316e0..4f139db36 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -205,6 +205,83 @@ def log_to_opentelemetry(func): return wrapper +class InternalUsageCache: + def __init__(self, dual_cache: DualCache): + self.dual_cache: DualCache = dual_cache + + async def async_get_cache( + self, + key, + litellm_parent_otel_span: Union[Span, None], + local_only: bool = False, + **kwargs, + ) -> Any: + return await self.dual_cache.async_get_cache( + key=key, + local_only=local_only, + litellm_parent_otel_span=litellm_parent_otel_span, + **kwargs, + ) + + async def async_set_cache( + self, + key, + value, + litellm_parent_otel_span: Union[Span, None], + local_only: bool = False, + **kwargs, + ) -> None: + return await self.dual_cache.async_set_cache( + key=key, + value=value, + local_only=local_only, + litellm_parent_otel_span=litellm_parent_otel_span, + **kwargs, + ) + + async def async_increment_cache( + self, + key, + value: float, + litellm_parent_otel_span: Union[Span, None], + local_only: bool = False, + **kwargs, + ): + return await self.dual_cache.async_increment_cache( + key=key, + value=value, + local_only=local_only, + litellm_parent_otel_span=litellm_parent_otel_span, + **kwargs, + ) + + def set_cache( + self, + key, + value, + local_only: bool = False, + **kwargs, + ) -> None: + return self.dual_cache.set_cache( + key=key, + value=value, + local_only=local_only, + **kwargs, + ) + + def get_cache( + self, + key, + local_only: bool = False, + **kwargs, + ) -> Any: + return self.dual_cache.get_cache( + key=key, + local_only=local_only, + **kwargs, + ) + + ### LOGGING ### class ProxyLogging: """ @@ -222,9 +299,9 @@ class ProxyLogging: ## INITIALIZE LITELLM CALLBACKS ## self.call_details: dict = {} self.call_details["user_api_key_cache"] = user_api_key_cache - self.internal_usage_cache = DualCache( - default_in_memory_ttl=1 - ) # ping redis cache every 1s + self.internal_usage_cache: InternalUsageCache = InternalUsageCache( + dual_cache=DualCache(default_in_memory_ttl=1) # ping redis cache every 1s + ) self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler( self.internal_usage_cache ) @@ -238,7 +315,7 @@ class ProxyLogging: alerting_threshold=self.alerting_threshold, alerting=self.alerting, alert_types=self.alert_types, - internal_usage_cache=self.internal_usage_cache, + internal_usage_cache=self.internal_usage_cache.dual_cache, ) def update_values( @@ -283,7 +360,7 @@ class ProxyLogging: litellm.callbacks.append(self.slack_alerting_instance) # type: ignore if redis_cache is not None: - self.internal_usage_cache.redis_cache = redis_cache + self.internal_usage_cache.dual_cache.redis_cache = redis_cache def _init_litellm_callbacks(self, llm_router: Optional[litellm.Router] = None): self.service_logging_obj = ServiceLogging() @@ -298,7 +375,7 @@ class ProxyLogging: if isinstance(callback, str): callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore callback, - internal_usage_cache=self.internal_usage_cache, + internal_usage_cache=self.internal_usage_cache.dual_cache, llm_router=llm_router, ) if callback not in litellm.input_callback: @@ -347,6 +424,7 @@ class ProxyLogging: value=status, local_only=True, ttl=alerting_threshold, + litellm_parent_otel_span=None, ) async def process_pre_call_hook_response(self, response, data, call_type): diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index 3dfadd73a..847f17295 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -2045,7 +2045,7 @@ async def test_proxy_logging_setup(): from litellm.proxy.utils import ProxyLogging pl_obj = ProxyLogging(user_api_key_cache=DualCache()) - assert pl_obj.internal_usage_cache.always_read_redis is True + assert pl_obj.internal_usage_cache.dual_cache.always_read_redis is True @pytest.mark.skip(reason="local test. Requires sentinel setup.") diff --git a/litellm/tests/test_parallel_request_limiter.py b/litellm/tests/test_parallel_request_limiter.py index 5173bc9c0..93e86404e 100644 --- a/litellm/tests/test_parallel_request_limiter.py +++ b/litellm/tests/test_parallel_request_limiter.py @@ -28,7 +28,7 @@ from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.hooks.parallel_request_limiter import ( _PROXY_MaxParallelRequestsHandler as MaxParallelRequestsHandler, ) -from litellm.proxy.utils import ProxyLogging, hash_token +from litellm.proxy.utils import InternalUsageCache, ProxyLogging, hash_token ## On Request received ## On Request success @@ -48,7 +48,7 @@ async def test_global_max_parallel_requests(): user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=100) local_cache = DualCache() parallel_request_handler = MaxParallelRequestsHandler( - internal_usage_cache=local_cache + internal_usage_cache=InternalUsageCache(dual_cache=local_cache) ) for _ in range(3): @@ -78,7 +78,7 @@ async def test_pre_call_hook(): user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1) local_cache = DualCache() parallel_request_handler = MaxParallelRequestsHandler( - internal_usage_cache=local_cache + internal_usage_cache=InternalUsageCache(dual_cache=local_cache) ) await parallel_request_handler.async_pre_call_hook( @@ -115,7 +115,7 @@ async def test_pre_call_hook_rpm_limits(): ) local_cache = DualCache() parallel_request_handler = MaxParallelRequestsHandler( - internal_usage_cache=local_cache + internal_usage_cache=InternalUsageCache(dual_cache=local_cache) ) await parallel_request_handler.async_pre_call_hook( @@ -157,7 +157,7 @@ async def test_pre_call_hook_rpm_limits_retry_after(): ) local_cache = DualCache() parallel_request_handler = MaxParallelRequestsHandler( - internal_usage_cache=local_cache + internal_usage_cache=InternalUsageCache(dual_cache=local_cache) ) await parallel_request_handler.async_pre_call_hook( @@ -208,7 +208,7 @@ async def test_pre_call_hook_team_rpm_limits(): ) local_cache = DualCache() parallel_request_handler = MaxParallelRequestsHandler( - internal_usage_cache=local_cache + internal_usage_cache=InternalUsageCache(dual_cache=local_cache) ) await parallel_request_handler.async_pre_call_hook( @@ -256,7 +256,7 @@ async def test_pre_call_hook_tpm_limits(): ) local_cache = DualCache() parallel_request_handler = MaxParallelRequestsHandler( - internal_usage_cache=local_cache + internal_usage_cache=InternalUsageCache(dual_cache=local_cache) ) await parallel_request_handler.async_pre_call_hook( @@ -308,7 +308,7 @@ async def test_pre_call_hook_user_tpm_limits(): print("dict user", res) parallel_request_handler = MaxParallelRequestsHandler( - internal_usage_cache=local_cache + internal_usage_cache=InternalUsageCache(dual_cache=local_cache) ) await parallel_request_handler.async_pre_call_hook( @@ -353,7 +353,7 @@ async def test_success_call_hook(): user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1) local_cache = DualCache() parallel_request_handler = MaxParallelRequestsHandler( - internal_usage_cache=local_cache + internal_usage_cache=InternalUsageCache(dual_cache=local_cache) ) await parallel_request_handler.async_pre_call_hook( @@ -397,7 +397,7 @@ async def test_failure_call_hook(): user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1) local_cache = DualCache() parallel_request_handler = MaxParallelRequestsHandler( - internal_usage_cache=local_cache + internal_usage_cache=InternalUsageCache(dual_cache=local_cache) ) await parallel_request_handler.async_pre_call_hook( @@ -975,7 +975,7 @@ async def test_bad_router_tpm_limit_per_model(): print( "internal usage cache: ", - parallel_request_handler.internal_usage_cache.in_memory_cache.cache_dict, + parallel_request_handler.internal_usage_cache.dual_cache.in_memory_cache.cache_dict, ) assert ( @@ -1161,7 +1161,7 @@ async def test_pre_call_hook_tpm_limits_per_model(): print( "internal usage cache: ", - parallel_request_handler.internal_usage_cache.in_memory_cache.cache_dict, + parallel_request_handler.internal_usage_cache.dual_cache.in_memory_cache.cache_dict, ) assert ( diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index ed179d3e2..7357785e3 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -753,10 +753,10 @@ async def test_team_update_redis(): litellm.proxy.proxy_server, "proxy_logging_obj" ) - proxy_logging_obj.internal_usage_cache.redis_cache = RedisCache() + proxy_logging_obj.internal_usage_cache.dual_cache.redis_cache = RedisCache() with patch.object( - proxy_logging_obj.internal_usage_cache.redis_cache, + proxy_logging_obj.internal_usage_cache.dual_cache.redis_cache, "async_set_cache", new=AsyncMock(), ) as mock_client: @@ -782,10 +782,10 @@ async def test_get_team_redis(client_no_auth): litellm.proxy.proxy_server, "proxy_logging_obj" ) - proxy_logging_obj.internal_usage_cache.redis_cache = RedisCache() + proxy_logging_obj.internal_usage_cache.dual_cache.redis_cache = RedisCache() with patch.object( - proxy_logging_obj.internal_usage_cache.redis_cache, + proxy_logging_obj.internal_usage_cache.dual_cache.redis_cache, "async_get_cache", new=AsyncMock(), ) as mock_client: