Merge pull request #5259 from BerriAI/litellm_return_remaining_tokens_in_header

[Feat] return `x-litellm-key-remaining-requests-{model}`: 1, `x-litellm-key-remaining-tokens-{model}: None` in response headers
This commit is contained in:
Ishaan Jaff 2024-08-17 12:41:16 -07:00 committed by GitHub
commit feb8c3c5b4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 518 additions and 11 deletions

View file

@ -103,13 +103,30 @@ class PrometheusLogger(CustomLogger):
"Remaining budget for api key", "Remaining budget for api key",
labelnames=["hashed_api_key", "api_key_alias"], labelnames=["hashed_api_key", "api_key_alias"],
) )
# Litellm-Enterprise Metrics
if premium_user is True:
########################################
# LiteLLM Virtual API KEY metrics
########################################
# Remaining MODEL RPM limit for API Key
self.litellm_remaining_api_key_requests_for_model = Gauge(
"litellm_remaining_api_key_requests_for_model",
"Remaining Requests API Key can make for model (model based rpm limit on key)",
labelnames=["hashed_api_key", "api_key_alias", "model"],
)
# Remaining MODEL TPM limit for API Key
self.litellm_remaining_api_key_tokens_for_model = Gauge(
"litellm_remaining_api_key_tokens_for_model",
"Remaining Tokens API Key can make for model (model based tpm limit on key)",
labelnames=["hashed_api_key", "api_key_alias", "model"],
)
######################################## ########################################
# LLM API Deployment Metrics / analytics # LLM API Deployment Metrics / analytics
######################################## ########################################
# Litellm-Enterprise Metrics
if premium_user is True:
# Remaining Rate Limit for model # Remaining Rate Limit for model
self.litellm_remaining_requests_metric = Gauge( self.litellm_remaining_requests_metric = Gauge(
"litellm_remaining_requests", "litellm_remaining_requests",
@ -187,6 +204,9 @@ class PrometheusLogger(CustomLogger):
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
# Define prometheus client # Define prometheus client
from litellm.proxy.common_utils.callback_utils import (
get_model_group_from_litellm_kwargs,
)
from litellm.proxy.proxy_server import premium_user from litellm.proxy.proxy_server import premium_user
verbose_logger.debug( verbose_logger.debug(
@ -197,6 +217,7 @@ class PrometheusLogger(CustomLogger):
model = kwargs.get("model", "") model = kwargs.get("model", "")
response_cost = kwargs.get("response_cost", 0.0) or 0 response_cost = kwargs.get("response_cost", 0.0) or 0
litellm_params = kwargs.get("litellm_params", {}) or {} litellm_params = kwargs.get("litellm_params", {}) or {}
_metadata = litellm_params.get("metadata", {})
proxy_server_request = litellm_params.get("proxy_server_request") or {} proxy_server_request = litellm_params.get("proxy_server_request") or {}
end_user_id = proxy_server_request.get("body", {}).get("user", None) end_user_id = proxy_server_request.get("body", {}).get("user", None)
user_id = litellm_params.get("metadata", {}).get("user_api_key_user_id", None) user_id = litellm_params.get("metadata", {}).get("user_api_key_user_id", None)
@ -286,6 +307,27 @@ class PrometheusLogger(CustomLogger):
user_api_key, user_api_key_alias user_api_key, user_api_key_alias
).set(_remaining_api_key_budget) ).set(_remaining_api_key_budget)
# Set remaining rpm/tpm for API Key + model
# see parallel_request_limiter.py - variables are set there
model_group = get_model_group_from_litellm_kwargs(kwargs)
remaining_requests_variable_name = (
f"litellm-key-remaining-requests-{model_group}"
)
remaining_tokens_variable_name = f"litellm-key-remaining-tokens-{model_group}"
remaining_requests = _metadata.get(
remaining_requests_variable_name, sys.maxsize
)
remaining_tokens = _metadata.get(remaining_tokens_variable_name, sys.maxsize)
self.litellm_remaining_api_key_requests_for_model.labels(
user_api_key, user_api_key_alias, model_group
).set(remaining_requests)
self.litellm_remaining_api_key_tokens_for_model.labels(
user_api_key, user_api_key_alias, model_group
).set(remaining_tokens)
# set x-ratelimit headers # set x-ratelimit headers
if premium_user is True: if premium_user is True:
self.set_llm_deployment_success_metrics( self.set_llm_deployment_success_metrics(

View file

@ -1337,6 +1337,8 @@ class UserAPIKeyAuth(
] = None ] = None
allowed_model_region: Optional[Literal["eu"]] = None allowed_model_region: Optional[Literal["eu"]] = None
parent_otel_span: Optional[Span] = None parent_otel_span: Optional[Span] = None
rpm_limit_per_model: Optional[Dict[str, int]] = None
tpm_limit_per_model: Optional[Dict[str, int]] = None
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod

View file

@ -210,3 +210,20 @@ def bytes_to_mb(bytes_value: int):
Helper to convert bytes to MB Helper to convert bytes to MB
""" """
return bytes_value / (1024 * 1024) return bytes_value / (1024 * 1024)
# helpers used by parallel request limiter to handle model rpm/tpm limits for a given api key
def get_key_model_rpm_limit(user_api_key_dict: UserAPIKeyAuth) -> Optional[dict]:
if user_api_key_dict.metadata:
if "model_rpm_limit" in user_api_key_dict.metadata:
return user_api_key_dict.metadata["model_rpm_limit"]
return None
def get_key_model_tpm_limit(user_api_key_dict: UserAPIKeyAuth) -> Optional[dict]:
if user_api_key_dict.metadata:
if "model_tpm_limit" in user_api_key_dict.metadata:
return user_api_key_dict.metadata["model_tpm_limit"]
return None

View file

@ -1,4 +1,5 @@
from typing import Any, List, Optional, get_args import sys
from typing import Any, Dict, List, Optional, get_args
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
@ -249,3 +250,46 @@ def initialize_callbacks_on_proxy(
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}" f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}"
) )
def get_model_group_from_litellm_kwargs(kwargs: dict) -> Optional[str]:
_litellm_params = kwargs.get("litellm_params", None) or {}
_metadata = _litellm_params.get("metadata", None) or {}
_model_group = _metadata.get("model_group", None)
if _model_group is not None:
return _model_group
return None
def get_model_group_from_request_data(data: dict) -> Optional[str]:
_metadata = data.get("metadata", None) or {}
_model_group = _metadata.get("model_group", None)
if _model_group is not None:
return _model_group
return None
def get_remaining_tokens_and_requests_from_request_data(data: Dict) -> Dict[str, str]:
"""
Helper function to return x-litellm-key-remaining-tokens-{model_group} and x-litellm-key-remaining-requests-{model_group}
Returns {} when api_key + model rpm/tpm limit is not set
"""
_metadata = data.get("metadata", None) or {}
model_group = get_model_group_from_request_data(data)
# Remaining Requests
remaining_requests_variable_name = f"litellm-key-remaining-requests-{model_group}"
remaining_requests = _metadata.get(remaining_requests_variable_name, None)
# Remaining Tokens
remaining_tokens_variable_name = f"litellm-key-remaining-tokens-{model_group}"
remaining_tokens = _metadata.get(remaining_tokens_variable_name, None)
return {
f"x-litellm-key-remaining-requests-{model_group}": str(remaining_requests),
f"x-litellm-key-remaining-tokens-{model_group}": str(remaining_tokens),
}

View file

@ -5,7 +5,7 @@ from pydantic import BaseModel, RootModel
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
all_guardrails: List[GuardrailItem] = [] all_guardrails: List[GuardrailItem] = []

View file

@ -11,6 +11,10 @@ from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.auth_utils import (
get_key_model_rpm_limit,
get_key_model_tpm_limit,
)
class _PROXY_MaxParallelRequestsHandler(CustomLogger): class _PROXY_MaxParallelRequestsHandler(CustomLogger):
@ -202,6 +206,82 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
additional_details=f"Hit limit for api_key: {api_key}. tpm_limit: {tpm_limit}, current_tpm {current['current_tpm']} , rpm_limit: {rpm_limit} current rpm {current['current_rpm']} " additional_details=f"Hit limit for api_key: {api_key}. tpm_limit: {tpm_limit}, current_tpm {current['current_tpm']} , rpm_limit: {rpm_limit} current rpm {current['current_rpm']} "
) )
# Check if request under RPM/TPM per model for a given API Key
if (
get_key_model_tpm_limit(user_api_key_dict) is not None
or get_key_model_rpm_limit(user_api_key_dict) is not None
):
_model = data.get("model", None)
request_count_api_key = (
f"{api_key}::{_model}::{precise_minute}::request_count"
)
current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
tpm_limit_for_model = None
rpm_limit_for_model = None
_tpm_limit_for_key_model = get_key_model_tpm_limit(user_api_key_dict)
_rpm_limit_for_key_model = get_key_model_rpm_limit(user_api_key_dict)
if _model is not None:
if _tpm_limit_for_key_model:
tpm_limit_for_model = _tpm_limit_for_key_model.get(_model)
if _rpm_limit_for_key_model:
rpm_limit_for_model = _rpm_limit_for_key_model.get(_model)
if current is None:
new_val = {
"current_requests": 1,
"current_tpm": 0,
"current_rpm": 0,
}
await self.internal_usage_cache.async_set_cache(
request_count_api_key, new_val
)
elif tpm_limit_for_model is not None or rpm_limit_for_model is not None:
# Increase count for this token
new_val = {
"current_requests": current["current_requests"] + 1,
"current_tpm": current["current_tpm"],
"current_rpm": current["current_rpm"],
}
if (
tpm_limit_for_model is not None
and current["current_tpm"] >= tpm_limit_for_model
):
return self.raise_rate_limit_error(
additional_details=f"Hit TPM limit for model: {_model} on api_key: {api_key}. tpm_limit: {tpm_limit_for_model}, current_tpm {current['current_tpm']} "
)
elif (
rpm_limit_for_model is not None
and current["current_rpm"] >= rpm_limit_for_model
):
return self.raise_rate_limit_error(
additional_details=f"Hit RPM limit for model: {_model} on api_key: {api_key}. rpm_limit: {rpm_limit_for_model}, current_rpm {current['current_rpm']} "
)
else:
await self.internal_usage_cache.async_set_cache(
request_count_api_key, new_val
)
_remaining_tokens = None
_remaining_requests = None
# Add remaining tokens, requests to metadata
if tpm_limit_for_model is not None:
_remaining_tokens = tpm_limit_for_model - new_val["current_tpm"]
if rpm_limit_for_model is not None:
_remaining_requests = rpm_limit_for_model - new_val["current_rpm"]
_remaining_limits_data = {
f"litellm-key-remaining-tokens-{_model}": _remaining_tokens,
f"litellm-key-remaining-requests-{_model}": _remaining_requests,
}
data["metadata"].update(_remaining_limits_data)
# check if REQUEST ALLOWED for user_id # check if REQUEST ALLOWED for user_id
user_id = user_api_key_dict.user_id user_id = user_api_key_dict.user_id
if user_id is not None: if user_id is not None:
@ -299,6 +379,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
return return
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
from litellm.proxy.common_utils.callback_utils import (
get_model_group_from_litellm_kwargs,
)
try: try:
self.print_verbose("INSIDE parallel request limiter ASYNC SUCCESS LOGGING") self.print_verbose("INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get( global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
@ -365,6 +449,36 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
request_count_api_key, new_val, ttl=60 request_count_api_key, new_val, ttl=60
) # store in cache for 1 min. ) # store in cache for 1 min.
# ------------
# Update usage - model group + API Key
# ------------
model_group = get_model_group_from_litellm_kwargs(kwargs)
if user_api_key is not None and model_group is not None:
request_count_api_key = (
f"{user_api_key}::{model_group}::{precise_minute}::request_count"
)
current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key
) or {
"current_requests": 1,
"current_tpm": total_tokens,
"current_rpm": 1,
}
new_val = {
"current_requests": max(current["current_requests"] - 1, 0),
"current_tpm": current["current_tpm"] + total_tokens,
"current_rpm": current["current_rpm"] + 1,
}
self.print_verbose(
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
)
await self.internal_usage_cache.async_set_cache(
request_count_api_key, new_val, ttl=60
)
# ------------ # ------------
# Update usage - User # Update usage - User
# ------------ # ------------

View file

@ -42,7 +42,5 @@ general_settings:
litellm_settings: litellm_settings:
fallbacks: [{"gemini-1.5-pro-001": ["gpt-4o"]}] fallbacks: [{"gemini-1.5-pro-001": ["gpt-4o"]}]
callbacks: ["gcs_bucket"] success_callback: ["langfuse", "prometheus"]
success_callback: ["langfuse"]
langfuse_default_tags: ["cache_hit", "cache_key", "user_api_key_alias", "user_api_key_team_alias"] langfuse_default_tags: ["cache_hit", "cache_key", "user_api_key_alias", "user_api_key_team_alias"]
cache: True

View file

@ -148,6 +148,10 @@ from litellm.proxy.common_utils.admin_ui_utils import (
html_form, html_form,
show_missing_vars_in_env, show_missing_vars_in_env,
) )
from litellm.proxy.common_utils.callback_utils import (
get_remaining_tokens_and_requests_from_request_data,
initialize_callbacks_on_proxy,
)
from litellm.proxy.common_utils.debug_utils import init_verbose_loggers from litellm.proxy.common_utils.debug_utils import init_verbose_loggers
from litellm.proxy.common_utils.debug_utils import router as debugging_endpoints_router from litellm.proxy.common_utils.debug_utils import router as debugging_endpoints_router
from litellm.proxy.common_utils.encrypt_decrypt_utils import ( from litellm.proxy.common_utils.encrypt_decrypt_utils import (
@ -158,7 +162,6 @@ from litellm.proxy.common_utils.http_parsing_utils import (
_read_request_body, _read_request_body,
check_file_size_under_limit, check_file_size_under_limit,
) )
from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy
from litellm.proxy.common_utils.load_config_utils import get_file_contents_from_s3 from litellm.proxy.common_utils.load_config_utils import get_file_contents_from_s3
from litellm.proxy.common_utils.openai_endpoint_utils import ( from litellm.proxy.common_utils.openai_endpoint_utils import (
remove_sensitive_info_from_deployment, remove_sensitive_info_from_deployment,
@ -503,6 +506,7 @@ def get_custom_headers(
model_region: Optional[str] = None, model_region: Optional[str] = None,
response_cost: Optional[Union[float, str]] = None, response_cost: Optional[Union[float, str]] = None,
fastest_response_batch_completion: Optional[bool] = None, fastest_response_batch_completion: Optional[bool] = None,
request_data: Optional[dict] = {},
**kwargs, **kwargs,
) -> dict: ) -> dict:
exclude_values = {"", None} exclude_values = {"", None}
@ -523,6 +527,12 @@ def get_custom_headers(
), ),
**{k: str(v) for k, v in kwargs.items()}, **{k: str(v) for k, v in kwargs.items()},
} }
if request_data:
remaining_tokens_header = get_remaining_tokens_and_requests_from_request_data(
request_data
)
headers.update(remaining_tokens_header)
try: try:
return { return {
key: value for key, value in headers.items() if value not in exclude_values key: value for key, value in headers.items() if value not in exclude_values
@ -3107,6 +3117,7 @@ async def chat_completion(
response_cost=response_cost, response_cost=response_cost,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""), model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
fastest_response_batch_completion=fastest_response_batch_completion, fastest_response_batch_completion=fastest_response_batch_completion,
request_data=data,
**additional_headers, **additional_headers,
) )
selected_data_generator = select_data_generator( selected_data_generator = select_data_generator(
@ -3141,6 +3152,7 @@ async def chat_completion(
response_cost=response_cost, response_cost=response_cost,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""), model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
fastest_response_batch_completion=fastest_response_batch_completion, fastest_response_batch_completion=fastest_response_batch_completion,
request_data=data,
**additional_headers, **additional_headers,
) )
) )
@ -3322,6 +3334,7 @@ async def completion(
api_base=api_base, api_base=api_base,
version=version, version=version,
response_cost=response_cost, response_cost=response_cost,
request_data=data,
) )
selected_data_generator = select_data_generator( selected_data_generator = select_data_generator(
response=response, response=response,
@ -3343,6 +3356,7 @@ async def completion(
api_base=api_base, api_base=api_base,
version=version, version=version,
response_cost=response_cost, response_cost=response_cost,
request_data=data,
) )
) )
await check_response_size_is_safe(response=response) await check_response_size_is_safe(response=response)
@ -3550,6 +3564,7 @@ async def embeddings(
response_cost=response_cost, response_cost=response_cost,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""), model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
call_id=litellm_call_id, call_id=litellm_call_id,
request_data=data,
) )
) )
await check_response_size_is_safe(response=response) await check_response_size_is_safe(response=response)
@ -3676,6 +3691,7 @@ async def image_generation(
response_cost=response_cost, response_cost=response_cost,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""), model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
call_id=litellm_call_id, call_id=litellm_call_id,
request_data=data,
) )
) )
@ -3797,6 +3813,7 @@ async def audio_speech(
model_region=getattr(user_api_key_dict, "allowed_model_region", ""), model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
fastest_response_batch_completion=None, fastest_response_batch_completion=None,
call_id=litellm_call_id, call_id=litellm_call_id,
request_data=data,
) )
selected_data_generator = select_data_generator( selected_data_generator = select_data_generator(
@ -3934,6 +3951,7 @@ async def audio_transcriptions(
response_cost=response_cost, response_cost=response_cost,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""), model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
call_id=litellm_call_id, call_id=litellm_call_id,
request_data=data,
) )
) )
@ -4037,6 +4055,7 @@ async def get_assistants(
api_base=api_base, api_base=api_base,
version=version, version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""), model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
) )
) )
@ -4132,6 +4151,7 @@ async def create_assistant(
api_base=api_base, api_base=api_base,
version=version, version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""), model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
) )
) )
@ -4227,6 +4247,7 @@ async def delete_assistant(
api_base=api_base, api_base=api_base,
version=version, version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""), model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
) )
) )
@ -4322,6 +4343,7 @@ async def create_threads(
api_base=api_base, api_base=api_base,
version=version, version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""), model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
) )
) )
@ -4416,6 +4438,7 @@ async def get_thread(
api_base=api_base, api_base=api_base,
version=version, version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""), model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
) )
) )
@ -4513,6 +4536,7 @@ async def add_messages(
api_base=api_base, api_base=api_base,
version=version, version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""), model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
) )
) )
@ -4606,6 +4630,7 @@ async def get_messages(
api_base=api_base, api_base=api_base,
version=version, version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""), model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
) )
) )
@ -4713,6 +4738,7 @@ async def run_thread(
api_base=api_base, api_base=api_base,
version=version, version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""), model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
) )
) )
@ -4835,6 +4861,7 @@ async def create_batch(
api_base=api_base, api_base=api_base,
version=version, version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""), model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
) )
) )
@ -4930,6 +4957,7 @@ async def retrieve_batch(
api_base=api_base, api_base=api_base,
version=version, version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""), model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
) )
) )
@ -5148,6 +5176,7 @@ async def moderations(
api_base=api_base, api_base=api_base,
version=version, version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""), model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
) )
) )
@ -5317,6 +5346,7 @@ async def anthropic_response(
api_base=api_base, api_base=api_base,
version=version, version=version,
response_cost=response_cost, response_cost=response_cost,
request_data=data,
) )
) )

View file

@ -908,3 +908,263 @@ async def test_bad_router_tpm_limit():
)["current_tpm"] )["current_tpm"]
== 0 == 0
) )
@pytest.mark.asyncio
async def test_bad_router_tpm_limit_per_model():
model_list = [
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-turbo",
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
"api_base": "https://openai-france-1234.openai.azure.com",
"rpm": 1440,
},
"model_info": {"id": 1},
},
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-35-turbo",
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
"rpm": 6,
},
"model_info": {"id": 2},
},
]
router = Router(
model_list=model_list,
set_verbose=False,
num_retries=3,
) # type: ignore
_api_key = "sk-12345"
_api_key = hash_token(_api_key)
model = "azure-model"
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key,
max_parallel_requests=10,
tpm_limit=10,
tpm_limit_per_model={model: 5},
rpm_limit_per_model={model: 5},
)
local_cache = DualCache()
pl = ProxyLogging(user_api_key_cache=local_cache)
pl._init_litellm_callbacks()
print(f"litellm callbacks: {litellm.callbacks}")
parallel_request_handler = pl.max_parallel_request_limiter
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=local_cache,
data={"model": model},
call_type="",
)
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{model}::{precise_minute}::request_count"
print(
"internal usage cache: ",
parallel_request_handler.internal_usage_cache.in_memory_cache.cache_dict,
)
assert (
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
)["current_requests"]
== 1
)
# bad call
try:
response = await router.acompletion(
model=model,
messages=[{"role": "user2", "content": "Write me a paragraph on the moon"}],
stream=True,
metadata={"user_api_key": _api_key},
)
except:
pass
await asyncio.sleep(1) # success is done in a separate thread
assert (
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
)["current_tpm"]
== 0
)
@pytest.mark.asyncio
async def test_pre_call_hook_rpm_limits_per_model():
"""
Test if error raised on hitting rpm limits for a given model
"""
import logging
from litellm._logging import (
verbose_logger,
verbose_proxy_logger,
verbose_router_logger,
)
verbose_logger.setLevel(logging.DEBUG)
verbose_proxy_logger.setLevel(logging.DEBUG)
verbose_router_logger.setLevel(logging.DEBUG)
_api_key = "sk-12345"
_api_key = hash_token(_api_key)
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key,
max_parallel_requests=100,
tpm_limit=900000,
rpm_limit=100000,
rpm_limit_per_model={"azure-model": 1},
)
local_cache = DualCache()
pl = ProxyLogging(user_api_key_cache=local_cache)
pl._init_litellm_callbacks()
print(f"litellm callbacks: {litellm.callbacks}")
parallel_request_handler = pl.max_parallel_request_limiter
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
)
model = "azure-model"
kwargs = {
"model": model,
"litellm_params": {"metadata": {"user_api_key": _api_key}},
}
await parallel_request_handler.async_log_success_event(
kwargs=kwargs,
response_obj="",
start_time="",
end_time="",
)
## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1}
try:
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=local_cache,
data={"model": model},
call_type="",
)
pytest.fail(f"Expected call to fail")
except Exception as e:
assert e.status_code == 429
print("got error=", e)
assert (
"limit reached Hit RPM limit for model: azure-model on api_key: c11e7177eb60c80cf983ddf8ca98f2dc1272d4c612204ce9bedd2460b18939cc"
in str(e)
)
@pytest.mark.asyncio
async def test_pre_call_hook_tpm_limits_per_model():
"""
Test if error raised on hitting tpm limits for a given model
"""
import logging
from litellm._logging import (
verbose_logger,
verbose_proxy_logger,
verbose_router_logger,
)
verbose_logger.setLevel(logging.DEBUG)
verbose_proxy_logger.setLevel(logging.DEBUG)
verbose_router_logger.setLevel(logging.DEBUG)
_api_key = "sk-12345"
_api_key = hash_token(_api_key)
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key,
max_parallel_requests=100,
tpm_limit=900000,
rpm_limit=100000,
rpm_limit_per_model={"azure-model": 100},
tpm_limit_per_model={"azure-model": 10},
)
local_cache = DualCache()
pl = ProxyLogging(user_api_key_cache=local_cache)
pl._init_litellm_callbacks()
print(f"litellm callbacks: {litellm.callbacks}")
parallel_request_handler = pl.max_parallel_request_limiter
model = "azure-model"
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=local_cache,
data={"model": model},
call_type="",
)
kwargs = {
"model": model,
"litellm_params": {"metadata": {"user_api_key": _api_key}},
}
await parallel_request_handler.async_log_success_event(
kwargs=kwargs,
response_obj=litellm.ModelResponse(usage=litellm.Usage(total_tokens=11)),
start_time="",
end_time="",
)
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{model}::{precise_minute}::request_count"
print(
"internal usage cache: ",
parallel_request_handler.internal_usage_cache.in_memory_cache.cache_dict,
)
assert (
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
)["current_tpm"]
== 11
)
assert (
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
)["current_rpm"]
== 1
)
## Expected cache val: {"current_requests": 0, "current_tpm": 11, "current_rpm": "1"}
try:
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=local_cache,
data={"model": model},
call_type="",
)
pytest.fail(f"Expected call to fail")
except Exception as e:
assert e.status_code == 429
print("got error=", e)
assert (
"request limit reached Hit TPM limit for model: azure-model on api_key"
in str(e)
)