From 5985c7e9336d1cfb85e414e4c729ff5c55b64e64 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 17 Aug 2024 10:46:04 -0700 Subject: [PATCH] feat - use commong helper for getting model group --- litellm/integrations/prometheus.py | 5 ++- litellm/proxy/auth/auth_utils.py | 17 +++++++++ .../{init_callbacks.py => callback_utils.py} | 10 +++++ litellm/proxy/guardrails/init_guardrails.py | 2 +- .../proxy/hooks/parallel_request_limiter.py | 37 ++++++++++++------- litellm/proxy/proxy_server.py | 3 +- 6 files changed, 56 insertions(+), 18 deletions(-) rename litellm/proxy/common_utils/{init_callbacks.py => callback_utils.py} (97%) diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py index 3141f36fb0..51c3bda4c7 100644 --- a/litellm/integrations/prometheus.py +++ b/litellm/integrations/prometheus.py @@ -204,6 +204,9 @@ class PrometheusLogger(CustomLogger): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): # Define prometheus client + from litellm.proxy.common_utils.callback_utils import ( + get_model_group_from_litellm_kwargs, + ) from litellm.proxy.proxy_server import premium_user verbose_logger.debug( @@ -306,7 +309,7 @@ class PrometheusLogger(CustomLogger): # Set remaining rpm/tpm for API Key + model # see parallel_request_limiter.py - variables are set there - model_group = _metadata.get("model_group", None) + model_group = get_model_group_from_litellm_kwargs(kwargs) remaining_requests_variable_name = ( f"litellm-key-remaining-requests-{model_group}" ) diff --git a/litellm/proxy/auth/auth_utils.py b/litellm/proxy/auth/auth_utils.py index 2f26ec6533..aff51624af 100644 --- a/litellm/proxy/auth/auth_utils.py +++ b/litellm/proxy/auth/auth_utils.py @@ -210,3 +210,20 @@ def bytes_to_mb(bytes_value: int): Helper to convert bytes to MB """ return bytes_value / (1024 * 1024) + + +# helpers used by parallel request limiter to handle model rpm/tpm limits for a given api key +def get_key_model_rpm_limit(user_api_key_dict: UserAPIKeyAuth) -> Optional[dict]: + if user_api_key_dict.metadata: + if "model_rpm_limit" in user_api_key_dict.metadata: + return user_api_key_dict.metadata["model_rpm_limit"] + + return None + + +def get_key_model_tpm_limit(user_api_key_dict: UserAPIKeyAuth) -> Optional[dict]: + if user_api_key_dict.metadata: + if "model_tpm_limit" in user_api_key_dict.metadata: + return user_api_key_dict.metadata["model_tpm_limit"] + + return None diff --git a/litellm/proxy/common_utils/init_callbacks.py b/litellm/proxy/common_utils/callback_utils.py similarity index 97% rename from litellm/proxy/common_utils/init_callbacks.py rename to litellm/proxy/common_utils/callback_utils.py index fbbfdcf018..2430ef3dcb 100644 --- a/litellm/proxy/common_utils/init_callbacks.py +++ b/litellm/proxy/common_utils/callback_utils.py @@ -249,3 +249,13 @@ def initialize_callbacks_on_proxy( verbose_proxy_logger.debug( f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}" ) + + +def get_model_group_from_litellm_kwargs(kwargs: dict) -> Optional[str]: + _litellm_params = kwargs.get("litellm_params", None) or {} + _metadata = _litellm_params.get("metadata", None) or {} + _model_group = _metadata.get("model_group", None) + if _model_group is not None: + return _model_group + + return None diff --git a/litellm/proxy/guardrails/init_guardrails.py b/litellm/proxy/guardrails/init_guardrails.py index de61818689..bfd3150ad4 100644 --- a/litellm/proxy/guardrails/init_guardrails.py +++ b/litellm/proxy/guardrails/init_guardrails.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, RootModel import litellm from litellm._logging import verbose_proxy_logger -from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy +from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec all_guardrails: List[GuardrailItem] = [] diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index b7f923631e..c9e2234d56 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -11,6 +11,10 @@ from litellm._logging import verbose_proxy_logger from litellm.caching import DualCache from litellm.integrations.custom_logger import CustomLogger from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.auth.auth_utils import ( + get_key_model_rpm_limit, + get_key_model_tpm_limit, +) class _PROXY_MaxParallelRequestsHandler(CustomLogger): @@ -204,8 +208,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): # Check if request under RPM/TPM per model for a given API Key if ( - user_api_key_dict.tpm_limit_per_model - or user_api_key_dict.rpm_limit_per_model + get_key_model_tpm_limit(user_api_key_dict) is not None + or get_key_model_rpm_limit(user_api_key_dict) is not None ): _model = data.get("model", None) request_count_api_key = ( @@ -219,15 +223,16 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): tpm_limit_for_model = None rpm_limit_for_model = None + _tpm_limit_for_key_model = get_key_model_tpm_limit(user_api_key_dict) + _rpm_limit_for_key_model = get_key_model_rpm_limit(user_api_key_dict) + if _model is not None: - if user_api_key_dict.tpm_limit_per_model: - tpm_limit_for_model = user_api_key_dict.tpm_limit_per_model.get( - _model - ) - if user_api_key_dict.rpm_limit_per_model: - rpm_limit_for_model = user_api_key_dict.rpm_limit_per_model.get( - _model - ) + + if _tpm_limit_for_key_model: + tpm_limit_for_model = _tpm_limit_for_key_model.get(_model) + + if _rpm_limit_for_key_model: + rpm_limit_for_model = _rpm_limit_for_key_model.get(_model) if current is None: new_val = { "current_requests": 1, @@ -371,6 +376,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): return async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + from litellm.proxy.common_utils.callback_utils import ( + get_model_group_from_litellm_kwargs, + ) + try: self.print_verbose("INSIDE parallel request limiter ASYNC SUCCESS LOGGING") global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get( @@ -438,12 +447,12 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): ) # store in cache for 1 min. # ------------ - # Update usage - model + API Key + # Update usage - model group + API Key # ------------ - _model = kwargs.get("model") - if user_api_key is not None and _model is not None: + model_group = get_model_group_from_litellm_kwargs(kwargs) + if user_api_key is not None and model_group is not None: request_count_api_key = ( - f"{user_api_key}::{_model}::{precise_minute}::request_count" + f"{user_api_key}::{model_group}::{precise_minute}::request_count" ) current = await self.internal_usage_cache.async_get_cache( diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 10c06b2ece..ad4dc0ddea 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -148,6 +148,7 @@ from litellm.proxy.common_utils.admin_ui_utils import ( html_form, show_missing_vars_in_env, ) +from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy from litellm.proxy.common_utils.debug_utils import init_verbose_loggers from litellm.proxy.common_utils.debug_utils import router as debugging_endpoints_router from litellm.proxy.common_utils.encrypt_decrypt_utils import ( @@ -158,7 +159,6 @@ from litellm.proxy.common_utils.http_parsing_utils import ( _read_request_body, check_file_size_under_limit, ) -from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy from litellm.proxy.common_utils.load_config_utils import get_file_contents_from_s3 from litellm.proxy.common_utils.openai_endpoint_utils import ( remove_sensitive_info_from_deployment, @@ -199,7 +199,6 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( router as pass_through_router, ) from litellm.proxy.route_llm_request import route_request - from litellm.proxy.secret_managers.aws_secret_manager import ( load_aws_kms, load_aws_secret_manager,