feat - use commong helper for getting model group

This commit is contained in:
Ishaan Jaff 2024-08-17 10:46:04 -07:00
parent d630f77b73
commit 5985c7e933
6 changed files with 56 additions and 18 deletions

View file

@ -204,6 +204,9 @@ class PrometheusLogger(CustomLogger):
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
# Define prometheus client # Define prometheus client
from litellm.proxy.common_utils.callback_utils import (
get_model_group_from_litellm_kwargs,
)
from litellm.proxy.proxy_server import premium_user from litellm.proxy.proxy_server import premium_user
verbose_logger.debug( verbose_logger.debug(
@ -306,7 +309,7 @@ class PrometheusLogger(CustomLogger):
# Set remaining rpm/tpm for API Key + model # Set remaining rpm/tpm for API Key + model
# see parallel_request_limiter.py - variables are set there # see parallel_request_limiter.py - variables are set there
model_group = _metadata.get("model_group", None) model_group = get_model_group_from_litellm_kwargs(kwargs)
remaining_requests_variable_name = ( remaining_requests_variable_name = (
f"litellm-key-remaining-requests-{model_group}" f"litellm-key-remaining-requests-{model_group}"
) )

View file

@ -210,3 +210,20 @@ def bytes_to_mb(bytes_value: int):
Helper to convert bytes to MB Helper to convert bytes to MB
""" """
return bytes_value / (1024 * 1024) return bytes_value / (1024 * 1024)
# helpers used by parallel request limiter to handle model rpm/tpm limits for a given api key
def get_key_model_rpm_limit(user_api_key_dict: UserAPIKeyAuth) -> Optional[dict]:
if user_api_key_dict.metadata:
if "model_rpm_limit" in user_api_key_dict.metadata:
return user_api_key_dict.metadata["model_rpm_limit"]
return None
def get_key_model_tpm_limit(user_api_key_dict: UserAPIKeyAuth) -> Optional[dict]:
if user_api_key_dict.metadata:
if "model_tpm_limit" in user_api_key_dict.metadata:
return user_api_key_dict.metadata["model_tpm_limit"]
return None

View file

@ -249,3 +249,13 @@ def initialize_callbacks_on_proxy(
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}" f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}"
) )
def get_model_group_from_litellm_kwargs(kwargs: dict) -> Optional[str]:
_litellm_params = kwargs.get("litellm_params", None) or {}
_metadata = _litellm_params.get("metadata", None) or {}
_model_group = _metadata.get("model_group", None)
if _model_group is not None:
return _model_group
return None

View file

@ -5,7 +5,7 @@ from pydantic import BaseModel, RootModel
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
all_guardrails: List[GuardrailItem] = [] all_guardrails: List[GuardrailItem] = []

View file

@ -11,6 +11,10 @@ from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.auth_utils import (
get_key_model_rpm_limit,
get_key_model_tpm_limit,
)
class _PROXY_MaxParallelRequestsHandler(CustomLogger): class _PROXY_MaxParallelRequestsHandler(CustomLogger):
@ -204,8 +208,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
# Check if request under RPM/TPM per model for a given API Key # Check if request under RPM/TPM per model for a given API Key
if ( if (
user_api_key_dict.tpm_limit_per_model get_key_model_tpm_limit(user_api_key_dict) is not None
or user_api_key_dict.rpm_limit_per_model or get_key_model_rpm_limit(user_api_key_dict) is not None
): ):
_model = data.get("model", None) _model = data.get("model", None)
request_count_api_key = ( request_count_api_key = (
@ -219,15 +223,16 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
tpm_limit_for_model = None tpm_limit_for_model = None
rpm_limit_for_model = None rpm_limit_for_model = None
_tpm_limit_for_key_model = get_key_model_tpm_limit(user_api_key_dict)
_rpm_limit_for_key_model = get_key_model_rpm_limit(user_api_key_dict)
if _model is not None: if _model is not None:
if user_api_key_dict.tpm_limit_per_model:
tpm_limit_for_model = user_api_key_dict.tpm_limit_per_model.get( if _tpm_limit_for_key_model:
_model tpm_limit_for_model = _tpm_limit_for_key_model.get(_model)
)
if user_api_key_dict.rpm_limit_per_model: if _rpm_limit_for_key_model:
rpm_limit_for_model = user_api_key_dict.rpm_limit_per_model.get( rpm_limit_for_model = _rpm_limit_for_key_model.get(_model)
_model
)
if current is None: if current is None:
new_val = { new_val = {
"current_requests": 1, "current_requests": 1,
@ -371,6 +376,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
return return
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
from litellm.proxy.common_utils.callback_utils import (
get_model_group_from_litellm_kwargs,
)
try: try:
self.print_verbose("INSIDE parallel request limiter ASYNC SUCCESS LOGGING") self.print_verbose("INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get( global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
@ -438,12 +447,12 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
) # store in cache for 1 min. ) # store in cache for 1 min.
# ------------ # ------------
# Update usage - model + API Key # Update usage - model group + API Key
# ------------ # ------------
_model = kwargs.get("model") model_group = get_model_group_from_litellm_kwargs(kwargs)
if user_api_key is not None and _model is not None: if user_api_key is not None and model_group is not None:
request_count_api_key = ( request_count_api_key = (
f"{user_api_key}::{_model}::{precise_minute}::request_count" f"{user_api_key}::{model_group}::{precise_minute}::request_count"
) )
current = await self.internal_usage_cache.async_get_cache( current = await self.internal_usage_cache.async_get_cache(

View file

@ -148,6 +148,7 @@ from litellm.proxy.common_utils.admin_ui_utils import (
html_form, html_form,
show_missing_vars_in_env, show_missing_vars_in_env,
) )
from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
from litellm.proxy.common_utils.debug_utils import init_verbose_loggers from litellm.proxy.common_utils.debug_utils import init_verbose_loggers
from litellm.proxy.common_utils.debug_utils import router as debugging_endpoints_router from litellm.proxy.common_utils.debug_utils import router as debugging_endpoints_router
from litellm.proxy.common_utils.encrypt_decrypt_utils import ( from litellm.proxy.common_utils.encrypt_decrypt_utils import (
@ -158,7 +159,6 @@ from litellm.proxy.common_utils.http_parsing_utils import (
_read_request_body, _read_request_body,
check_file_size_under_limit, check_file_size_under_limit,
) )
from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy
from litellm.proxy.common_utils.load_config_utils import get_file_contents_from_s3 from litellm.proxy.common_utils.load_config_utils import get_file_contents_from_s3
from litellm.proxy.common_utils.openai_endpoint_utils import ( from litellm.proxy.common_utils.openai_endpoint_utils import (
remove_sensitive_info_from_deployment, remove_sensitive_info_from_deployment,
@ -199,7 +199,6 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
router as pass_through_router, router as pass_through_router,
) )
from litellm.proxy.route_llm_request import route_request from litellm.proxy.route_llm_request import route_request
from litellm.proxy.secret_managers.aws_secret_manager import ( from litellm.proxy.secret_managers.aws_secret_manager import (
load_aws_kms, load_aws_kms,
load_aws_secret_manager, load_aws_secret_manager,