diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 2e131d2b2..e85f116f7 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -984,10 +984,6 @@ class LiteLLM_VerificationToken(LiteLLMBase): org_id: Optional[str] = None # org id for a given key - # hidden params used for parallel request limiting, not required to create a token - user_id_rate_limits: Optional[dict] = None - team_id_rate_limits: Optional[dict] = None - class Config: protected_namespaces = () diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index 0558cdf05..4ba7a2229 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -164,8 +164,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): # check if REQUEST ALLOWED for user_id user_id = user_api_key_dict.user_id if user_id is not None: - _user_id_rate_limits = user_api_key_dict.user_id_rate_limits - + _user_id_rate_limits = await self.user_api_key_cache.async_get_cache( + key=user_id + ) # get user tpm/rpm limits if _user_id_rate_limits is not None and isinstance( _user_id_rate_limits, dict @@ -196,13 +197,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): ## get team tpm/rpm limits team_id = user_api_key_dict.team_id if team_id is not None: - team_tpm_limit = getattr(user_api_key_dict, "team_tpm_limit", sys.maxsize) - - if team_tpm_limit is None: - team_tpm_limit = sys.maxsize - team_rpm_limit = getattr(user_api_key_dict, "team_rpm_limit", sys.maxsize) - if team_rpm_limit is None: - team_rpm_limit = sys.maxsize + team_tpm_limit = user_api_key_dict.team_tpm_limit + team_rpm_limit = user_api_key_dict.team_rpm_limit if team_tpm_limit is None: team_tpm_limit = sys.maxsize diff --git a/litellm/proxy/hooks/tpm_rpm_limiter.py b/litellm/proxy/hooks/tpm_rpm_limiter.py deleted file mode 100644 index 8951991d2..000000000 --- a/litellm/proxy/hooks/tpm_rpm_limiter.py +++ /dev/null @@ -1,379 +0,0 @@ -# What is this? -## Checks TPM/RPM Limits for a key/user/team on the proxy -## Works with Redis - if given - -from typing import Optional, Literal -import litellm, traceback, sys -from litellm.caching import DualCache, RedisCache -from litellm.proxy._types import ( - UserAPIKeyAuth, - LiteLLM_VerificationTokenView, - LiteLLM_UserTable, - LiteLLM_TeamTable, -) -from litellm.integrations.custom_logger import CustomLogger -from fastapi import HTTPException -from litellm._logging import verbose_proxy_logger -from litellm import ModelResponse -from datetime import datetime - - -class _PROXY_MaxTPMRPMLimiter(CustomLogger): - user_api_key_cache = None - - # Class variables or attributes - def __init__(self, internal_cache: Optional[DualCache]): - if internal_cache is None: - self.internal_cache = DualCache() - else: - self.internal_cache = internal_cache - - def print_verbose(self, print_statement): - try: - verbose_proxy_logger.debug(print_statement) - if litellm.set_verbose: - print(print_statement) # noqa - except: - pass - - ## check if admin has set tpm/rpm limits for this key/user/team - - def _check_limits_set( - self, - user_api_key_cache: DualCache, - key: Optional[str], - user_id: Optional[str], - team_id: Optional[str], - ) -> bool: - ## key - if key is not None: - key_val = user_api_key_cache.get_cache(key=key) - if isinstance(key_val, dict): - key_val = LiteLLM_VerificationTokenView(**key_val) - - if isinstance(key_val, LiteLLM_VerificationTokenView): - user_api_key_tpm_limit = key_val.tpm_limit - - user_api_key_rpm_limit = key_val.rpm_limit - - if ( - user_api_key_tpm_limit is not None - or user_api_key_rpm_limit is not None - ): - return True - - ## team - if team_id is not None: - team_val = user_api_key_cache.get_cache(key=team_id) - if isinstance(team_val, dict): - team_val = LiteLLM_TeamTable(**team_val) - - if isinstance(team_val, LiteLLM_TeamTable): - team_tpm_limit = team_val.tpm_limit - - team_rpm_limit = team_val.rpm_limit - - if team_tpm_limit is not None or team_rpm_limit is not None: - return True - - ## user - if user_id is not None: - user_val = user_api_key_cache.get_cache(key=user_id) - if isinstance(user_val, dict): - user_val = LiteLLM_UserTable(**user_val) - - if isinstance(user_val, LiteLLM_UserTable): - user_tpm_limit = user_val.tpm_limit - - user_rpm_limit = user_val.rpm_limit - - if user_tpm_limit is not None or user_rpm_limit is not None: - return True - return False - - async def check_key_in_limits( - self, - user_api_key_dict: UserAPIKeyAuth, - current_minute_dict: dict, - tpm_limit: int, - rpm_limit: int, - request_count_api_key: str, - type: Literal["key", "user", "team"], - ): - - if type == "key" and user_api_key_dict.api_key is not None: - current = current_minute_dict["key"].get(user_api_key_dict.api_key, None) - elif type == "user" and user_api_key_dict.user_id is not None: - current = current_minute_dict["user"].get(user_api_key_dict.user_id, None) - elif type == "team" and user_api_key_dict.team_id is not None: - current = current_minute_dict["team"].get(user_api_key_dict.team_id, None) - else: - return - if current is None: - if tpm_limit == 0 or rpm_limit == 0: - # base case - raise HTTPException( - status_code=429, detail="Max tpm/rpm limit reached." - ) - elif current["current_tpm"] < tpm_limit and current["current_rpm"] < rpm_limit: - pass - else: - raise HTTPException(status_code=429, detail="Max tpm/rpm limit reached.") - - async def async_pre_call_hook( - self, - user_api_key_dict: UserAPIKeyAuth, - cache: DualCache, - data: dict, - call_type: str, - ): - self.print_verbose( - f"Inside Max TPM/RPM Limiter Pre-Call Hook - {user_api_key_dict}" - ) - api_key = user_api_key_dict.api_key - # check if REQUEST ALLOWED for user_id - user_id = user_api_key_dict.user_id - ## get team tpm/rpm limits - team_id = user_api_key_dict.team_id - - self.user_api_key_cache = cache - - _set_limits = self._check_limits_set( - user_api_key_cache=cache, key=api_key, user_id=user_id, team_id=team_id - ) - - self.print_verbose(f"_set_limits: {_set_limits}") - - if _set_limits == False: - return - - # ------------ - # Setup values - # ------------ - - current_date = datetime.now().strftime("%Y-%m-%d") - current_hour = datetime.now().strftime("%H") - current_minute = datetime.now().strftime("%M") - precise_minute = f"{current_date}-{current_hour}-{current_minute}" - cache_key = "usage:{}".format(precise_minute) - current_minute_dict = await self.internal_cache.async_get_cache( - key=cache_key - ) # {"usage:{curr_minute}": {"key": {: {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}}}} - - if current_minute_dict is None: - current_minute_dict = {"key": {}, "user": {}, "team": {}} - - if api_key is not None: - tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize) - if tpm_limit is None: - tpm_limit = sys.maxsize - rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize) - if rpm_limit is None: - rpm_limit = sys.maxsize - request_count_api_key = f"{api_key}::{precise_minute}::request_count" - await self.check_key_in_limits( - user_api_key_dict=user_api_key_dict, - current_minute_dict=current_minute_dict, - request_count_api_key=request_count_api_key, - tpm_limit=tpm_limit, - rpm_limit=rpm_limit, - type="key", - ) - - if user_id is not None: - _user_id_rate_limits = user_api_key_dict.user_id_rate_limits - - # get user tpm/rpm limits - if _user_id_rate_limits is not None and isinstance( - _user_id_rate_limits, dict - ): - user_tpm_limit = _user_id_rate_limits.get("tpm_limit", None) - user_rpm_limit = _user_id_rate_limits.get("rpm_limit", None) - if user_tpm_limit is None: - user_tpm_limit = sys.maxsize - if user_rpm_limit is None: - user_rpm_limit = sys.maxsize - - # now do the same tpm/rpm checks - request_count_api_key = f"{user_id}::{precise_minute}::request_count" - # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}") - await self.check_key_in_limits( - user_api_key_dict=user_api_key_dict, - current_minute_dict=current_minute_dict, - request_count_api_key=request_count_api_key, - tpm_limit=user_tpm_limit, - rpm_limit=user_rpm_limit, - type="user", - ) - - # TEAM RATE LIMITS - if team_id is not None: - team_tpm_limit = getattr(user_api_key_dict, "team_tpm_limit", sys.maxsize) - - if team_tpm_limit is None: - team_tpm_limit = sys.maxsize - team_rpm_limit = getattr(user_api_key_dict, "team_rpm_limit", sys.maxsize) - if team_rpm_limit is None: - team_rpm_limit = sys.maxsize - - if team_tpm_limit is None: - team_tpm_limit = sys.maxsize - if team_rpm_limit is None: - team_rpm_limit = sys.maxsize - - # now do the same tpm/rpm checks - request_count_api_key = f"{team_id}::{precise_minute}::request_count" - - # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}") - await self.check_key_in_limits( - user_api_key_dict=user_api_key_dict, - current_minute_dict=current_minute_dict, - request_count_api_key=request_count_api_key, - tpm_limit=team_tpm_limit, - rpm_limit=team_rpm_limit, - type="team", - ) - - return - - async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): - try: - self.print_verbose(f"INSIDE TPM RPM Limiter ASYNC SUCCESS LOGGING") - - user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"] - user_api_key_user_id = kwargs["litellm_params"]["metadata"].get( - "user_api_key_user_id", None - ) - user_api_key_team_id = kwargs["litellm_params"]["metadata"].get( - "user_api_key_team_id", None - ) - _limits_set = self._check_limits_set( - user_api_key_cache=self.user_api_key_cache, - key=user_api_key, - user_id=user_api_key_user_id, - team_id=user_api_key_team_id, - ) - - if _limits_set == False: # don't waste cache calls if no tpm/rpm limits set - return - - # ------------ - # Setup values - # ------------ - - current_date = datetime.now().strftime("%Y-%m-%d") - current_hour = datetime.now().strftime("%H") - current_minute = datetime.now().strftime("%M") - precise_minute = f"{current_date}-{current_hour}-{current_minute}" - - total_tokens = 0 - - if isinstance(response_obj, ModelResponse): - total_tokens = response_obj.usage.total_tokens - - """ - - get value from redis - - increment requests + 1 - - increment tpm + 1 - - increment rpm + 1 - - update value in-memory + redis - """ - cache_key = "usage:{}".format(precise_minute) - if ( - self.internal_cache.redis_cache is not None - ): # get straight from redis if possible - current_minute_dict = ( - await self.internal_cache.redis_cache.async_get_cache( - key=cache_key, - ) - ) # {"usage:{current_minute}": {"key": {}, "team": {}, "user": {}}} - else: - current_minute_dict = await self.internal_cache.async_get_cache( - key=cache_key, - ) - - if current_minute_dict is None: - current_minute_dict = {"key": {}, "user": {}, "team": {}} - - _cache_updated = False # check if a cache update is required. prevent unnecessary rewrites. - - # ------------ - # Update usage - API Key - # ------------ - - if user_api_key is not None: - _cache_updated = True - ## API KEY ## - if user_api_key in current_minute_dict["key"]: - current_key_usage = current_minute_dict["key"][user_api_key] - new_val = { - "current_tpm": current_key_usage["current_tpm"] + total_tokens, - "current_rpm": current_key_usage["current_rpm"] + 1, - } - else: - new_val = { - "current_tpm": total_tokens, - "current_rpm": 1, - } - - current_minute_dict["key"][user_api_key] = new_val - - self.print_verbose( - f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" - ) - - # ------------ - # Update usage - User - # ------------ - if user_api_key_user_id is not None: - _cache_updated = True - total_tokens = 0 - - if isinstance(response_obj, ModelResponse): - total_tokens = response_obj.usage.total_tokens - - if user_api_key_user_id in current_minute_dict["key"]: - current_key_usage = current_minute_dict["key"][user_api_key_user_id] - new_val = { - "current_tpm": current_key_usage["current_tpm"] + total_tokens, - "current_rpm": current_key_usage["current_rpm"] + 1, - } - else: - new_val = { - "current_tpm": total_tokens, - "current_rpm": 1, - } - - current_minute_dict["user"][user_api_key_user_id] = new_val - - # ------------ - # Update usage - Team - # ------------ - if user_api_key_team_id is not None: - _cache_updated = True - total_tokens = 0 - - if isinstance(response_obj, ModelResponse): - total_tokens = response_obj.usage.total_tokens - - if user_api_key_team_id in current_minute_dict["key"]: - current_key_usage = current_minute_dict["key"][user_api_key_team_id] - new_val = { - "current_tpm": current_key_usage["current_tpm"] + total_tokens, - "current_rpm": current_key_usage["current_rpm"] + 1, - } - else: - new_val = { - "current_tpm": total_tokens, - "current_rpm": 1, - } - - current_minute_dict["team"][user_api_key_team_id] = new_val - - if _cache_updated == True: - await self.internal_cache.async_set_cache( - key=cache_key, value=current_minute_dict - ) - - except Exception as e: - self.print_verbose("{}\n{}".format(e, traceback.format_exc())) # noqa diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 6f1a3e557..1bdb5edba 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -397,6 +397,7 @@ def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict: def get_custom_headers( *, + user_api_key_dict: UserAPIKeyAuth, model_id: Optional[str] = None, cache_key: Optional[str] = None, api_base: Optional[str] = None, @@ -410,6 +411,8 @@ def get_custom_headers( "x-litellm-model-api-base": api_base, "x-litellm-version": version, "x-litellm-model-region": model_region, + "x-litellm-key-tpm-limit": str(user_api_key_dict.tpm_limit), + "x-litellm-key-rpm-limit": str(user_api_key_dict.rpm_limit), } try: return { @@ -4059,6 +4062,7 @@ async def chat_completion( "stream" in data and data["stream"] == True ): # use generate_responses to stream responses custom_headers = get_custom_headers( + user_api_key_dict=user_api_key_dict, model_id=model_id, cache_key=cache_key, api_base=api_base, @@ -4078,6 +4082,7 @@ async def chat_completion( fastapi_response.headers.update( get_custom_headers( + user_api_key_dict=user_api_key_dict, model_id=model_id, cache_key=cache_key, api_base=api_base, @@ -4298,6 +4303,7 @@ async def completion( "stream" in data and data["stream"] == True ): # use generate_responses to stream responses custom_headers = get_custom_headers( + user_api_key_dict=user_api_key_dict, model_id=model_id, cache_key=cache_key, api_base=api_base, @@ -4316,6 +4322,7 @@ async def completion( ) fastapi_response.headers.update( get_custom_headers( + user_api_key_dict=user_api_key_dict, model_id=model_id, cache_key=cache_key, api_base=api_base, @@ -4565,6 +4572,7 @@ async def embeddings( fastapi_response.headers.update( get_custom_headers( + user_api_key_dict=user_api_key_dict, model_id=model_id, cache_key=cache_key, api_base=api_base, @@ -4748,6 +4756,7 @@ async def image_generation( fastapi_response.headers.update( get_custom_headers( + user_api_key_dict=user_api_key_dict, model_id=model_id, cache_key=cache_key, api_base=api_base, @@ -4949,6 +4958,7 @@ async def audio_transcriptions( fastapi_response.headers.update( get_custom_headers( + user_api_key_dict=user_api_key_dict, model_id=model_id, cache_key=cache_key, api_base=api_base, @@ -5132,6 +5142,7 @@ async def moderations( fastapi_response.headers.update( get_custom_headers( + user_api_key_dict=user_api_key_dict, model_id=model_id, cache_key=cache_key, api_base=api_base, diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 2bca287e2..709ddbd3d 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -35,7 +35,6 @@ from litellm import ( ) from litellm.utils import ModelResponseIterator from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter -from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck from litellm.integrations.custom_logger import CustomLogger from litellm.proxy.db.base_client import CustomDB @@ -81,9 +80,6 @@ class ProxyLogging: self.call_details["user_api_key_cache"] = user_api_key_cache self.internal_usage_cache = DualCache() self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler() - self.max_tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter( - internal_cache=self.internal_usage_cache - ) self.max_budget_limiter = _PROXY_MaxBudgetLimiter() self.cache_control_check = _PROXY_CacheControlCheck() self.alerting: Optional[List] = None diff --git a/litellm/tests/test_max_tpm_rpm_limiter.py b/litellm/tests/test_max_tpm_rpm_limiter.py index fbaf30c59..43489d5d9 100644 --- a/litellm/tests/test_max_tpm_rpm_limiter.py +++ b/litellm/tests/test_max_tpm_rpm_limiter.py @@ -1,162 +1,163 @@ +### REPLACED BY 'test_parallel_request_limiter.py' ### # What is this? ## Unit tests for the max tpm / rpm limiter hook for proxy -import sys, os, asyncio, time, random -from datetime import datetime -import traceback -from dotenv import load_dotenv -from typing import Optional +# import sys, os, asyncio, time, random +# from datetime import datetime +# import traceback +# from dotenv import load_dotenv +# from typing import Optional -load_dotenv() -import os +# load_dotenv() +# import os -sys.path.insert( - 0, os.path.abspath("../..") -) # Adds the parent directory to the system path -import pytest -import litellm -from litellm import Router -from litellm.proxy.utils import ProxyLogging, hash_token -from litellm.proxy._types import UserAPIKeyAuth -from litellm.caching import DualCache, RedisCache -from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter -from datetime import datetime +# sys.path.insert( +# 0, os.path.abspath("../..") +# ) # Adds the parent directory to the system path +# import pytest +# import litellm +# from litellm import Router +# from litellm.proxy.utils import ProxyLogging, hash_token +# from litellm.proxy._types import UserAPIKeyAuth +# from litellm.caching import DualCache, RedisCache +# from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter +# from datetime import datetime -@pytest.mark.asyncio -async def test_pre_call_hook_rpm_limits(): - """ - Test if error raised on hitting rpm limits - """ - litellm.set_verbose = True - _api_key = hash_token("sk-12345") - user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, tpm_limit=9, rpm_limit=1) - local_cache = DualCache() - # redis_usage_cache = RedisCache() +# @pytest.mark.asyncio +# async def test_pre_call_hook_rpm_limits(): +# """ +# Test if error raised on hitting rpm limits +# """ +# litellm.set_verbose = True +# _api_key = hash_token("sk-12345") +# user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, tpm_limit=9, rpm_limit=1) +# local_cache = DualCache() +# # redis_usage_cache = RedisCache() - local_cache.set_cache( - key=_api_key, value={"api_key": _api_key, "tpm_limit": 9, "rpm_limit": 1} - ) +# local_cache.set_cache( +# key=_api_key, value={"api_key": _api_key, "tpm_limit": 9, "rpm_limit": 1} +# ) - tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(internal_cache=DualCache()) +# tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(internal_cache=DualCache()) - await tpm_rpm_limiter.async_pre_call_hook( - user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" - ) +# await tpm_rpm_limiter.async_pre_call_hook( +# user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" +# ) - kwargs = {"litellm_params": {"metadata": {"user_api_key": _api_key}}} +# kwargs = {"litellm_params": {"metadata": {"user_api_key": _api_key}}} - await tpm_rpm_limiter.async_log_success_event( - kwargs=kwargs, - response_obj="", - start_time="", - end_time="", - ) +# await tpm_rpm_limiter.async_log_success_event( +# kwargs=kwargs, +# response_obj="", +# start_time="", +# end_time="", +# ) - ## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1} +# ## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1} - try: - await tpm_rpm_limiter.async_pre_call_hook( - user_api_key_dict=user_api_key_dict, - cache=local_cache, - data={}, - call_type="", - ) +# try: +# await tpm_rpm_limiter.async_pre_call_hook( +# user_api_key_dict=user_api_key_dict, +# cache=local_cache, +# data={}, +# call_type="", +# ) - pytest.fail(f"Expected call to fail") - except Exception as e: - assert e.status_code == 429 +# pytest.fail(f"Expected call to fail") +# except Exception as e: +# assert e.status_code == 429 -@pytest.mark.asyncio -async def test_pre_call_hook_team_rpm_limits( - _redis_usage_cache: Optional[RedisCache] = None, -): - """ - Test if error raised on hitting team rpm limits - """ - litellm.set_verbose = True - _api_key = "sk-12345" - _team_id = "unique-team-id" - _user_api_key_dict = { - "api_key": _api_key, - "max_parallel_requests": 1, - "tpm_limit": 9, - "rpm_limit": 10, - "team_rpm_limit": 1, - "team_id": _team_id, - } - user_api_key_dict = UserAPIKeyAuth(**_user_api_key_dict) # type: ignore - _api_key = hash_token(_api_key) - local_cache = DualCache() - local_cache.set_cache(key=_api_key, value=_user_api_key_dict) - internal_cache = DualCache(redis_cache=_redis_usage_cache) - tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(internal_cache=internal_cache) - await tpm_rpm_limiter.async_pre_call_hook( - user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" - ) +# @pytest.mark.asyncio +# async def test_pre_call_hook_team_rpm_limits( +# _redis_usage_cache: Optional[RedisCache] = None, +# ): +# """ +# Test if error raised on hitting team rpm limits +# """ +# litellm.set_verbose = True +# _api_key = "sk-12345" +# _team_id = "unique-team-id" +# _user_api_key_dict = { +# "api_key": _api_key, +# "max_parallel_requests": 1, +# "tpm_limit": 9, +# "rpm_limit": 10, +# "team_rpm_limit": 1, +# "team_id": _team_id, +# } +# user_api_key_dict = UserAPIKeyAuth(**_user_api_key_dict) # type: ignore +# _api_key = hash_token(_api_key) +# local_cache = DualCache() +# local_cache.set_cache(key=_api_key, value=_user_api_key_dict) +# internal_cache = DualCache(redis_cache=_redis_usage_cache) +# tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(internal_cache=internal_cache) +# await tpm_rpm_limiter.async_pre_call_hook( +# user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" +# ) - kwargs = { - "litellm_params": { - "metadata": {"user_api_key": _api_key, "user_api_key_team_id": _team_id} - } - } +# kwargs = { +# "litellm_params": { +# "metadata": {"user_api_key": _api_key, "user_api_key_team_id": _team_id} +# } +# } - await tpm_rpm_limiter.async_log_success_event( - kwargs=kwargs, - response_obj="", - start_time="", - end_time="", - ) +# await tpm_rpm_limiter.async_log_success_event( +# kwargs=kwargs, +# response_obj="", +# start_time="", +# end_time="", +# ) - print(f"local_cache: {local_cache}") +# print(f"local_cache: {local_cache}") - ## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1} +# ## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1} - try: - await tpm_rpm_limiter.async_pre_call_hook( - user_api_key_dict=user_api_key_dict, - cache=local_cache, - data={}, - call_type="", - ) +# try: +# await tpm_rpm_limiter.async_pre_call_hook( +# user_api_key_dict=user_api_key_dict, +# cache=local_cache, +# data={}, +# call_type="", +# ) - pytest.fail(f"Expected call to fail") - except Exception as e: - assert e.status_code == 429 # type: ignore +# pytest.fail(f"Expected call to fail") +# except Exception as e: +# assert e.status_code == 429 # type: ignore -@pytest.mark.asyncio -async def test_namespace(): - """ - - test if default namespace set via `proxyconfig._init_cache` - - respected for tpm/rpm caching - """ - from litellm.proxy.proxy_server import ProxyConfig +# @pytest.mark.asyncio +# async def test_namespace(): +# """ +# - test if default namespace set via `proxyconfig._init_cache` +# - respected for tpm/rpm caching +# """ +# from litellm.proxy.proxy_server import ProxyConfig - redis_usage_cache: Optional[RedisCache] = None - cache_params = {"type": "redis", "namespace": "litellm_default"} +# redis_usage_cache: Optional[RedisCache] = None +# cache_params = {"type": "redis", "namespace": "litellm_default"} - ## INIT CACHE ## - proxy_config = ProxyConfig() - setattr(litellm.proxy.proxy_server, "proxy_config", proxy_config) +# ## INIT CACHE ## +# proxy_config = ProxyConfig() +# setattr(litellm.proxy.proxy_server, "proxy_config", proxy_config) - proxy_config._init_cache(cache_params=cache_params) +# proxy_config._init_cache(cache_params=cache_params) - redis_cache: Optional[RedisCache] = getattr( - litellm.proxy.proxy_server, "redis_usage_cache" - ) +# redis_cache: Optional[RedisCache] = getattr( +# litellm.proxy.proxy_server, "redis_usage_cache" +# ) - ## CHECK IF NAMESPACE SET ## - assert redis_cache.namespace == "litellm_default" +# ## CHECK IF NAMESPACE SET ## +# assert redis_cache.namespace == "litellm_default" - ## CHECK IF TPM/RPM RATE LIMITING WORKS ## - await test_pre_call_hook_team_rpm_limits(_redis_usage_cache=redis_cache) - current_date = datetime.now().strftime("%Y-%m-%d") - current_hour = datetime.now().strftime("%H") - current_minute = datetime.now().strftime("%M") - precise_minute = f"{current_date}-{current_hour}-{current_minute}" +# ## CHECK IF TPM/RPM RATE LIMITING WORKS ## +# await test_pre_call_hook_team_rpm_limits(_redis_usage_cache=redis_cache) +# current_date = datetime.now().strftime("%Y-%m-%d") +# current_hour = datetime.now().strftime("%H") +# current_minute = datetime.now().strftime("%M") +# precise_minute = f"{current_date}-{current_hour}-{current_minute}" - cache_key = "litellm_default:usage:{}".format(precise_minute) - value = await redis_cache.async_get_cache(key=cache_key) - assert value is not None +# cache_key = "litellm_default:usage:{}".format(precise_minute) +# value = await redis_cache.async_get_cache(key=cache_key) +# assert value is not None diff --git a/litellm/tests/test_parallel_request_limiter.py b/litellm/tests/test_parallel_request_limiter.py index 00da199d9..94652c2a6 100644 --- a/litellm/tests/test_parallel_request_limiter.py +++ b/litellm/tests/test_parallel_request_limiter.py @@ -229,17 +229,21 @@ async def test_pre_call_hook_user_tpm_limits(): """ Test if error raised on hitting tpm limits """ + local_cache = DualCache() # create user with tpm/rpm limits + user_id = "test-user" + user_obj = {"tpm_limit": 9, "rpm_limit": 10} + + local_cache.set_cache(key=user_id, value=user_obj) _api_key = "sk-12345" user_api_key_dict = UserAPIKeyAuth( api_key=_api_key, - user_id="ishaan", - user_id_rate_limits={"tpm_limit": 9, "rpm_limit": 10}, + user_id=user_id, ) res = dict(user_api_key_dict) print("dict user", res) - local_cache = DualCache() + parallel_request_handler = MaxParallelRequestsHandler() await parallel_request_handler.async_pre_call_hook( @@ -248,7 +252,7 @@ async def test_pre_call_hook_user_tpm_limits(): kwargs = { "litellm_params": { - "metadata": {"user_api_key_user_id": "ishaan", "user_api_key": "gm"} + "metadata": {"user_api_key_user_id": user_id, "user_api_key": "gm"} } } @@ -734,7 +738,7 @@ async def test_bad_router_call(): request_count_api_key = f"{_api_key}::{precise_minute}::request_count" assert ( - parallel_request_handler.user_api_key_cache.get_cache( + parallel_request_handler.user_api_key_cache.get_cache( # type: ignore key=request_count_api_key )["current_requests"] == 1 @@ -751,7 +755,7 @@ async def test_bad_router_call(): except: pass assert ( - parallel_request_handler.user_api_key_cache.get_cache( + parallel_request_handler.user_api_key_cache.get_cache( # type: ignore key=request_count_api_key )["current_requests"] == 0