From bd956265796981ffeab57e549348d2d4a252c38b Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 30 Mar 2024 19:07:04 -0700 Subject: [PATCH 01/20] (fix) improve async perf --- litellm/utils.py | 41 ++++++++++++++++++++++++++++++++--------- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/litellm/utils.py b/litellm/utils.py index 3ec882f0f..2a765bc41 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1434,9 +1434,7 @@ class Logging: model = self.model kwargs = self.model_call_details - input = kwargs.get( - "messages", kwargs.get("input", None) - ) + input = kwargs.get("messages", kwargs.get("input", None)) type = ( "embed" @@ -1444,7 +1442,7 @@ class Logging: else "llm" ) - # this only logs streaming once, complete_streaming_response exists i.e when stream ends + # this only logs streaming once, complete_streaming_response exists i.e when stream ends if self.stream: if "complete_streaming_response" not in kwargs: break @@ -1458,7 +1456,7 @@ class Logging: model=model, input=input, user_id=kwargs.get("user", None), - #user_props=self.model_call_details.get("user_props", None), + # user_props=self.model_call_details.get("user_props", None), extra=kwargs.get("optional_params", {}), response_obj=result, start_time=start_time, @@ -2064,8 +2062,6 @@ class Logging: else "llm" ) - - lunaryLogger.log_event( type=_type, event="error", @@ -2509,6 +2505,33 @@ def client(original_function): @wraps(original_function) def wrapper(*args, **kwargs): + # DO NOT MOVE THIS. It always needs to run first + # Check if this is an async function. If so only execute the async function + if ( + kwargs.get("acompletion", False) == True + or kwargs.get("aembedding", False) == True + or kwargs.get("aimg_generation", False) == True + or kwargs.get("amoderation", False) == True + or kwargs.get("atext_completion", False) == True + or kwargs.get("atranscription", False) == True + ): + # MODEL CALL + result = original_function(*args, **kwargs) + if "stream" in kwargs and kwargs["stream"] == True: + if ( + "complete_response" in kwargs + and kwargs["complete_response"] == True + ): + chunks = [] + for idx, chunk in enumerate(result): + chunks.append(chunk) + return litellm.stream_chunk_builder( + chunks, messages=kwargs.get("messages", None) + ) + else: + return result + return result + # Prints Exactly what was passed to litellm function - don't execute any logic here - it should just print print_args_passed_to_litellm(original_function, args, kwargs) start_time = datetime.datetime.now() @@ -6178,9 +6201,9 @@ def validate_environment(model: Optional[str] = None) -> dict: def set_callbacks(callback_list, function_id=None): - + global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, langsmithLogger, dynamoLogger, s3Logger, dataDogLogger, prometheusLogger - + try: for callback in callback_list: print_verbose(f"callback: {callback}") From c365de122a2af998c8d11c9473bf163bb3ddd737 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 30 Mar 2024 19:33:40 -0700 Subject: [PATCH 02/20] check num retries in async wrapper --- litellm/utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/litellm/utils.py b/litellm/utils.py index 2a765bc41..600d80599 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2515,6 +2515,16 @@ def client(original_function): or kwargs.get("atext_completion", False) == True or kwargs.get("atranscription", False) == True ): + # [OPTIONAL] CHECK MAX RETRIES / REQUEST + if litellm.num_retries_per_request is not None: + # check if previous_models passed in as ['litellm_params']['metadata]['previous_models'] + previous_models = kwargs.get("metadata", {}).get( + "previous_models", None + ) + if previous_models is not None: + if litellm.num_retries_per_request <= len(previous_models): + raise Exception(f"Max retries per request hit!") + # MODEL CALL result = original_function(*args, **kwargs) if "stream" in kwargs and kwargs["stream"] == True: From e62e83c42a8e78208418e9a1ba4484d644bd3f5e Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 30 Mar 2024 11:30:18 -0700 Subject: [PATCH 03/20] (ui) dont let prediction block spend view --- .../src/components/view_key_table.tsx | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/ui/litellm-dashboard/src/components/view_key_table.tsx b/ui/litellm-dashboard/src/components/view_key_table.tsx index 4269bfc02..bf3affe14 100644 --- a/ui/litellm-dashboard/src/components/view_key_table.tsx +++ b/ui/litellm-dashboard/src/components/view_key_table.tsx @@ -84,15 +84,20 @@ const ViewKeyTable: React.FC = ({ setSpendData(response); // predict spend based on response - const predictedSpend = await PredictedSpendLogsCall(accessToken, response); - console.log("Response2:", predictedSpend); + try { + const predictedSpend = await PredictedSpendLogsCall(accessToken, response); + console.log("Response2:", predictedSpend); - // append predictedSpend to data - const combinedData = [...response, ...predictedSpend.response]; - setSpendData(combinedData); - setPredictedSpendString(predictedSpend.predicted_spend) + // append predictedSpend to data + const combinedData = [...response, ...predictedSpend.response]; + setSpendData(combinedData); + setPredictedSpendString(predictedSpend.predicted_spend) - console.log("Combined Data:", combinedData); + console.log("Combined Data:", combinedData); + } catch (error) { + console.error("There was an error fetching the predicted data", error); + } + // setPredictedSpend(predictedSpend); } catch (error) { From 15b5cb161256058cb606fcb311aaa4c74460f131 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 30 Mar 2024 11:34:20 -0700 Subject: [PATCH 04/20] (fix) check size of data to predict --- enterprise/utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/enterprise/utils.py b/enterprise/utils.py index d762cd56c..c575a74f6 100644 --- a/enterprise/utils.py +++ b/enterprise/utils.py @@ -251,6 +251,11 @@ def _forecast_daily_cost(data: list): import requests from datetime import datetime, timedelta + if len(data) == 0: + return { + "response": [], + "predicted_spend": "Current Spend = $0, Predicted = $0", + } first_entry = data[0] last_entry = data[-1] From ea76c546ff63f6fad7715ec802ed47b30a611e48 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 30 Mar 2024 20:08:17 -0700 Subject: [PATCH 05/20] docs(deploy.md): fix docs for litlelm-database docker run example --- docs/my-website/docs/proxy/deploy.md | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/docs/my-website/docs/proxy/deploy.md b/docs/my-website/docs/proxy/deploy.md index e78c128bb..00606332e 100644 --- a/docs/my-website/docs/proxy/deploy.md +++ b/docs/my-website/docs/proxy/deploy.md @@ -250,10 +250,13 @@ docker pull docker pull ghcr.io/berriai/litellm-database:main-latest ``` ```shell -docker run --name litellm-proxy \ --e DATABASE_URL=postgresql://:@:/ \ --p 4000:4000 \ -ghcr.io/berriai/litellm-database:main-latest +docker run \ + -v $(pwd)/litellm_config.yaml:/app/config.yaml \ + -e AZURE_API_KEY=d6*********** \ + -e AZURE_API_BASE=https://openai-***********/ \ + -p 4000:4000 \ + ghcr.io/berriai/litellm-database:main-latest \ + --config /app/config.yaml --detailed_debug ``` Your OpenAI proxy server is now running on `http://0.0.0.0:4000`. From 0c77f75ce93060e9dae16ca63c210086d16990cf Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 30 Mar 2024 20:01:36 -0700 Subject: [PATCH 06/20] fix(tpm_rpm_limiter.py): enable redis caching for tpm/rpm checks on keys/user/teams allows tpm/rpm checks to work across instances https://github.com/BerriAI/litellm/issues/2730 --- litellm/caching.py | 9 +- litellm/proxy/_new_secret_config.yaml | 5 +- litellm/proxy/hooks/tpm_rpm_limiter.py | 380 +++++++++++++++++++++++++ litellm/proxy/proxy_server.py | 23 +- litellm/proxy/utils.py | 17 +- 5 files changed, 423 insertions(+), 11 deletions(-) create mode 100644 litellm/proxy/hooks/tpm_rpm_limiter.py diff --git a/litellm/caching.py b/litellm/caching.py index 921ae1b21..eea687edf 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -796,6 +796,8 @@ class DualCache(BaseCache): self, in_memory_cache: Optional[InMemoryCache] = None, redis_cache: Optional[RedisCache] = None, + default_in_memory_ttl: Optional[float] = None, + default_redis_ttl: Optional[float] = None, ) -> None: super().__init__() # If in_memory_cache is not provided, use the default InMemoryCache @@ -803,11 +805,17 @@ class DualCache(BaseCache): # If redis_cache is not provided, use the default RedisCache self.redis_cache = redis_cache + self.default_in_memory_ttl = default_in_memory_ttl + self.default_redis_ttl = default_redis_ttl + def set_cache(self, key, value, local_only: bool = False, **kwargs): # Update both Redis and in-memory cache try: print_verbose(f"set cache: key: {key}; value: {value}") if self.in_memory_cache is not None: + if "ttl" not in kwargs and self.default_in_memory_ttl is not None: + kwargs["ttl"] = self.default_in_memory_ttl + self.in_memory_cache.set_cache(key, value, **kwargs) if self.redis_cache is not None and local_only == False: @@ -823,7 +831,6 @@ class DualCache(BaseCache): if self.in_memory_cache is not None: in_memory_result = self.in_memory_cache.get_cache(key, **kwargs) - print_verbose(f"in_memory_result: {in_memory_result}") if in_memory_result is not None: result = in_memory_result diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index d218fddb1..dde07bb87 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -6,8 +6,9 @@ model_list: api_base: https://exampleopenaiendpoint-production.up.railway.app/ litellm_settings: - max_budget: 600020 - budget_duration: 30d + cache: true +# max_budget: 600020 +# budget_duration: 30d general_settings: master_key: sk-1234 diff --git a/litellm/proxy/hooks/tpm_rpm_limiter.py b/litellm/proxy/hooks/tpm_rpm_limiter.py new file mode 100644 index 000000000..db1d1759a --- /dev/null +++ b/litellm/proxy/hooks/tpm_rpm_limiter.py @@ -0,0 +1,380 @@ +# What is this? +## Checks TPM/RPM Limits for a key/user/team on the proxy +## Works with Redis - if given + +from typing import Optional, Literal +import litellm, traceback, sys +from litellm.caching import DualCache, RedisCache +from litellm.proxy._types import ( + UserAPIKeyAuth, + LiteLLM_VerificationTokenView, + LiteLLM_UserTable, + LiteLLM_TeamTable, +) +from litellm.integrations.custom_logger import CustomLogger +from fastapi import HTTPException +from litellm._logging import verbose_proxy_logger +from litellm import ModelResponse +from datetime import datetime + + +class _PROXY_MaxTPMRPMLimiter(CustomLogger): + user_api_key_cache = None + + # Class variables or attributes + def __init__(self, redis_usage_cache: Optional[RedisCache]): + self.redis_usage_cache = redis_usage_cache + self.internal_cache = DualCache( + redis_cache=redis_usage_cache, + default_in_memory_ttl=10, + default_redis_ttl=60, + ) + + def print_verbose(self, print_statement): + try: + verbose_proxy_logger.debug(print_statement) + if litellm.set_verbose: + print(print_statement) # noqa + except: + pass + + ## check if admin has set tpm/rpm limits for this key/user/team + + def _check_limits_set( + self, + user_api_key_cache: DualCache, + key: Optional[str], + user_id: Optional[str], + team_id: Optional[str], + ) -> bool: + ## key + if key is not None: + key_val = user_api_key_cache.get_cache(key=key) + if isinstance(key_val, dict): + key_val = LiteLLM_VerificationTokenView(**key_val) + + if isinstance(key_val, LiteLLM_VerificationTokenView): + user_api_key_tpm_limit = key_val.tpm_limit + + user_api_key_rpm_limit = key_val.rpm_limit + + if ( + user_api_key_tpm_limit is not None + or user_api_key_rpm_limit is not None + ): + return True + + ## team + if team_id is not None: + team_val = user_api_key_cache.get_cache(key=team_id) + if isinstance(team_val, dict): + team_val = LiteLLM_TeamTable(**team_val) + + if isinstance(team_val, LiteLLM_TeamTable): + team_tpm_limit = team_val.tpm_limit + + team_rpm_limit = team_val.rpm_limit + + if team_tpm_limit is not None or team_rpm_limit is not None: + return True + + ## user + if user_id is not None: + user_val = user_api_key_cache.get_cache(key=user_id) + if isinstance(user_val, dict): + user_val = LiteLLM_UserTable(**user_val) + + if isinstance(user_val, LiteLLM_UserTable): + user_tpm_limit = user_val.tpm_limit + + user_rpm_limit = user_val.rpm_limit + + if user_tpm_limit is not None or user_rpm_limit is not None: + return True + return False + + async def check_key_in_limits( + self, + user_api_key_dict: UserAPIKeyAuth, + current_minute_dict: dict, + tpm_limit: int, + rpm_limit: int, + request_count_api_key: str, + type: Literal["key", "user", "team"], + ): + if type == "key" and user_api_key_dict.api_key is not None: + current = current_minute_dict["key"].get(user_api_key_dict.api_key, None) + elif type == "user" and user_api_key_dict.user_id is not None: + current = current_minute_dict["user"].get(user_api_key_dict.user_id, None) + elif type == "team" and user_api_key_dict.team_id is not None: + current = current_minute_dict["team"].get(user_api_key_dict.team_id, None) + else: + return + + if current is None: + if tpm_limit == 0 or rpm_limit == 0: + # base case + raise HTTPException( + status_code=429, detail="Max tpm/rpm limit reached." + ) + elif current["current_tpm"] < tpm_limit and current["current_rpm"] < rpm_limit: + pass + else: + raise HTTPException(status_code=429, detail="Max tpm/rpm limit reached.") + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: str, + ): + self.print_verbose( + f"Inside Max TPM/RPM Limiter Pre-Call Hook - {user_api_key_dict}" + ) + api_key = user_api_key_dict.api_key + # check if REQUEST ALLOWED for user_id + user_id = user_api_key_dict.user_id + ## get team tpm/rpm limits + team_id = user_api_key_dict.team_id + + _set_limits = self._check_limits_set( + user_api_key_cache=cache, key=api_key, user_id=user_id, team_id=team_id + ) + + if _set_limits == False: + return + + # ------------ + # Setup values + # ------------ + + self.user_api_key_cache = cache + + current_date = datetime.now().strftime("%Y-%m-%d") + current_hour = datetime.now().strftime("%H") + current_minute = datetime.now().strftime("%M") + precise_minute = f"{current_date}-{current_hour}-{current_minute}" + cache_key = "usage:{}".format(precise_minute) + current_minute_dict = await self.internal_cache.async_get_cache( + key=cache_key + ) # {"usage:{curr_minute}": {"key": {: {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}}}} + + if current_minute_dict is None: + current_minute_dict = {"key": {}, "user": {}, "team": {}} + + if api_key is not None: + tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize) + if tpm_limit is None: + tpm_limit = sys.maxsize + rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize) + if rpm_limit is None: + rpm_limit = sys.maxsize + request_count_api_key = f"{api_key}::{precise_minute}::request_count" + await self.check_key_in_limits( + user_api_key_dict=user_api_key_dict, + current_minute_dict=current_minute_dict, + request_count_api_key=request_count_api_key, + tpm_limit=tpm_limit, + rpm_limit=rpm_limit, + type="key", + ) + + if user_id is not None: + _user_id_rate_limits = user_api_key_dict.user_id_rate_limits + + # get user tpm/rpm limits + if _user_id_rate_limits is not None and isinstance( + _user_id_rate_limits, dict + ): + user_tpm_limit = _user_id_rate_limits.get("tpm_limit", None) + user_rpm_limit = _user_id_rate_limits.get("rpm_limit", None) + if user_tpm_limit is None: + user_tpm_limit = sys.maxsize + if user_rpm_limit is None: + user_rpm_limit = sys.maxsize + + # now do the same tpm/rpm checks + request_count_api_key = f"{user_id}::{precise_minute}::request_count" + # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}") + await self.check_key_in_limits( + user_api_key_dict=user_api_key_dict, + current_minute_dict=current_minute_dict, + request_count_api_key=request_count_api_key, + tpm_limit=user_tpm_limit, + rpm_limit=user_rpm_limit, + type="user", + ) + + # TEAM RATE LIMITS + if team_id is not None: + team_tpm_limit = getattr(user_api_key_dict, "team_tpm_limit", sys.maxsize) + + if team_tpm_limit is None: + team_tpm_limit = sys.maxsize + team_rpm_limit = getattr(user_api_key_dict, "team_rpm_limit", sys.maxsize) + if team_rpm_limit is None: + team_rpm_limit = sys.maxsize + + if team_tpm_limit is None: + team_tpm_limit = sys.maxsize + if team_rpm_limit is None: + team_rpm_limit = sys.maxsize + + # now do the same tpm/rpm checks + request_count_api_key = f"{team_id}::{precise_minute}::request_count" + + # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}") + await self.check_key_in_limits( + user_api_key_dict=user_api_key_dict, + current_minute_dict=current_minute_dict, + request_count_api_key=request_count_api_key, + tpm_limit=team_tpm_limit, + rpm_limit=team_rpm_limit, + type="team", + ) + + return + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + self.print_verbose(f"INSIDE TPM RPM Limiter ASYNC SUCCESS LOGGING") + + user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"] + user_api_key_user_id = kwargs["litellm_params"]["metadata"].get( + "user_api_key_user_id", None + ) + user_api_key_team_id = kwargs["litellm_params"]["metadata"].get( + "user_api_key_team_id", None + ) + + _limits_set = self._check_limits_set( + user_api_key_cache=self.user_api_key_cache, + key=user_api_key, + user_id=user_api_key_user_id, + team_id=user_api_key_team_id, + ) + + if _limits_set == False: # don't waste cache calls if no tpm/rpm limits set + return + + # ------------ + # Setup values + # ------------ + + current_date = datetime.now().strftime("%Y-%m-%d") + current_hour = datetime.now().strftime("%H") + current_minute = datetime.now().strftime("%M") + precise_minute = f"{current_date}-{current_hour}-{current_minute}" + + total_tokens = 0 + + if isinstance(response_obj, ModelResponse): + total_tokens = response_obj.usage.total_tokens + + """ + - get value from redis + - increment requests + 1 + - increment tpm + 1 + - increment rpm + 1 + - update value in-memory + redis + """ + cache_key = "usage:{}".format(precise_minute) + if ( + self.internal_cache.redis_cache is not None + ): # get straight from redis if possible + current_minute_dict = ( + await self.internal_cache.redis_cache.async_get_cache( + key=cache_key, + ) + ) # {"usage:{current_minute}": {"key": {}, "team": {}, "user": {}}} + else: + current_minute_dict = await self.internal_cache.async_get_cache( + key=cache_key, + ) + + if current_minute_dict is None: + current_minute_dict = {"key": {}, "user": {}, "team": {}} + + _cache_updated = False # check if a cache update is required. prevent unnecessary rewrites. + + # ------------ + # Update usage - API Key + # ------------ + + if user_api_key is not None: + _cache_updated = True + ## API KEY ## + if user_api_key in current_minute_dict["key"]: + current_key_usage = current_minute_dict["key"][user_api_key] + new_val = { + "current_tpm": current_key_usage["current_tpm"] + total_tokens, + "current_rpm": current_key_usage["current_rpm"] + 1, + } + else: + new_val = { + "current_tpm": total_tokens, + "current_rpm": 1, + } + + current_minute_dict["key"][user_api_key] = new_val + + self.print_verbose( + f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" + ) + + # ------------ + # Update usage - User + # ------------ + if user_api_key_user_id is not None: + _cache_updated = True + total_tokens = 0 + + if isinstance(response_obj, ModelResponse): + total_tokens = response_obj.usage.total_tokens + + if user_api_key_user_id in current_minute_dict["key"]: + current_key_usage = current_minute_dict["key"][user_api_key_user_id] + new_val = { + "current_tpm": current_key_usage["current_tpm"] + total_tokens, + "current_rpm": current_key_usage["current_rpm"] + 1, + } + else: + new_val = { + "current_tpm": total_tokens, + "current_rpm": 1, + } + + current_minute_dict["user"][user_api_key_user_id] = new_val + + # ------------ + # Update usage - Team + # ------------ + if user_api_key_team_id is not None: + _cache_updated = True + total_tokens = 0 + + if isinstance(response_obj, ModelResponse): + total_tokens = response_obj.usage.total_tokens + + if user_api_key_team_id in current_minute_dict["key"]: + current_key_usage = current_minute_dict["key"][user_api_key_team_id] + new_val = { + "current_tpm": current_key_usage["current_tpm"] + total_tokens, + "current_rpm": current_key_usage["current_rpm"] + 1, + } + else: + new_val = { + "current_tpm": total_tokens, + "current_rpm": 1, + } + + current_minute_dict["team"][user_api_key_team_id] = new_val + + if _cache_updated == True: + await self.internal_cache.async_set_cache( + key=cache_key, value=current_minute_dict + ) + + except Exception as e: + self.print_verbose(e) # noqa diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index ea6adb99a..f819791b7 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -102,7 +102,7 @@ from litellm.proxy.secret_managers.google_kms import load_google_kms from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager import pydantic from litellm.proxy._types import * -from litellm.caching import DualCache +from litellm.caching import DualCache, RedisCache from litellm.proxy.health_check import perform_health_check from litellm._logging import verbose_router_logger, verbose_proxy_logger from litellm.proxy.auth.handle_jwt import JWTHandler @@ -281,6 +281,9 @@ otel_logging = False prisma_client: Optional[PrismaClient] = None custom_db_client: Optional[DBClient] = None user_api_key_cache = DualCache() +redis_usage_cache: Optional[RedisCache] = ( + None # redis cache used for tracking spend, tpm/rpm limits +) user_custom_auth = None user_custom_key_generate = None use_background_health_checks = None @@ -299,7 +302,9 @@ disable_spend_logs = False jwt_handler = JWTHandler() prompt_injection_detection_obj: Optional[_OPTIONAL_PromptInjectionDetection] = None ### INITIALIZE GLOBAL LOGGING OBJECT ### -proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) +proxy_logging_obj = ProxyLogging( + user_api_key_cache=user_api_key_cache, redis_usage_cache=redis_usage_cache +) ### REDIS QUEUE ### async_result = None celery_app_conn = None @@ -909,6 +914,10 @@ async def user_api_key_auth( models=valid_token.team_models, ) + user_api_key_cache.set_cache( + key=valid_token.team_id, value=_team_obj + ) # save team table in cache - used for tpm/rpm limiting - tpm_rpm_limiter.py + _end_user_object = None if "user" in request_data: _id = "end_user_id:{}".format(request_data["user"]) @@ -1905,7 +1914,7 @@ class ProxyConfig: """ Load config values into proxy global state """ - global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj + global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache # Load existing config config = await self.get_config(config_file_path=config_file_path) @@ -1967,6 +1976,7 @@ class ProxyConfig: "password": cache_password, } ) + # Assuming cache_type, cache_host, cache_port, and cache_password are strings print( # noqa f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}" @@ -1991,7 +2001,14 @@ class ProxyConfig: cache_params[key] = litellm.get_secret(value) ## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = `, _redis.py checks for REDIS specific environment variables + litellm.cache = Cache(**cache_params) + + if litellm.cache is not None and isinstance( + litellm.cache.cache, RedisCache + ): + ## INIT PROXY REDIS USAGE CLIENT ## + redis_usage_cache = litellm.cache.cache print( # noqa f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}" ) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index f70d67aac..7d3e30f86 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -12,13 +12,14 @@ from litellm.proxy._types import ( LiteLLM_TeamTable, Member, ) -from litellm.caching import DualCache +from litellm.caching import DualCache, RedisCache from litellm.llms.custom_httpx.httpx_handler import HTTPHandler from litellm.proxy.hooks.parallel_request_limiter import ( _PROXY_MaxParallelRequestsHandler, ) from litellm import ModelResponse, EmbeddingResponse, ImageResponse from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter +from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck from litellm.integrations.custom_logger import CustomLogger from litellm.proxy.db.base_client import CustomDB @@ -46,16 +47,21 @@ class ProxyLogging: - support the max parallel request integration """ - def __init__(self, user_api_key_cache: DualCache): + def __init__( + self, user_api_key_cache: DualCache, redis_usage_cache: Optional[RedisCache] + ): ## INITIALIZE LITELLM CALLBACKS ## self.call_details: dict = {} self.call_details["user_api_key_cache"] = user_api_key_cache - self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler() + # self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler() + self.max_tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter( + redis_usage_cache=redis_usage_cache + ) self.max_budget_limiter = _PROXY_MaxBudgetLimiter() self.cache_control_check = _PROXY_CacheControlCheck() self.alerting: Optional[List] = None self.alerting_threshold: float = 300 # default to 5 min. threshold - pass + self.redis_usage_cache = redis_usage_cache def update_values( self, alerting: Optional[List], alerting_threshold: Optional[float] @@ -66,7 +72,8 @@ class ProxyLogging: def _init_litellm_callbacks(self): print_verbose(f"INITIALIZING LITELLM CALLBACKS!") - litellm.callbacks.append(self.max_parallel_request_limiter) + # litellm.callbacks.append(self.max_parallel_request_limiter) + litellm.callbacks.append(self.max_tpm_rpm_limiter) litellm.callbacks.append(self.max_budget_limiter) litellm.callbacks.append(self.cache_control_check) litellm.success_callback.append(self.response_taking_too_long_callback) From 583e334bd229ba4fee04ef201a2744f95b0d5758 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 30 Mar 2024 20:10:56 -0700 Subject: [PATCH 07/20] fix(utils.py): set redis_usage_cache to none by default --- litellm/proxy/utils.py | 4 +++- litellm/tests/test_blocked_user_list.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 7d3e30f86..f75a3ff56 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -48,7 +48,9 @@ class ProxyLogging: """ def __init__( - self, user_api_key_cache: DualCache, redis_usage_cache: Optional[RedisCache] + self, + user_api_key_cache: DualCache, + redis_usage_cache: Optional[RedisCache] = None, ): ## INITIALIZE LITELLM CALLBACKS ## self.call_details: dict = {} diff --git a/litellm/tests/test_blocked_user_list.py b/litellm/tests/test_blocked_user_list.py index d3f9f6a1a..533727505 100644 --- a/litellm/tests/test_blocked_user_list.py +++ b/litellm/tests/test_blocked_user_list.py @@ -61,7 +61,7 @@ from litellm.proxy.utils import DBClient from starlette.datastructures import URL from litellm.caching import DualCache -proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) +proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache(), redis_usage_cache=None) @pytest.fixture From 17cabf013c6cc3f3fdb69920125469006d8e396b Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 30 Mar 2024 20:20:29 -0700 Subject: [PATCH 08/20] fix(caching.py): respect redis namespace for all redis get/set requests --- litellm/caching.py | 32 +++++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/litellm/caching.py b/litellm/caching.py index eea687edf..ded347c10 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -100,7 +100,13 @@ class RedisCache(BaseCache): # if users don't provider one, use the default litellm cache def __init__( - self, host=None, port=None, password=None, redis_flush_size=100, **kwargs + self, + host=None, + port=None, + password=None, + redis_flush_size=100, + namespace: Optional[str] = None, + **kwargs, ): from ._redis import get_redis_client, get_redis_connection_pool @@ -116,9 +122,10 @@ class RedisCache(BaseCache): self.redis_client = get_redis_client(**redis_kwargs) self.redis_kwargs = redis_kwargs self.async_redis_conn_pool = get_redis_connection_pool(**redis_kwargs) - + # redis namespaces + self.namespace = namespace # for high traffic, we store the redis results in memory and then batch write to redis - self.redis_batch_writing_buffer = [] + self.redis_batch_writing_buffer: list = [] self.redis_flush_size = redis_flush_size self.redis_version = "Unknown" try: @@ -133,11 +140,21 @@ class RedisCache(BaseCache): connection_pool=self.async_redis_conn_pool, **self.redis_kwargs ) + def check_and_fix_namespace(self, key: str) -> str: + """ + Make sure each key starts with the given namespace + """ + if self.namespace is not None and not key.startswith(self.namespace): + key = self.namespace + ":" + key + + return key + def set_cache(self, key, value, **kwargs): ttl = kwargs.get("ttl", None) print_verbose( f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}, redis_version={self.redis_version}" ) + key = self.check_and_fix_namespace(key=key) try: self.redis_client.set(name=key, value=str(value), ex=ttl) except Exception as e: @@ -158,6 +175,7 @@ class RedisCache(BaseCache): async def async_set_cache(self, key, value, **kwargs): _redis_client = self.init_async_client() + key = self.check_and_fix_namespace(key=key) async with _redis_client as redis_client: ttl = kwargs.get("ttl", None) print_verbose( @@ -187,6 +205,7 @@ class RedisCache(BaseCache): async with redis_client.pipeline(transaction=True) as pipe: # Iterate through each key-value pair in the cache_list and set them in the pipeline. for cache_key, cache_value in cache_list: + cache_key = self.check_and_fix_namespace(key=cache_key) print_verbose( f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}" ) @@ -213,6 +232,7 @@ class RedisCache(BaseCache): print_verbose( f"in batch cache writing for redis buffer size={len(self.redis_batch_writing_buffer)}", ) + key = self.check_and_fix_namespace(key=key) self.redis_batch_writing_buffer.append((key, value)) if len(self.redis_batch_writing_buffer) >= self.redis_flush_size: await self.flush_cache_buffer() @@ -242,6 +262,7 @@ class RedisCache(BaseCache): def get_cache(self, key, **kwargs): try: + key = self.check_and_fix_namespace(key=key) print_verbose(f"Get Redis Cache: key: {key}") cached_response = self.redis_client.get(key) print_verbose( @@ -255,6 +276,7 @@ class RedisCache(BaseCache): async def async_get_cache(self, key, **kwargs): _redis_client = self.init_async_client() + key = self.check_and_fix_namespace(key=key) async with _redis_client as redis_client: try: print_verbose(f"Get Async Redis Cache: key: {key}") @@ -281,6 +303,7 @@ class RedisCache(BaseCache): async with redis_client.pipeline(transaction=True) as pipe: # Queue the get operations in the pipeline for all keys. for cache_key in key_list: + cache_key = self.check_and_fix_namespace(key=cache_key) pipe.get(cache_key) # Queue GET command in pipeline # Execute the pipeline and await the results. @@ -1015,6 +1038,9 @@ class Cache: self.redis_flush_size = redis_flush_size self.ttl = ttl + if self.namespace is not None and isinstance(self.cache, RedisCache): + self.cache.namespace = self.namespace + def get_cache_key(self, *args, **kwargs): """ Get the cache key for the given arguments. From aebb0e489c1b3ae81bbb2ffd0b9a73d71c2cb357 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 30 Mar 2024 20:22:48 -0700 Subject: [PATCH 09/20] test: fix test --- litellm/tests/test_custom_callback_input.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/litellm/tests/test_custom_callback_input.py b/litellm/tests/test_custom_callback_input.py index 4296f188d..afaf7a54c 100644 --- a/litellm/tests/test_custom_callback_input.py +++ b/litellm/tests/test_custom_callback_input.py @@ -253,7 +253,12 @@ class CompletionCustomHandler( assert isinstance(end_time, datetime) ## RESPONSE OBJECT assert isinstance( - response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse) + response_obj, + ( + litellm.ModelResponse, + litellm.EmbeddingResponse, + litellm.TextCompletionResponse, + ), ) ## KWARGS assert isinstance(kwargs["model"], str) From d4dd6d0cdc816654e1111626f95b7615b0592615 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 30 Mar 2024 20:51:59 -0700 Subject: [PATCH 10/20] fix(proxy/utils.py): uncomment max parallel request limit check --- litellm/proxy/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index f75a3ff56..6e1f0521f 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -55,7 +55,7 @@ class ProxyLogging: ## INITIALIZE LITELLM CALLBACKS ## self.call_details: dict = {} self.call_details["user_api_key_cache"] = user_api_key_cache - # self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler() + self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler() self.max_tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter( redis_usage_cache=redis_usage_cache ) @@ -74,7 +74,7 @@ class ProxyLogging: def _init_litellm_callbacks(self): print_verbose(f"INITIALIZING LITELLM CALLBACKS!") - # litellm.callbacks.append(self.max_parallel_request_limiter) + litellm.callbacks.append(self.max_parallel_request_limiter) litellm.callbacks.append(self.max_tpm_rpm_limiter) litellm.callbacks.append(self.max_budget_limiter) litellm.callbacks.append(self.cache_control_check) From a1365f6035371acf515b80ecbf38d8fc8ed4f90c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 30 Mar 2024 21:40:43 -0700 Subject: [PATCH 11/20] test: cleanup --- tests/test_users.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_users.py b/tests/test_users.py index a644f8761..a73d7163e 100644 --- a/tests/test_users.py +++ b/tests/test_users.py @@ -167,6 +167,7 @@ async def chat_completion_streaming(session, key, model="gpt-4"): continue +@pytest.mark.skip(reason="Global proxy now tracked via `/global/spend/logs`") @pytest.mark.asyncio async def test_global_proxy_budget_update(): """ From 5800a5095a1ec839e69195bc89e1b01ad07f7141 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 30 Mar 2024 21:41:14 -0700 Subject: [PATCH 12/20] refactor(main.py): trigger new build --- litellm/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/litellm/main.py b/litellm/main.py index 1fcf0d5d3..02d3777af 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -12,6 +12,7 @@ from typing import Any, Literal, Union, BinaryIO from functools import partial import dotenv, traceback, random, asyncio, time, contextvars from copy import deepcopy + import httpx import litellm from ._logging import verbose_logger From ea2356fd95baf4e1f12e00477fdb392dc285ec93 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 30 Mar 2024 22:10:21 -0700 Subject: [PATCH 13/20] =?UTF-8?q?bump:=20version=201.34.17=20=E2=86=92=201?= =?UTF-8?q?.34.18?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8e328412d..15c81cad4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "1.34.17" +version = "1.34.18" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT" @@ -80,7 +80,7 @@ requires = ["poetry-core", "wheel"] build-backend = "poetry.core.masonry.api" [tool.commitizen] -version = "1.34.17" +version = "1.34.18" version_files = [ "pyproject.toml:^version" ] From a2c7455c3d820e624c6ab7b8c4b4e44349c7fe05 Mon Sep 17 00:00:00 2001 From: DaxServer Date: Sun, 31 Mar 2024 19:35:37 +0200 Subject: [PATCH 14/20] docs: Update references to Ollama repository url Updated references to the Ollama repository URL from https://github.com/jmorganca/ollama to https://github.com/ollama/ollama. --- docs/my-website/docs/providers/ollama.md | 4 ++-- docs/my-website/docs/proxy/quick_start.md | 3 +-- docs/my-website/docs/proxy_server.md | 3 +-- docs/my-website/docs/simple_proxy_old_doc.md | 4 +--- litellm/llms/ollama.py | 4 ++-- litellm/llms/ollama_chat.py | 4 ++-- litellm/llms/prompt_templates/factory.py | 2 +- 7 files changed, 10 insertions(+), 14 deletions(-) diff --git a/docs/my-website/docs/providers/ollama.md b/docs/my-website/docs/providers/ollama.md index ec2a231e1..1c913c08c 100644 --- a/docs/my-website/docs/providers/ollama.md +++ b/docs/my-website/docs/providers/ollama.md @@ -1,5 +1,5 @@ # Ollama -LiteLLM supports all models from [Ollama](https://github.com/jmorganca/ollama) +LiteLLM supports all models from [Ollama](https://github.com/ollama/ollama) Open In Colab @@ -97,7 +97,7 @@ response = completion( print(response) ``` ## Ollama Models -Ollama supported models: https://github.com/jmorganca/ollama +Ollama supported models: https://github.com/ollama/ollama | Model Name | Function Call | |----------------------|----------------------------------------------------------------------------------- diff --git a/docs/my-website/docs/proxy/quick_start.md b/docs/my-website/docs/proxy/quick_start.md index 8c7d1c066..a7ca4743b 100644 --- a/docs/my-website/docs/proxy/quick_start.md +++ b/docs/my-website/docs/proxy/quick_start.md @@ -438,7 +438,7 @@ In the [config.py](https://continue.dev/docs/reference/Models/openai) set this a ), ``` -Credits [@vividfog](https://github.com/jmorganca/ollama/issues/305#issuecomment-1751848077) for this tutorial. +Credits [@vividfog](https://github.com/ollama/ollama/issues/305#issuecomment-1751848077) for this tutorial. @@ -551,4 +551,3 @@ No Logs ```shell export LITELLM_LOG=None ``` - diff --git a/docs/my-website/docs/proxy_server.md b/docs/my-website/docs/proxy_server.md index 9c335f2a2..87d30e16d 100644 --- a/docs/my-website/docs/proxy_server.md +++ b/docs/my-website/docs/proxy_server.md @@ -435,7 +435,7 @@ In the [config.py](https://continue.dev/docs/reference/Models/openai) set this a ), ``` -Credits [@vividfog](https://github.com/jmorganca/ollama/issues/305#issuecomment-1751848077) for this tutorial. +Credits [@vividfog](https://github.com/ollama/ollama/issues/305#issuecomment-1751848077) for this tutorial. @@ -815,4 +815,3 @@ Thread Stats Avg Stdev Max +/- Stdev - [Community Discord 💭](https://discord.gg/wuPM9dRgDw) - Our numbers 📞 +1 (770) 8783-106 / ‭+1 (412) 618-6238‬ - Our emails ✉️ ishaan@berri.ai / krrish@berri.ai - diff --git a/docs/my-website/docs/simple_proxy_old_doc.md b/docs/my-website/docs/simple_proxy_old_doc.md index 9dcb27797..195728d1b 100644 --- a/docs/my-website/docs/simple_proxy_old_doc.md +++ b/docs/my-website/docs/simple_proxy_old_doc.md @@ -310,7 +310,7 @@ In the [config.py](https://continue.dev/docs/reference/Models/openai) set this a ), ``` -Credits [@vividfog](https://github.com/jmorganca/ollama/issues/305#issuecomment-1751848077) for this tutorial. +Credits [@vividfog](https://github.com/ollama/ollama/issues/305#issuecomment-1751848077) for this tutorial. @@ -1351,5 +1351,3 @@ LiteLLM proxy adds **0.00325 seconds** latency as compared to using the Raw Open ```shell litellm --telemetry False ``` - - diff --git a/litellm/llms/ollama.py b/litellm/llms/ollama.py index 612a8c32e..779896abf 100644 --- a/litellm/llms/ollama.py +++ b/litellm/llms/ollama.py @@ -20,7 +20,7 @@ class OllamaError(Exception): class OllamaConfig: """ - Reference: https://github.com/jmorganca/ollama/blob/main/docs/api.md#parameters + Reference: https://github.com/ollama/ollama/blob/main/docs/api.md#parameters The class `OllamaConfig` provides the configuration for the Ollama's API interface. Below are the parameters: @@ -69,7 +69,7 @@ class OllamaConfig: repeat_penalty: Optional[float] = None temperature: Optional[float] = None stop: Optional[list] = ( - None # stop is a list based on this - https://github.com/jmorganca/ollama/pull/442 + None # stop is a list based on this - https://github.com/ollama/ollama/pull/442 ) tfs_z: Optional[float] = None num_predict: Optional[int] = None diff --git a/litellm/llms/ollama_chat.py b/litellm/llms/ollama_chat.py index 11e08fb72..d442ba5aa 100644 --- a/litellm/llms/ollama_chat.py +++ b/litellm/llms/ollama_chat.py @@ -20,7 +20,7 @@ class OllamaError(Exception): class OllamaChatConfig: """ - Reference: https://github.com/jmorganca/ollama/blob/main/docs/api.md#parameters + Reference: https://github.com/ollama/ollama/blob/main/docs/api.md#parameters The class `OllamaConfig` provides the configuration for the Ollama's API interface. Below are the parameters: @@ -69,7 +69,7 @@ class OllamaChatConfig: repeat_penalty: Optional[float] = None temperature: Optional[float] = None stop: Optional[list] = ( - None # stop is a list based on this - https://github.com/jmorganca/ollama/pull/442 + None # stop is a list based on this - https://github.com/ollama/ollama/pull/442 ) tfs_z: Optional[float] = None num_predict: Optional[int] = None diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 4492423f4..06faaf557 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -62,7 +62,7 @@ def llama_2_chat_pt(messages): def ollama_pt( model, messages -): # https://github.com/jmorganca/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template +): # https://github.com/ollama/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template if "instruct" in model: prompt = custom_prompt( role_dict={ From 3f25049dc8fc60ee851eda5b3a6afe2e7fda4d75 Mon Sep 17 00:00:00 2001 From: DaxServer Date: Sun, 31 Mar 2024 20:10:00 +0200 Subject: [PATCH 15/20] fix(docs): Correct Docker pull command in deploy.md Corrected the Docker pull command in deploy.md to remove duplicated 'docker pull' command. --- docs/my-website/docs/proxy/deploy.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/my-website/docs/proxy/deploy.md b/docs/my-website/docs/proxy/deploy.md index 00606332e..c08cb7d0e 100644 --- a/docs/my-website/docs/proxy/deploy.md +++ b/docs/my-website/docs/proxy/deploy.md @@ -246,7 +246,7 @@ Your OpenAI proxy server is now running on `http://127.0.0.1:4000`. We maintain a [seperate Dockerfile](https://github.com/BerriAI/litellm/pkgs/container/litellm-database) for reducing build time when running LiteLLM proxy with a connected Postgres Database ```shell -docker pull docker pull ghcr.io/berriai/litellm-database:main-latest +docker pull ghcr.io/berriai/litellm-database:main-latest ``` ```shell From ddb35facc0f4a47958ef6f0734cb7ed73f2d0659 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 1 Apr 2024 07:40:05 -0700 Subject: [PATCH 16/20] ci/cd run again --- litellm/tests/test_completion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index cb4ee84b5..29669a87d 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -41,7 +41,7 @@ def test_completion_custom_provider_model_name(): messages=messages, logger_fn=logger_fn, ) - # Add any assertions here to check the,response + # Add any assertions here to check the response print(response) print(response["choices"][0]["finish_reason"]) except litellm.Timeout as e: From 2dd5f2bc8c41e9060f1885b9640fd0553276efa5 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 30 Mar 2024 20:01:36 -0700 Subject: [PATCH 17/20] fix(tpm_rpm_limiter.py): enable redis caching for tpm/rpm checks on keys/user/teams allows tpm/rpm checks to work across instances https://github.com/BerriAI/litellm/issues/2730 --- litellm/proxy/utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 6e1f0521f..7d3e30f86 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -48,14 +48,12 @@ class ProxyLogging: """ def __init__( - self, - user_api_key_cache: DualCache, - redis_usage_cache: Optional[RedisCache] = None, + self, user_api_key_cache: DualCache, redis_usage_cache: Optional[RedisCache] ): ## INITIALIZE LITELLM CALLBACKS ## self.call_details: dict = {} self.call_details["user_api_key_cache"] = user_api_key_cache - self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler() + # self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler() self.max_tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter( redis_usage_cache=redis_usage_cache ) @@ -74,7 +72,7 @@ class ProxyLogging: def _init_litellm_callbacks(self): print_verbose(f"INITIALIZING LITELLM CALLBACKS!") - litellm.callbacks.append(self.max_parallel_request_limiter) + # litellm.callbacks.append(self.max_parallel_request_limiter) litellm.callbacks.append(self.max_tpm_rpm_limiter) litellm.callbacks.append(self.max_budget_limiter) litellm.callbacks.append(self.cache_control_check) From 1bb4f3ad6dec49fd9d0a6ba0f267ef1885354286 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 30 Mar 2024 20:10:56 -0700 Subject: [PATCH 18/20] fix(utils.py): set redis_usage_cache to none by default --- litellm/proxy/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 7d3e30f86..f75a3ff56 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -48,7 +48,9 @@ class ProxyLogging: """ def __init__( - self, user_api_key_cache: DualCache, redis_usage_cache: Optional[RedisCache] + self, + user_api_key_cache: DualCache, + redis_usage_cache: Optional[RedisCache] = None, ): ## INITIALIZE LITELLM CALLBACKS ## self.call_details: dict = {} From 19fc12008126ae224f6f6916a888c8b6801cb0dd Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 30 Mar 2024 20:51:59 -0700 Subject: [PATCH 19/20] fix(proxy/utils.py): uncomment max parallel request limit check --- litellm/proxy/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index f75a3ff56..6e1f0521f 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -55,7 +55,7 @@ class ProxyLogging: ## INITIALIZE LITELLM CALLBACKS ## self.call_details: dict = {} self.call_details["user_api_key_cache"] = user_api_key_cache - # self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler() + self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler() self.max_tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter( redis_usage_cache=redis_usage_cache ) @@ -74,7 +74,7 @@ class ProxyLogging: def _init_litellm_callbacks(self): print_verbose(f"INITIALIZING LITELLM CALLBACKS!") - # litellm.callbacks.append(self.max_parallel_request_limiter) + litellm.callbacks.append(self.max_parallel_request_limiter) litellm.callbacks.append(self.max_tpm_rpm_limiter) litellm.callbacks.append(self.max_budget_limiter) litellm.callbacks.append(self.cache_control_check) From f3e47323b98de100922bfc19fe88e96de5e9d6e0 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 1 Apr 2024 07:59:30 -0700 Subject: [PATCH 20/20] test(test_max_tpm_rpm_limiter.py): unit tests for key + team based tpm rpm limits on proxy --- litellm/proxy/hooks/tpm_rpm_limiter.py | 11 +- litellm/tests/test_max_tpm_rpm_limiter.py | 122 ++++++++++++++++++++++ 2 files changed, 128 insertions(+), 5 deletions(-) create mode 100644 litellm/tests/test_max_tpm_rpm_limiter.py diff --git a/litellm/proxy/hooks/tpm_rpm_limiter.py b/litellm/proxy/hooks/tpm_rpm_limiter.py index db1d1759a..a46337491 100644 --- a/litellm/proxy/hooks/tpm_rpm_limiter.py +++ b/litellm/proxy/hooks/tpm_rpm_limiter.py @@ -102,6 +102,7 @@ class _PROXY_MaxTPMRPMLimiter(CustomLogger): request_count_api_key: str, type: Literal["key", "user", "team"], ): + if type == "key" and user_api_key_dict.api_key is not None: current = current_minute_dict["key"].get(user_api_key_dict.api_key, None) elif type == "user" and user_api_key_dict.user_id is not None: @@ -110,7 +111,6 @@ class _PROXY_MaxTPMRPMLimiter(CustomLogger): current = current_minute_dict["team"].get(user_api_key_dict.team_id, None) else: return - if current is None: if tpm_limit == 0 or rpm_limit == 0: # base case @@ -138,10 +138,14 @@ class _PROXY_MaxTPMRPMLimiter(CustomLogger): ## get team tpm/rpm limits team_id = user_api_key_dict.team_id + self.user_api_key_cache = cache + _set_limits = self._check_limits_set( user_api_key_cache=cache, key=api_key, user_id=user_id, team_id=team_id ) + self.print_verbose(f"_set_limits: {_set_limits}") + if _set_limits == False: return @@ -149,8 +153,6 @@ class _PROXY_MaxTPMRPMLimiter(CustomLogger): # Setup values # ------------ - self.user_api_key_cache = cache - current_date = datetime.now().strftime("%Y-%m-%d") current_hour = datetime.now().strftime("%H") current_minute = datetime.now().strftime("%M") @@ -247,7 +249,6 @@ class _PROXY_MaxTPMRPMLimiter(CustomLogger): user_api_key_team_id = kwargs["litellm_params"]["metadata"].get( "user_api_key_team_id", None ) - _limits_set = self._check_limits_set( user_api_key_cache=self.user_api_key_cache, key=user_api_key, @@ -377,4 +378,4 @@ class _PROXY_MaxTPMRPMLimiter(CustomLogger): ) except Exception as e: - self.print_verbose(e) # noqa + self.print_verbose("{}\n{}".format(e, traceback.format_exc())) # noqa diff --git a/litellm/tests/test_max_tpm_rpm_limiter.py b/litellm/tests/test_max_tpm_rpm_limiter.py new file mode 100644 index 000000000..db1ab0f86 --- /dev/null +++ b/litellm/tests/test_max_tpm_rpm_limiter.py @@ -0,0 +1,122 @@ +# What is this? +## Unit tests for the max tpm / rpm limiter hook for proxy + +import sys, os, asyncio, time, random +from datetime import datetime +import traceback +from dotenv import load_dotenv + +load_dotenv() +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest +import litellm +from litellm import Router +from litellm.proxy.utils import ProxyLogging +from litellm.proxy._types import UserAPIKeyAuth +from litellm.caching import DualCache, RedisCache +from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter +from datetime import datetime + + +@pytest.mark.asyncio +async def test_pre_call_hook_rpm_limits(): + """ + Test if error raised on hitting rpm limits + """ + litellm.set_verbose = True + _api_key = "sk-12345" + user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, tpm_limit=9, rpm_limit=1) + local_cache = DualCache() + # redis_usage_cache = RedisCache() + + local_cache.set_cache( + key=_api_key, value={"api_key": _api_key, "tpm_limit": 9, "rpm_limit": 1} + ) + + tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(redis_usage_cache=None) + + await tpm_rpm_limiter.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" + ) + + kwargs = {"litellm_params": {"metadata": {"user_api_key": _api_key}}} + + await tpm_rpm_limiter.async_log_success_event( + kwargs=kwargs, + response_obj="", + start_time="", + end_time="", + ) + + ## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1} + + try: + await tpm_rpm_limiter.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=local_cache, + data={}, + call_type="", + ) + + pytest.fail(f"Expected call to fail") + except Exception as e: + assert e.status_code == 429 + + +@pytest.mark.asyncio +async def test_pre_call_hook_team_rpm_limits(): + """ + Test if error raised on hitting team rpm limits + """ + litellm.set_verbose = True + _api_key = "sk-12345" + _team_id = "unique-team-id" + _user_api_key_dict = { + "api_key": _api_key, + "max_parallel_requests": 1, + "tpm_limit": 9, + "rpm_limit": 10, + "team_rpm_limit": 1, + "team_id": _team_id, + } + user_api_key_dict = UserAPIKeyAuth(**_user_api_key_dict) + local_cache = DualCache() + local_cache.set_cache(key=_api_key, value=_user_api_key_dict) + tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(redis_usage_cache=None) + + await tpm_rpm_limiter.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" + ) + + kwargs = { + "litellm_params": { + "metadata": {"user_api_key": _api_key, "user_api_key_team_id": _team_id} + } + } + + await tpm_rpm_limiter.async_log_success_event( + kwargs=kwargs, + response_obj="", + start_time="", + end_time="", + ) + + print(f"local_cache: {local_cache}") + + ## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1} + + try: + await tpm_rpm_limiter.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=local_cache, + data={}, + call_type="", + ) + + pytest.fail(f"Expected call to fail") + except Exception as e: + assert e.status_code == 429