mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat - use commong helper for getting model group
This commit is contained in:
parent
d630f77b73
commit
5985c7e933
6 changed files with 56 additions and 18 deletions
|
@ -204,6 +204,9 @@ class PrometheusLogger(CustomLogger):
|
||||||
|
|
||||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
# Define prometheus client
|
# Define prometheus client
|
||||||
|
from litellm.proxy.common_utils.callback_utils import (
|
||||||
|
get_model_group_from_litellm_kwargs,
|
||||||
|
)
|
||||||
from litellm.proxy.proxy_server import premium_user
|
from litellm.proxy.proxy_server import premium_user
|
||||||
|
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
|
@ -306,7 +309,7 @@ class PrometheusLogger(CustomLogger):
|
||||||
|
|
||||||
# Set remaining rpm/tpm for API Key + model
|
# Set remaining rpm/tpm for API Key + model
|
||||||
# see parallel_request_limiter.py - variables are set there
|
# see parallel_request_limiter.py - variables are set there
|
||||||
model_group = _metadata.get("model_group", None)
|
model_group = get_model_group_from_litellm_kwargs(kwargs)
|
||||||
remaining_requests_variable_name = (
|
remaining_requests_variable_name = (
|
||||||
f"litellm-key-remaining-requests-{model_group}"
|
f"litellm-key-remaining-requests-{model_group}"
|
||||||
)
|
)
|
||||||
|
|
|
@ -210,3 +210,20 @@ def bytes_to_mb(bytes_value: int):
|
||||||
Helper to convert bytes to MB
|
Helper to convert bytes to MB
|
||||||
"""
|
"""
|
||||||
return bytes_value / (1024 * 1024)
|
return bytes_value / (1024 * 1024)
|
||||||
|
|
||||||
|
|
||||||
|
# helpers used by parallel request limiter to handle model rpm/tpm limits for a given api key
|
||||||
|
def get_key_model_rpm_limit(user_api_key_dict: UserAPIKeyAuth) -> Optional[dict]:
|
||||||
|
if user_api_key_dict.metadata:
|
||||||
|
if "model_rpm_limit" in user_api_key_dict.metadata:
|
||||||
|
return user_api_key_dict.metadata["model_rpm_limit"]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_key_model_tpm_limit(user_api_key_dict: UserAPIKeyAuth) -> Optional[dict]:
|
||||||
|
if user_api_key_dict.metadata:
|
||||||
|
if "model_tpm_limit" in user_api_key_dict.metadata:
|
||||||
|
return user_api_key_dict.metadata["model_tpm_limit"]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
|
@ -249,3 +249,13 @@ def initialize_callbacks_on_proxy(
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}"
|
f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_group_from_litellm_kwargs(kwargs: dict) -> Optional[str]:
|
||||||
|
_litellm_params = kwargs.get("litellm_params", None) or {}
|
||||||
|
_metadata = _litellm_params.get("metadata", None) or {}
|
||||||
|
_model_group = _metadata.get("model_group", None)
|
||||||
|
if _model_group is not None:
|
||||||
|
return _model_group
|
||||||
|
|
||||||
|
return None
|
|
@ -5,7 +5,7 @@ from pydantic import BaseModel, RootModel
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy
|
from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
|
||||||
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
|
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
|
||||||
|
|
||||||
all_guardrails: List[GuardrailItem] = []
|
all_guardrails: List[GuardrailItem] = []
|
||||||
|
|
|
@ -11,6 +11,10 @@ from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.proxy.auth.auth_utils import (
|
||||||
|
get_key_model_rpm_limit,
|
||||||
|
get_key_model_tpm_limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
|
@ -204,8 +208,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
|
|
||||||
# Check if request under RPM/TPM per model for a given API Key
|
# Check if request under RPM/TPM per model for a given API Key
|
||||||
if (
|
if (
|
||||||
user_api_key_dict.tpm_limit_per_model
|
get_key_model_tpm_limit(user_api_key_dict) is not None
|
||||||
or user_api_key_dict.rpm_limit_per_model
|
or get_key_model_rpm_limit(user_api_key_dict) is not None
|
||||||
):
|
):
|
||||||
_model = data.get("model", None)
|
_model = data.get("model", None)
|
||||||
request_count_api_key = (
|
request_count_api_key = (
|
||||||
|
@ -219,15 +223,16 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
tpm_limit_for_model = None
|
tpm_limit_for_model = None
|
||||||
rpm_limit_for_model = None
|
rpm_limit_for_model = None
|
||||||
|
|
||||||
|
_tpm_limit_for_key_model = get_key_model_tpm_limit(user_api_key_dict)
|
||||||
|
_rpm_limit_for_key_model = get_key_model_rpm_limit(user_api_key_dict)
|
||||||
|
|
||||||
if _model is not None:
|
if _model is not None:
|
||||||
if user_api_key_dict.tpm_limit_per_model:
|
|
||||||
tpm_limit_for_model = user_api_key_dict.tpm_limit_per_model.get(
|
if _tpm_limit_for_key_model:
|
||||||
_model
|
tpm_limit_for_model = _tpm_limit_for_key_model.get(_model)
|
||||||
)
|
|
||||||
if user_api_key_dict.rpm_limit_per_model:
|
if _rpm_limit_for_key_model:
|
||||||
rpm_limit_for_model = user_api_key_dict.rpm_limit_per_model.get(
|
rpm_limit_for_model = _rpm_limit_for_key_model.get(_model)
|
||||||
_model
|
|
||||||
)
|
|
||||||
if current is None:
|
if current is None:
|
||||||
new_val = {
|
new_val = {
|
||||||
"current_requests": 1,
|
"current_requests": 1,
|
||||||
|
@ -371,6 +376,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
return
|
return
|
||||||
|
|
||||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
from litellm.proxy.common_utils.callback_utils import (
|
||||||
|
get_model_group_from_litellm_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.print_verbose("INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
|
self.print_verbose("INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
|
||||||
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
|
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
|
||||||
|
@ -438,12 +447,12 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
) # store in cache for 1 min.
|
) # store in cache for 1 min.
|
||||||
|
|
||||||
# ------------
|
# ------------
|
||||||
# Update usage - model + API Key
|
# Update usage - model group + API Key
|
||||||
# ------------
|
# ------------
|
||||||
_model = kwargs.get("model")
|
model_group = get_model_group_from_litellm_kwargs(kwargs)
|
||||||
if user_api_key is not None and _model is not None:
|
if user_api_key is not None and model_group is not None:
|
||||||
request_count_api_key = (
|
request_count_api_key = (
|
||||||
f"{user_api_key}::{_model}::{precise_minute}::request_count"
|
f"{user_api_key}::{model_group}::{precise_minute}::request_count"
|
||||||
)
|
)
|
||||||
|
|
||||||
current = await self.internal_usage_cache.async_get_cache(
|
current = await self.internal_usage_cache.async_get_cache(
|
||||||
|
|
|
@ -148,6 +148,7 @@ from litellm.proxy.common_utils.admin_ui_utils import (
|
||||||
html_form,
|
html_form,
|
||||||
show_missing_vars_in_env,
|
show_missing_vars_in_env,
|
||||||
)
|
)
|
||||||
|
from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
|
||||||
from litellm.proxy.common_utils.debug_utils import init_verbose_loggers
|
from litellm.proxy.common_utils.debug_utils import init_verbose_loggers
|
||||||
from litellm.proxy.common_utils.debug_utils import router as debugging_endpoints_router
|
from litellm.proxy.common_utils.debug_utils import router as debugging_endpoints_router
|
||||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
|
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
|
||||||
|
@ -158,7 +159,6 @@ from litellm.proxy.common_utils.http_parsing_utils import (
|
||||||
_read_request_body,
|
_read_request_body,
|
||||||
check_file_size_under_limit,
|
check_file_size_under_limit,
|
||||||
)
|
)
|
||||||
from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy
|
|
||||||
from litellm.proxy.common_utils.load_config_utils import get_file_contents_from_s3
|
from litellm.proxy.common_utils.load_config_utils import get_file_contents_from_s3
|
||||||
from litellm.proxy.common_utils.openai_endpoint_utils import (
|
from litellm.proxy.common_utils.openai_endpoint_utils import (
|
||||||
remove_sensitive_info_from_deployment,
|
remove_sensitive_info_from_deployment,
|
||||||
|
@ -199,7 +199,6 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
||||||
router as pass_through_router,
|
router as pass_through_router,
|
||||||
)
|
)
|
||||||
from litellm.proxy.route_llm_request import route_request
|
from litellm.proxy.route_llm_request import route_request
|
||||||
|
|
||||||
from litellm.proxy.secret_managers.aws_secret_manager import (
|
from litellm.proxy.secret_managers.aws_secret_manager import (
|
||||||
load_aws_kms,
|
load_aws_kms,
|
||||||
load_aws_secret_manager,
|
load_aws_secret_manager,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue