mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
Merge pull request #5259 from BerriAI/litellm_return_remaining_tokens_in_header
[Feat] return `x-litellm-key-remaining-requests-{model}`: 1, `x-litellm-key-remaining-tokens-{model}: None` in response headers
This commit is contained in:
commit
feb8c3c5b4
9 changed files with 518 additions and 11 deletions
|
@ -103,13 +103,30 @@ class PrometheusLogger(CustomLogger):
|
||||||
"Remaining budget for api key",
|
"Remaining budget for api key",
|
||||||
labelnames=["hashed_api_key", "api_key_alias"],
|
labelnames=["hashed_api_key", "api_key_alias"],
|
||||||
)
|
)
|
||||||
|
# Litellm-Enterprise Metrics
|
||||||
|
if premium_user is True:
|
||||||
|
|
||||||
|
########################################
|
||||||
|
# LiteLLM Virtual API KEY metrics
|
||||||
|
########################################
|
||||||
|
# Remaining MODEL RPM limit for API Key
|
||||||
|
self.litellm_remaining_api_key_requests_for_model = Gauge(
|
||||||
|
"litellm_remaining_api_key_requests_for_model",
|
||||||
|
"Remaining Requests API Key can make for model (model based rpm limit on key)",
|
||||||
|
labelnames=["hashed_api_key", "api_key_alias", "model"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remaining MODEL TPM limit for API Key
|
||||||
|
self.litellm_remaining_api_key_tokens_for_model = Gauge(
|
||||||
|
"litellm_remaining_api_key_tokens_for_model",
|
||||||
|
"Remaining Tokens API Key can make for model (model based tpm limit on key)",
|
||||||
|
labelnames=["hashed_api_key", "api_key_alias", "model"],
|
||||||
|
)
|
||||||
|
|
||||||
########################################
|
########################################
|
||||||
# LLM API Deployment Metrics / analytics
|
# LLM API Deployment Metrics / analytics
|
||||||
########################################
|
########################################
|
||||||
|
|
||||||
# Litellm-Enterprise Metrics
|
|
||||||
if premium_user is True:
|
|
||||||
# Remaining Rate Limit for model
|
# Remaining Rate Limit for model
|
||||||
self.litellm_remaining_requests_metric = Gauge(
|
self.litellm_remaining_requests_metric = Gauge(
|
||||||
"litellm_remaining_requests",
|
"litellm_remaining_requests",
|
||||||
|
@ -187,6 +204,9 @@ class PrometheusLogger(CustomLogger):
|
||||||
|
|
||||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
# Define prometheus client
|
# Define prometheus client
|
||||||
|
from litellm.proxy.common_utils.callback_utils import (
|
||||||
|
get_model_group_from_litellm_kwargs,
|
||||||
|
)
|
||||||
from litellm.proxy.proxy_server import premium_user
|
from litellm.proxy.proxy_server import premium_user
|
||||||
|
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
|
@ -197,6 +217,7 @@ class PrometheusLogger(CustomLogger):
|
||||||
model = kwargs.get("model", "")
|
model = kwargs.get("model", "")
|
||||||
response_cost = kwargs.get("response_cost", 0.0) or 0
|
response_cost = kwargs.get("response_cost", 0.0) or 0
|
||||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||||
|
_metadata = litellm_params.get("metadata", {})
|
||||||
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
||||||
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
||||||
user_id = litellm_params.get("metadata", {}).get("user_api_key_user_id", None)
|
user_id = litellm_params.get("metadata", {}).get("user_api_key_user_id", None)
|
||||||
|
@ -286,6 +307,27 @@ class PrometheusLogger(CustomLogger):
|
||||||
user_api_key, user_api_key_alias
|
user_api_key, user_api_key_alias
|
||||||
).set(_remaining_api_key_budget)
|
).set(_remaining_api_key_budget)
|
||||||
|
|
||||||
|
# Set remaining rpm/tpm for API Key + model
|
||||||
|
# see parallel_request_limiter.py - variables are set there
|
||||||
|
model_group = get_model_group_from_litellm_kwargs(kwargs)
|
||||||
|
remaining_requests_variable_name = (
|
||||||
|
f"litellm-key-remaining-requests-{model_group}"
|
||||||
|
)
|
||||||
|
remaining_tokens_variable_name = f"litellm-key-remaining-tokens-{model_group}"
|
||||||
|
|
||||||
|
remaining_requests = _metadata.get(
|
||||||
|
remaining_requests_variable_name, sys.maxsize
|
||||||
|
)
|
||||||
|
remaining_tokens = _metadata.get(remaining_tokens_variable_name, sys.maxsize)
|
||||||
|
|
||||||
|
self.litellm_remaining_api_key_requests_for_model.labels(
|
||||||
|
user_api_key, user_api_key_alias, model_group
|
||||||
|
).set(remaining_requests)
|
||||||
|
|
||||||
|
self.litellm_remaining_api_key_tokens_for_model.labels(
|
||||||
|
user_api_key, user_api_key_alias, model_group
|
||||||
|
).set(remaining_tokens)
|
||||||
|
|
||||||
# set x-ratelimit headers
|
# set x-ratelimit headers
|
||||||
if premium_user is True:
|
if premium_user is True:
|
||||||
self.set_llm_deployment_success_metrics(
|
self.set_llm_deployment_success_metrics(
|
||||||
|
|
|
@ -1337,6 +1337,8 @@ class UserAPIKeyAuth(
|
||||||
] = None
|
] = None
|
||||||
allowed_model_region: Optional[Literal["eu"]] = None
|
allowed_model_region: Optional[Literal["eu"]] = None
|
||||||
parent_otel_span: Optional[Span] = None
|
parent_otel_span: Optional[Span] = None
|
||||||
|
rpm_limit_per_model: Optional[Dict[str, int]] = None
|
||||||
|
tpm_limit_per_model: Optional[Dict[str, int]] = None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -210,3 +210,20 @@ def bytes_to_mb(bytes_value: int):
|
||||||
Helper to convert bytes to MB
|
Helper to convert bytes to MB
|
||||||
"""
|
"""
|
||||||
return bytes_value / (1024 * 1024)
|
return bytes_value / (1024 * 1024)
|
||||||
|
|
||||||
|
|
||||||
|
# helpers used by parallel request limiter to handle model rpm/tpm limits for a given api key
|
||||||
|
def get_key_model_rpm_limit(user_api_key_dict: UserAPIKeyAuth) -> Optional[dict]:
|
||||||
|
if user_api_key_dict.metadata:
|
||||||
|
if "model_rpm_limit" in user_api_key_dict.metadata:
|
||||||
|
return user_api_key_dict.metadata["model_rpm_limit"]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_key_model_tpm_limit(user_api_key_dict: UserAPIKeyAuth) -> Optional[dict]:
|
||||||
|
if user_api_key_dict.metadata:
|
||||||
|
if "model_tpm_limit" in user_api_key_dict.metadata:
|
||||||
|
return user_api_key_dict.metadata["model_tpm_limit"]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from typing import Any, List, Optional, get_args
|
import sys
|
||||||
|
from typing import Any, Dict, List, Optional, get_args
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
@ -249,3 +250,46 @@ def initialize_callbacks_on_proxy(
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}"
|
f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_group_from_litellm_kwargs(kwargs: dict) -> Optional[str]:
|
||||||
|
_litellm_params = kwargs.get("litellm_params", None) or {}
|
||||||
|
_metadata = _litellm_params.get("metadata", None) or {}
|
||||||
|
_model_group = _metadata.get("model_group", None)
|
||||||
|
if _model_group is not None:
|
||||||
|
return _model_group
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_group_from_request_data(data: dict) -> Optional[str]:
|
||||||
|
_metadata = data.get("metadata", None) or {}
|
||||||
|
_model_group = _metadata.get("model_group", None)
|
||||||
|
if _model_group is not None:
|
||||||
|
return _model_group
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_remaining_tokens_and_requests_from_request_data(data: Dict) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
Helper function to return x-litellm-key-remaining-tokens-{model_group} and x-litellm-key-remaining-requests-{model_group}
|
||||||
|
|
||||||
|
Returns {} when api_key + model rpm/tpm limit is not set
|
||||||
|
|
||||||
|
"""
|
||||||
|
_metadata = data.get("metadata", None) or {}
|
||||||
|
model_group = get_model_group_from_request_data(data)
|
||||||
|
|
||||||
|
# Remaining Requests
|
||||||
|
remaining_requests_variable_name = f"litellm-key-remaining-requests-{model_group}"
|
||||||
|
remaining_requests = _metadata.get(remaining_requests_variable_name, None)
|
||||||
|
|
||||||
|
# Remaining Tokens
|
||||||
|
remaining_tokens_variable_name = f"litellm-key-remaining-tokens-{model_group}"
|
||||||
|
remaining_tokens = _metadata.get(remaining_tokens_variable_name, None)
|
||||||
|
|
||||||
|
return {
|
||||||
|
f"x-litellm-key-remaining-requests-{model_group}": str(remaining_requests),
|
||||||
|
f"x-litellm-key-remaining-tokens-{model_group}": str(remaining_tokens),
|
||||||
|
}
|
|
@ -5,7 +5,7 @@ from pydantic import BaseModel, RootModel
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy
|
from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
|
||||||
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
|
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
|
||||||
|
|
||||||
all_guardrails: List[GuardrailItem] = []
|
all_guardrails: List[GuardrailItem] = []
|
||||||
|
|
|
@ -11,6 +11,10 @@ from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.proxy.auth.auth_utils import (
|
||||||
|
get_key_model_rpm_limit,
|
||||||
|
get_key_model_tpm_limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
|
@ -202,6 +206,82 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
additional_details=f"Hit limit for api_key: {api_key}. tpm_limit: {tpm_limit}, current_tpm {current['current_tpm']} , rpm_limit: {rpm_limit} current rpm {current['current_rpm']} "
|
additional_details=f"Hit limit for api_key: {api_key}. tpm_limit: {tpm_limit}, current_tpm {current['current_tpm']} , rpm_limit: {rpm_limit} current rpm {current['current_rpm']} "
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check if request under RPM/TPM per model for a given API Key
|
||||||
|
if (
|
||||||
|
get_key_model_tpm_limit(user_api_key_dict) is not None
|
||||||
|
or get_key_model_rpm_limit(user_api_key_dict) is not None
|
||||||
|
):
|
||||||
|
_model = data.get("model", None)
|
||||||
|
request_count_api_key = (
|
||||||
|
f"{api_key}::{_model}::{precise_minute}::request_count"
|
||||||
|
)
|
||||||
|
|
||||||
|
current = await self.internal_usage_cache.async_get_cache(
|
||||||
|
key=request_count_api_key
|
||||||
|
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
|
||||||
|
|
||||||
|
tpm_limit_for_model = None
|
||||||
|
rpm_limit_for_model = None
|
||||||
|
|
||||||
|
_tpm_limit_for_key_model = get_key_model_tpm_limit(user_api_key_dict)
|
||||||
|
_rpm_limit_for_key_model = get_key_model_rpm_limit(user_api_key_dict)
|
||||||
|
|
||||||
|
if _model is not None:
|
||||||
|
|
||||||
|
if _tpm_limit_for_key_model:
|
||||||
|
tpm_limit_for_model = _tpm_limit_for_key_model.get(_model)
|
||||||
|
|
||||||
|
if _rpm_limit_for_key_model:
|
||||||
|
rpm_limit_for_model = _rpm_limit_for_key_model.get(_model)
|
||||||
|
if current is None:
|
||||||
|
new_val = {
|
||||||
|
"current_requests": 1,
|
||||||
|
"current_tpm": 0,
|
||||||
|
"current_rpm": 0,
|
||||||
|
}
|
||||||
|
await self.internal_usage_cache.async_set_cache(
|
||||||
|
request_count_api_key, new_val
|
||||||
|
)
|
||||||
|
elif tpm_limit_for_model is not None or rpm_limit_for_model is not None:
|
||||||
|
# Increase count for this token
|
||||||
|
new_val = {
|
||||||
|
"current_requests": current["current_requests"] + 1,
|
||||||
|
"current_tpm": current["current_tpm"],
|
||||||
|
"current_rpm": current["current_rpm"],
|
||||||
|
}
|
||||||
|
if (
|
||||||
|
tpm_limit_for_model is not None
|
||||||
|
and current["current_tpm"] >= tpm_limit_for_model
|
||||||
|
):
|
||||||
|
return self.raise_rate_limit_error(
|
||||||
|
additional_details=f"Hit TPM limit for model: {_model} on api_key: {api_key}. tpm_limit: {tpm_limit_for_model}, current_tpm {current['current_tpm']} "
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
rpm_limit_for_model is not None
|
||||||
|
and current["current_rpm"] >= rpm_limit_for_model
|
||||||
|
):
|
||||||
|
return self.raise_rate_limit_error(
|
||||||
|
additional_details=f"Hit RPM limit for model: {_model} on api_key: {api_key}. rpm_limit: {rpm_limit_for_model}, current_rpm {current['current_rpm']} "
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self.internal_usage_cache.async_set_cache(
|
||||||
|
request_count_api_key, new_val
|
||||||
|
)
|
||||||
|
|
||||||
|
_remaining_tokens = None
|
||||||
|
_remaining_requests = None
|
||||||
|
# Add remaining tokens, requests to metadata
|
||||||
|
if tpm_limit_for_model is not None:
|
||||||
|
_remaining_tokens = tpm_limit_for_model - new_val["current_tpm"]
|
||||||
|
if rpm_limit_for_model is not None:
|
||||||
|
_remaining_requests = rpm_limit_for_model - new_val["current_rpm"]
|
||||||
|
|
||||||
|
_remaining_limits_data = {
|
||||||
|
f"litellm-key-remaining-tokens-{_model}": _remaining_tokens,
|
||||||
|
f"litellm-key-remaining-requests-{_model}": _remaining_requests,
|
||||||
|
}
|
||||||
|
data["metadata"].update(_remaining_limits_data)
|
||||||
|
|
||||||
# check if REQUEST ALLOWED for user_id
|
# check if REQUEST ALLOWED for user_id
|
||||||
user_id = user_api_key_dict.user_id
|
user_id = user_api_key_dict.user_id
|
||||||
if user_id is not None:
|
if user_id is not None:
|
||||||
|
@ -299,6 +379,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
return
|
return
|
||||||
|
|
||||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
from litellm.proxy.common_utils.callback_utils import (
|
||||||
|
get_model_group_from_litellm_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.print_verbose("INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
|
self.print_verbose("INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
|
||||||
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
|
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
|
||||||
|
@ -365,6 +449,36 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
request_count_api_key, new_val, ttl=60
|
request_count_api_key, new_val, ttl=60
|
||||||
) # store in cache for 1 min.
|
) # store in cache for 1 min.
|
||||||
|
|
||||||
|
# ------------
|
||||||
|
# Update usage - model group + API Key
|
||||||
|
# ------------
|
||||||
|
model_group = get_model_group_from_litellm_kwargs(kwargs)
|
||||||
|
if user_api_key is not None and model_group is not None:
|
||||||
|
request_count_api_key = (
|
||||||
|
f"{user_api_key}::{model_group}::{precise_minute}::request_count"
|
||||||
|
)
|
||||||
|
|
||||||
|
current = await self.internal_usage_cache.async_get_cache(
|
||||||
|
key=request_count_api_key
|
||||||
|
) or {
|
||||||
|
"current_requests": 1,
|
||||||
|
"current_tpm": total_tokens,
|
||||||
|
"current_rpm": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
new_val = {
|
||||||
|
"current_requests": max(current["current_requests"] - 1, 0),
|
||||||
|
"current_tpm": current["current_tpm"] + total_tokens,
|
||||||
|
"current_rpm": current["current_rpm"] + 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.print_verbose(
|
||||||
|
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
|
||||||
|
)
|
||||||
|
await self.internal_usage_cache.async_set_cache(
|
||||||
|
request_count_api_key, new_val, ttl=60
|
||||||
|
)
|
||||||
|
|
||||||
# ------------
|
# ------------
|
||||||
# Update usage - User
|
# Update usage - User
|
||||||
# ------------
|
# ------------
|
||||||
|
|
|
@ -42,7 +42,5 @@ general_settings:
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
fallbacks: [{"gemini-1.5-pro-001": ["gpt-4o"]}]
|
fallbacks: [{"gemini-1.5-pro-001": ["gpt-4o"]}]
|
||||||
callbacks: ["gcs_bucket"]
|
success_callback: ["langfuse", "prometheus"]
|
||||||
success_callback: ["langfuse"]
|
|
||||||
langfuse_default_tags: ["cache_hit", "cache_key", "user_api_key_alias", "user_api_key_team_alias"]
|
langfuse_default_tags: ["cache_hit", "cache_key", "user_api_key_alias", "user_api_key_team_alias"]
|
||||||
cache: True
|
|
||||||
|
|
|
@ -148,6 +148,10 @@ from litellm.proxy.common_utils.admin_ui_utils import (
|
||||||
html_form,
|
html_form,
|
||||||
show_missing_vars_in_env,
|
show_missing_vars_in_env,
|
||||||
)
|
)
|
||||||
|
from litellm.proxy.common_utils.callback_utils import (
|
||||||
|
get_remaining_tokens_and_requests_from_request_data,
|
||||||
|
initialize_callbacks_on_proxy,
|
||||||
|
)
|
||||||
from litellm.proxy.common_utils.debug_utils import init_verbose_loggers
|
from litellm.proxy.common_utils.debug_utils import init_verbose_loggers
|
||||||
from litellm.proxy.common_utils.debug_utils import router as debugging_endpoints_router
|
from litellm.proxy.common_utils.debug_utils import router as debugging_endpoints_router
|
||||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
|
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
|
||||||
|
@ -158,7 +162,6 @@ from litellm.proxy.common_utils.http_parsing_utils import (
|
||||||
_read_request_body,
|
_read_request_body,
|
||||||
check_file_size_under_limit,
|
check_file_size_under_limit,
|
||||||
)
|
)
|
||||||
from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy
|
|
||||||
from litellm.proxy.common_utils.load_config_utils import get_file_contents_from_s3
|
from litellm.proxy.common_utils.load_config_utils import get_file_contents_from_s3
|
||||||
from litellm.proxy.common_utils.openai_endpoint_utils import (
|
from litellm.proxy.common_utils.openai_endpoint_utils import (
|
||||||
remove_sensitive_info_from_deployment,
|
remove_sensitive_info_from_deployment,
|
||||||
|
@ -503,6 +506,7 @@ def get_custom_headers(
|
||||||
model_region: Optional[str] = None,
|
model_region: Optional[str] = None,
|
||||||
response_cost: Optional[Union[float, str]] = None,
|
response_cost: Optional[Union[float, str]] = None,
|
||||||
fastest_response_batch_completion: Optional[bool] = None,
|
fastest_response_batch_completion: Optional[bool] = None,
|
||||||
|
request_data: Optional[dict] = {},
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
exclude_values = {"", None}
|
exclude_values = {"", None}
|
||||||
|
@ -523,6 +527,12 @@ def get_custom_headers(
|
||||||
),
|
),
|
||||||
**{k: str(v) for k, v in kwargs.items()},
|
**{k: str(v) for k, v in kwargs.items()},
|
||||||
}
|
}
|
||||||
|
if request_data:
|
||||||
|
remaining_tokens_header = get_remaining_tokens_and_requests_from_request_data(
|
||||||
|
request_data
|
||||||
|
)
|
||||||
|
headers.update(remaining_tokens_header)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return {
|
return {
|
||||||
key: value for key, value in headers.items() if value not in exclude_values
|
key: value for key, value in headers.items() if value not in exclude_values
|
||||||
|
@ -3107,6 +3117,7 @@ async def chat_completion(
|
||||||
response_cost=response_cost,
|
response_cost=response_cost,
|
||||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
fastest_response_batch_completion=fastest_response_batch_completion,
|
fastest_response_batch_completion=fastest_response_batch_completion,
|
||||||
|
request_data=data,
|
||||||
**additional_headers,
|
**additional_headers,
|
||||||
)
|
)
|
||||||
selected_data_generator = select_data_generator(
|
selected_data_generator = select_data_generator(
|
||||||
|
@ -3141,6 +3152,7 @@ async def chat_completion(
|
||||||
response_cost=response_cost,
|
response_cost=response_cost,
|
||||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
fastest_response_batch_completion=fastest_response_batch_completion,
|
fastest_response_batch_completion=fastest_response_batch_completion,
|
||||||
|
request_data=data,
|
||||||
**additional_headers,
|
**additional_headers,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -3322,6 +3334,7 @@ async def completion(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
version=version,
|
version=version,
|
||||||
response_cost=response_cost,
|
response_cost=response_cost,
|
||||||
|
request_data=data,
|
||||||
)
|
)
|
||||||
selected_data_generator = select_data_generator(
|
selected_data_generator = select_data_generator(
|
||||||
response=response,
|
response=response,
|
||||||
|
@ -3343,6 +3356,7 @@ async def completion(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
version=version,
|
version=version,
|
||||||
response_cost=response_cost,
|
response_cost=response_cost,
|
||||||
|
request_data=data,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
await check_response_size_is_safe(response=response)
|
await check_response_size_is_safe(response=response)
|
||||||
|
@ -3550,6 +3564,7 @@ async def embeddings(
|
||||||
response_cost=response_cost,
|
response_cost=response_cost,
|
||||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
call_id=litellm_call_id,
|
call_id=litellm_call_id,
|
||||||
|
request_data=data,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
await check_response_size_is_safe(response=response)
|
await check_response_size_is_safe(response=response)
|
||||||
|
@ -3676,6 +3691,7 @@ async def image_generation(
|
||||||
response_cost=response_cost,
|
response_cost=response_cost,
|
||||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
call_id=litellm_call_id,
|
call_id=litellm_call_id,
|
||||||
|
request_data=data,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -3797,6 +3813,7 @@ async def audio_speech(
|
||||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
fastest_response_batch_completion=None,
|
fastest_response_batch_completion=None,
|
||||||
call_id=litellm_call_id,
|
call_id=litellm_call_id,
|
||||||
|
request_data=data,
|
||||||
)
|
)
|
||||||
|
|
||||||
selected_data_generator = select_data_generator(
|
selected_data_generator = select_data_generator(
|
||||||
|
@ -3934,6 +3951,7 @@ async def audio_transcriptions(
|
||||||
response_cost=response_cost,
|
response_cost=response_cost,
|
||||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
call_id=litellm_call_id,
|
call_id=litellm_call_id,
|
||||||
|
request_data=data,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -4037,6 +4055,7 @@ async def get_assistants(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
version=version,
|
version=version,
|
||||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
|
request_data=data,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -4132,6 +4151,7 @@ async def create_assistant(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
version=version,
|
version=version,
|
||||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
|
request_data=data,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -4227,6 +4247,7 @@ async def delete_assistant(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
version=version,
|
version=version,
|
||||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
|
request_data=data,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -4322,6 +4343,7 @@ async def create_threads(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
version=version,
|
version=version,
|
||||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
|
request_data=data,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -4416,6 +4438,7 @@ async def get_thread(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
version=version,
|
version=version,
|
||||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
|
request_data=data,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -4513,6 +4536,7 @@ async def add_messages(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
version=version,
|
version=version,
|
||||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
|
request_data=data,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -4606,6 +4630,7 @@ async def get_messages(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
version=version,
|
version=version,
|
||||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
|
request_data=data,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -4713,6 +4738,7 @@ async def run_thread(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
version=version,
|
version=version,
|
||||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
|
request_data=data,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -4835,6 +4861,7 @@ async def create_batch(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
version=version,
|
version=version,
|
||||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
|
request_data=data,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -4930,6 +4957,7 @@ async def retrieve_batch(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
version=version,
|
version=version,
|
||||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
|
request_data=data,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -5148,6 +5176,7 @@ async def moderations(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
version=version,
|
version=version,
|
||||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
|
request_data=data,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -5317,6 +5346,7 @@ async def anthropic_response(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
version=version,
|
version=version,
|
||||||
response_cost=response_cost,
|
response_cost=response_cost,
|
||||||
|
request_data=data,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -908,3 +908,263 @@ async def test_bad_router_tpm_limit():
|
||||||
)["current_tpm"]
|
)["current_tpm"]
|
||||||
== 0
|
== 0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bad_router_tpm_limit_per_model():
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "azure-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/gpt-turbo",
|
||||||
|
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
|
||||||
|
"api_base": "https://openai-france-1234.openai.azure.com",
|
||||||
|
"rpm": 1440,
|
||||||
|
},
|
||||||
|
"model_info": {"id": 1},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "azure-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/gpt-35-turbo",
|
||||||
|
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
|
||||||
|
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
|
||||||
|
"rpm": 6,
|
||||||
|
},
|
||||||
|
"model_info": {"id": 2},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
router = Router(
|
||||||
|
model_list=model_list,
|
||||||
|
set_verbose=False,
|
||||||
|
num_retries=3,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
_api_key = "sk-12345"
|
||||||
|
_api_key = hash_token(_api_key)
|
||||||
|
model = "azure-model"
|
||||||
|
|
||||||
|
user_api_key_dict = UserAPIKeyAuth(
|
||||||
|
api_key=_api_key,
|
||||||
|
max_parallel_requests=10,
|
||||||
|
tpm_limit=10,
|
||||||
|
tpm_limit_per_model={model: 5},
|
||||||
|
rpm_limit_per_model={model: 5},
|
||||||
|
)
|
||||||
|
local_cache = DualCache()
|
||||||
|
pl = ProxyLogging(user_api_key_cache=local_cache)
|
||||||
|
pl._init_litellm_callbacks()
|
||||||
|
print(f"litellm callbacks: {litellm.callbacks}")
|
||||||
|
parallel_request_handler = pl.max_parallel_request_limiter
|
||||||
|
|
||||||
|
await parallel_request_handler.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
cache=local_cache,
|
||||||
|
data={"model": model},
|
||||||
|
call_type="",
|
||||||
|
)
|
||||||
|
|
||||||
|
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
current_hour = datetime.now().strftime("%H")
|
||||||
|
current_minute = datetime.now().strftime("%M")
|
||||||
|
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||||
|
request_count_api_key = f"{_api_key}::{model}::{precise_minute}::request_count"
|
||||||
|
|
||||||
|
print(
|
||||||
|
"internal usage cache: ",
|
||||||
|
parallel_request_handler.internal_usage_cache.in_memory_cache.cache_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
parallel_request_handler.internal_usage_cache.get_cache(
|
||||||
|
key=request_count_api_key
|
||||||
|
)["current_requests"]
|
||||||
|
== 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# bad call
|
||||||
|
try:
|
||||||
|
response = await router.acompletion(
|
||||||
|
model=model,
|
||||||
|
messages=[{"role": "user2", "content": "Write me a paragraph on the moon"}],
|
||||||
|
stream=True,
|
||||||
|
metadata={"user_api_key": _api_key},
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
await asyncio.sleep(1) # success is done in a separate thread
|
||||||
|
|
||||||
|
assert (
|
||||||
|
parallel_request_handler.internal_usage_cache.get_cache(
|
||||||
|
key=request_count_api_key
|
||||||
|
)["current_tpm"]
|
||||||
|
== 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pre_call_hook_rpm_limits_per_model():
|
||||||
|
"""
|
||||||
|
Test if error raised on hitting rpm limits for a given model
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from litellm._logging import (
|
||||||
|
verbose_logger,
|
||||||
|
verbose_proxy_logger,
|
||||||
|
verbose_router_logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
verbose_logger.setLevel(logging.DEBUG)
|
||||||
|
verbose_proxy_logger.setLevel(logging.DEBUG)
|
||||||
|
verbose_router_logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
_api_key = "sk-12345"
|
||||||
|
_api_key = hash_token(_api_key)
|
||||||
|
user_api_key_dict = UserAPIKeyAuth(
|
||||||
|
api_key=_api_key,
|
||||||
|
max_parallel_requests=100,
|
||||||
|
tpm_limit=900000,
|
||||||
|
rpm_limit=100000,
|
||||||
|
rpm_limit_per_model={"azure-model": 1},
|
||||||
|
)
|
||||||
|
local_cache = DualCache()
|
||||||
|
pl = ProxyLogging(user_api_key_cache=local_cache)
|
||||||
|
pl._init_litellm_callbacks()
|
||||||
|
print(f"litellm callbacks: {litellm.callbacks}")
|
||||||
|
parallel_request_handler = pl.max_parallel_request_limiter
|
||||||
|
|
||||||
|
await parallel_request_handler.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||||
|
)
|
||||||
|
|
||||||
|
model = "azure-model"
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
"model": model,
|
||||||
|
"litellm_params": {"metadata": {"user_api_key": _api_key}},
|
||||||
|
}
|
||||||
|
|
||||||
|
await parallel_request_handler.async_log_success_event(
|
||||||
|
kwargs=kwargs,
|
||||||
|
response_obj="",
|
||||||
|
start_time="",
|
||||||
|
end_time="",
|
||||||
|
)
|
||||||
|
|
||||||
|
## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1}
|
||||||
|
|
||||||
|
try:
|
||||||
|
await parallel_request_handler.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
cache=local_cache,
|
||||||
|
data={"model": model},
|
||||||
|
call_type="",
|
||||||
|
)
|
||||||
|
|
||||||
|
pytest.fail(f"Expected call to fail")
|
||||||
|
except Exception as e:
|
||||||
|
assert e.status_code == 429
|
||||||
|
print("got error=", e)
|
||||||
|
assert (
|
||||||
|
"limit reached Hit RPM limit for model: azure-model on api_key: c11e7177eb60c80cf983ddf8ca98f2dc1272d4c612204ce9bedd2460b18939cc"
|
||||||
|
in str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pre_call_hook_tpm_limits_per_model():
|
||||||
|
"""
|
||||||
|
Test if error raised on hitting tpm limits for a given model
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from litellm._logging import (
|
||||||
|
verbose_logger,
|
||||||
|
verbose_proxy_logger,
|
||||||
|
verbose_router_logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
verbose_logger.setLevel(logging.DEBUG)
|
||||||
|
verbose_proxy_logger.setLevel(logging.DEBUG)
|
||||||
|
verbose_router_logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
_api_key = "sk-12345"
|
||||||
|
_api_key = hash_token(_api_key)
|
||||||
|
user_api_key_dict = UserAPIKeyAuth(
|
||||||
|
api_key=_api_key,
|
||||||
|
max_parallel_requests=100,
|
||||||
|
tpm_limit=900000,
|
||||||
|
rpm_limit=100000,
|
||||||
|
rpm_limit_per_model={"azure-model": 100},
|
||||||
|
tpm_limit_per_model={"azure-model": 10},
|
||||||
|
)
|
||||||
|
local_cache = DualCache()
|
||||||
|
pl = ProxyLogging(user_api_key_cache=local_cache)
|
||||||
|
pl._init_litellm_callbacks()
|
||||||
|
print(f"litellm callbacks: {litellm.callbacks}")
|
||||||
|
parallel_request_handler = pl.max_parallel_request_limiter
|
||||||
|
model = "azure-model"
|
||||||
|
|
||||||
|
await parallel_request_handler.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
cache=local_cache,
|
||||||
|
data={"model": model},
|
||||||
|
call_type="",
|
||||||
|
)
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
"model": model,
|
||||||
|
"litellm_params": {"metadata": {"user_api_key": _api_key}},
|
||||||
|
}
|
||||||
|
|
||||||
|
await parallel_request_handler.async_log_success_event(
|
||||||
|
kwargs=kwargs,
|
||||||
|
response_obj=litellm.ModelResponse(usage=litellm.Usage(total_tokens=11)),
|
||||||
|
start_time="",
|
||||||
|
end_time="",
|
||||||
|
)
|
||||||
|
|
||||||
|
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
current_hour = datetime.now().strftime("%H")
|
||||||
|
current_minute = datetime.now().strftime("%M")
|
||||||
|
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||||
|
request_count_api_key = f"{_api_key}::{model}::{precise_minute}::request_count"
|
||||||
|
|
||||||
|
print(
|
||||||
|
"internal usage cache: ",
|
||||||
|
parallel_request_handler.internal_usage_cache.in_memory_cache.cache_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
parallel_request_handler.internal_usage_cache.get_cache(
|
||||||
|
key=request_count_api_key
|
||||||
|
)["current_tpm"]
|
||||||
|
== 11
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
parallel_request_handler.internal_usage_cache.get_cache(
|
||||||
|
key=request_count_api_key
|
||||||
|
)["current_rpm"]
|
||||||
|
== 1
|
||||||
|
)
|
||||||
|
|
||||||
|
## Expected cache val: {"current_requests": 0, "current_tpm": 11, "current_rpm": "1"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
await parallel_request_handler.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
cache=local_cache,
|
||||||
|
data={"model": model},
|
||||||
|
call_type="",
|
||||||
|
)
|
||||||
|
|
||||||
|
pytest.fail(f"Expected call to fail")
|
||||||
|
except Exception as e:
|
||||||
|
assert e.status_code == 429
|
||||||
|
print("got error=", e)
|
||||||
|
assert (
|
||||||
|
"request limit reached Hit TPM limit for model: azure-model on api_key"
|
||||||
|
in str(e)
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue