feat - use commong helper for getting model group

This commit is contained in:
Ishaan Jaff 2024-08-17 10:46:04 -07:00
parent d630f77b73
commit 5985c7e933
6 changed files with 56 additions and 18 deletions

View file

@ -11,6 +11,10 @@ from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.auth_utils import (
get_key_model_rpm_limit,
get_key_model_tpm_limit,
)
class _PROXY_MaxParallelRequestsHandler(CustomLogger):
@ -204,8 +208,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
# Check if request under RPM/TPM per model for a given API Key
if (
user_api_key_dict.tpm_limit_per_model
or user_api_key_dict.rpm_limit_per_model
get_key_model_tpm_limit(user_api_key_dict) is not None
or get_key_model_rpm_limit(user_api_key_dict) is not None
):
_model = data.get("model", None)
request_count_api_key = (
@ -219,15 +223,16 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
tpm_limit_for_model = None
rpm_limit_for_model = None
_tpm_limit_for_key_model = get_key_model_tpm_limit(user_api_key_dict)
_rpm_limit_for_key_model = get_key_model_rpm_limit(user_api_key_dict)
if _model is not None:
if user_api_key_dict.tpm_limit_per_model:
tpm_limit_for_model = user_api_key_dict.tpm_limit_per_model.get(
_model
)
if user_api_key_dict.rpm_limit_per_model:
rpm_limit_for_model = user_api_key_dict.rpm_limit_per_model.get(
_model
)
if _tpm_limit_for_key_model:
tpm_limit_for_model = _tpm_limit_for_key_model.get(_model)
if _rpm_limit_for_key_model:
rpm_limit_for_model = _rpm_limit_for_key_model.get(_model)
if current is None:
new_val = {
"current_requests": 1,
@ -371,6 +376,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
return
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
from litellm.proxy.common_utils.callback_utils import (
get_model_group_from_litellm_kwargs,
)
try:
self.print_verbose("INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
@ -438,12 +447,12 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
) # store in cache for 1 min.
# ------------
# Update usage - model + API Key
# Update usage - model group + API Key
# ------------
_model = kwargs.get("model")
if user_api_key is not None and _model is not None:
model_group = get_model_group_from_litellm_kwargs(kwargs)
if user_api_key is not None and model_group is not None:
request_count_api_key = (
f"{user_api_key}::{_model}::{precise_minute}::request_count"
f"{user_api_key}::{model_group}::{precise_minute}::request_count"
)
current = await self.internal_usage_cache.async_get_cache(