mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat - use commong helper for getting model group
This commit is contained in:
parent
d630f77b73
commit
5985c7e933
6 changed files with 56 additions and 18 deletions
|
@ -204,6 +204,9 @@ class PrometheusLogger(CustomLogger):
|
|||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
# Define prometheus client
|
||||
from litellm.proxy.common_utils.callback_utils import (
|
||||
get_model_group_from_litellm_kwargs,
|
||||
)
|
||||
from litellm.proxy.proxy_server import premium_user
|
||||
|
||||
verbose_logger.debug(
|
||||
|
@ -306,7 +309,7 @@ class PrometheusLogger(CustomLogger):
|
|||
|
||||
# Set remaining rpm/tpm for API Key + model
|
||||
# see parallel_request_limiter.py - variables are set there
|
||||
model_group = _metadata.get("model_group", None)
|
||||
model_group = get_model_group_from_litellm_kwargs(kwargs)
|
||||
remaining_requests_variable_name = (
|
||||
f"litellm-key-remaining-requests-{model_group}"
|
||||
)
|
||||
|
|
|
@ -210,3 +210,20 @@ def bytes_to_mb(bytes_value: int):
|
|||
Helper to convert bytes to MB
|
||||
"""
|
||||
return bytes_value / (1024 * 1024)
|
||||
|
||||
|
||||
# helpers used by parallel request limiter to handle model rpm/tpm limits for a given api key
|
||||
def get_key_model_rpm_limit(user_api_key_dict: UserAPIKeyAuth) -> Optional[dict]:
|
||||
if user_api_key_dict.metadata:
|
||||
if "model_rpm_limit" in user_api_key_dict.metadata:
|
||||
return user_api_key_dict.metadata["model_rpm_limit"]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_key_model_tpm_limit(user_api_key_dict: UserAPIKeyAuth) -> Optional[dict]:
|
||||
if user_api_key_dict.metadata:
|
||||
if "model_tpm_limit" in user_api_key_dict.metadata:
|
||||
return user_api_key_dict.metadata["model_tpm_limit"]
|
||||
|
||||
return None
|
||||
|
|
|
@ -249,3 +249,13 @@ def initialize_callbacks_on_proxy(
|
|||
verbose_proxy_logger.debug(
|
||||
f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}"
|
||||
)
|
||||
|
||||
|
||||
def get_model_group_from_litellm_kwargs(kwargs: dict) -> Optional[str]:
|
||||
_litellm_params = kwargs.get("litellm_params", None) or {}
|
||||
_metadata = _litellm_params.get("metadata", None) or {}
|
||||
_model_group = _metadata.get("model_group", None)
|
||||
if _model_group is not None:
|
||||
return _model_group
|
||||
|
||||
return None
|
|
@ -5,7 +5,7 @@ from pydantic import BaseModel, RootModel
|
|||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy
|
||||
from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
|
||||
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
|
||||
|
||||
all_guardrails: List[GuardrailItem] = []
|
||||
|
|
|
@ -11,6 +11,10 @@ from litellm._logging import verbose_proxy_logger
|
|||
from litellm.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.auth_utils import (
|
||||
get_key_model_rpm_limit,
|
||||
get_key_model_tpm_limit,
|
||||
)
|
||||
|
||||
|
||||
class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||
|
@ -204,8 +208,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
|
||||
# Check if request under RPM/TPM per model for a given API Key
|
||||
if (
|
||||
user_api_key_dict.tpm_limit_per_model
|
||||
or user_api_key_dict.rpm_limit_per_model
|
||||
get_key_model_tpm_limit(user_api_key_dict) is not None
|
||||
or get_key_model_rpm_limit(user_api_key_dict) is not None
|
||||
):
|
||||
_model = data.get("model", None)
|
||||
request_count_api_key = (
|
||||
|
@ -219,15 +223,16 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
tpm_limit_for_model = None
|
||||
rpm_limit_for_model = None
|
||||
|
||||
_tpm_limit_for_key_model = get_key_model_tpm_limit(user_api_key_dict)
|
||||
_rpm_limit_for_key_model = get_key_model_rpm_limit(user_api_key_dict)
|
||||
|
||||
if _model is not None:
|
||||
if user_api_key_dict.tpm_limit_per_model:
|
||||
tpm_limit_for_model = user_api_key_dict.tpm_limit_per_model.get(
|
||||
_model
|
||||
)
|
||||
if user_api_key_dict.rpm_limit_per_model:
|
||||
rpm_limit_for_model = user_api_key_dict.rpm_limit_per_model.get(
|
||||
_model
|
||||
)
|
||||
|
||||
if _tpm_limit_for_key_model:
|
||||
tpm_limit_for_model = _tpm_limit_for_key_model.get(_model)
|
||||
|
||||
if _rpm_limit_for_key_model:
|
||||
rpm_limit_for_model = _rpm_limit_for_key_model.get(_model)
|
||||
if current is None:
|
||||
new_val = {
|
||||
"current_requests": 1,
|
||||
|
@ -371,6 +376,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
return
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
from litellm.proxy.common_utils.callback_utils import (
|
||||
get_model_group_from_litellm_kwargs,
|
||||
)
|
||||
|
||||
try:
|
||||
self.print_verbose("INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
|
||||
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
|
||||
|
@ -438,12 +447,12 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
) # store in cache for 1 min.
|
||||
|
||||
# ------------
|
||||
# Update usage - model + API Key
|
||||
# Update usage - model group + API Key
|
||||
# ------------
|
||||
_model = kwargs.get("model")
|
||||
if user_api_key is not None and _model is not None:
|
||||
model_group = get_model_group_from_litellm_kwargs(kwargs)
|
||||
if user_api_key is not None and model_group is not None:
|
||||
request_count_api_key = (
|
||||
f"{user_api_key}::{_model}::{precise_minute}::request_count"
|
||||
f"{user_api_key}::{model_group}::{precise_minute}::request_count"
|
||||
)
|
||||
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
|
|
|
@ -148,6 +148,7 @@ from litellm.proxy.common_utils.admin_ui_utils import (
|
|||
html_form,
|
||||
show_missing_vars_in_env,
|
||||
)
|
||||
from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
|
||||
from litellm.proxy.common_utils.debug_utils import init_verbose_loggers
|
||||
from litellm.proxy.common_utils.debug_utils import router as debugging_endpoints_router
|
||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
|
||||
|
@ -158,7 +159,6 @@ from litellm.proxy.common_utils.http_parsing_utils import (
|
|||
_read_request_body,
|
||||
check_file_size_under_limit,
|
||||
)
|
||||
from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy
|
||||
from litellm.proxy.common_utils.load_config_utils import get_file_contents_from_s3
|
||||
from litellm.proxy.common_utils.openai_endpoint_utils import (
|
||||
remove_sensitive_info_from_deployment,
|
||||
|
@ -199,7 +199,6 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
|||
router as pass_through_router,
|
||||
)
|
||||
from litellm.proxy.route_llm_request import route_request
|
||||
|
||||
from litellm.proxy.secret_managers.aws_secret_manager import (
|
||||
load_aws_kms,
|
||||
load_aws_secret_manager,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue