feat(dynamic_rate_limiter.py): update cache with active project

This commit is contained in:
Krrish Dholakia 2024-06-21 20:25:40 -07:00
parent 2545da777b
commit a028600932
5 changed files with 253 additions and 521 deletions

View file

@ -1,11 +1,12 @@
from datetime import datetime from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Optional, Union
import litellm import litellm
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from .types.services import ServiceTypes, ServiceLoggerPayload
from .integrations.prometheus_services import PrometheusServicesLogger
from .integrations.custom_logger import CustomLogger from .integrations.custom_logger import CustomLogger
from datetime import timedelta from .integrations.prometheus_services import PrometheusServicesLogger
from typing import Union, Optional, TYPE_CHECKING, Any from .types.services import ServiceLoggerPayload, ServiceTypes
if TYPE_CHECKING: if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span from opentelemetry.trace import Span as _Span
@ -53,8 +54,8 @@ class ServiceLogging(CustomLogger):
call_type: str, call_type: str,
duration: float, duration: float,
parent_otel_span: Optional[Span] = None, parent_otel_span: Optional[Span] = None,
start_time: Optional[datetime] = None, start_time: Optional[Union[datetime, float]] = None,
end_time: Optional[datetime] = None, end_time: Optional[Union[datetime, float]] = None,
): ):
""" """
- For counting if the redis, postgres call is successful - For counting if the redis, postgres call is successful
@ -92,8 +93,8 @@ class ServiceLogging(CustomLogger):
error: Union[str, Exception], error: Union[str, Exception],
call_type: str, call_type: str,
parent_otel_span: Optional[Span] = None, parent_otel_span: Optional[Span] = None,
start_time: Optional[datetime] = None, start_time: Optional[Union[datetime, float]] = None,
end_time: Optional[datetime] = None, end_time: Optional[Union[float, datetime]] = None,
): ):
""" """
- For counting if the redis, postgres call is unsuccessful - For counting if the redis, postgres call is unsuccessful

View file

@ -7,14 +7,21 @@
# #
# Thank you users! We ❤️ you! - Krrish & Ishaan # Thank you users! We ❤️ you! - Krrish & Ishaan
import litellm import ast
import time, logging, asyncio import asyncio
import json, traceback, ast, hashlib import hashlib
from typing import Optional, Literal, List, Union, Any, BinaryIO import json
import logging
import time
import traceback
from datetime import timedelta
from typing import Any, BinaryIO, List, Literal, Optional, Union
from openai._models import BaseModel as OpenAIObject from openai._models import BaseModel as OpenAIObject
import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.types.services import ServiceLoggerPayload, ServiceTypes from litellm.types.services import ServiceLoggerPayload, ServiceTypes
import traceback
def print_verbose(print_statement): def print_verbose(print_statement):
@ -78,6 +85,17 @@ class InMemoryCache(BaseCache):
else: else:
self.set_cache(key=cache_key, value=cache_value) self.set_cache(key=cache_key, value=cache_value)
async def async_set_cache_sadd(self, key, value: List, ttl: Optional[float]):
"""
Add value to set
"""
# get the value
init_value = self.get_cache(key=key) or set()
for val in value:
init_value.add(val)
self.set_cache(key, init_value, ttl=ttl)
return value
def get_cache(self, key, **kwargs): def get_cache(self, key, **kwargs):
if key in self.cache_dict: if key in self.cache_dict:
if key in self.ttl_dict: if key in self.ttl_dict:
@ -147,10 +165,12 @@ class RedisCache(BaseCache):
namespace: Optional[str] = None, namespace: Optional[str] = None,
**kwargs, **kwargs,
): ):
from ._redis import get_redis_client, get_redis_connection_pool
from litellm._service_logger import ServiceLogging
import redis import redis
from litellm._service_logger import ServiceLogging
from ._redis import get_redis_client, get_redis_connection_pool
redis_kwargs = {} redis_kwargs = {}
if host is not None: if host is not None:
redis_kwargs["host"] = host redis_kwargs["host"] = host
@ -329,6 +349,7 @@ class RedisCache(BaseCache):
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
call_type="async_set_cache",
) )
) )
# NON blocking - notify users Redis is throwing an exception # NON blocking - notify users Redis is throwing an exception
@ -448,6 +469,80 @@ class RedisCache(BaseCache):
cache_value, cache_value,
) )
async def async_set_cache_sadd(
self, key, value: List, ttl: Optional[float], **kwargs
):
start_time = time.time()
try:
_redis_client = self.init_async_client()
except Exception as e:
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
start_time=start_time,
end_time=end_time,
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
call_type="async_set_cache_sadd",
)
)
# NON blocking - notify users Redis is throwing an exception
verbose_logger.error(
"LiteLLM Redis Caching: async set() - Got exception from REDIS %s, Writing value=%s",
str(e),
value,
)
raise e
key = self.check_and_fix_namespace(key=key)
async with _redis_client as redis_client:
print_verbose(
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
)
try:
await redis_client.sadd(key, *value)
if ttl is not None:
_td = timedelta(seconds=ttl)
await redis_client.expire(key, _td)
print_verbose(
f"Successfully Set ASYNC Redis Cache SADD: key: {key}\nValue {value}\nttl={ttl}"
)
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_set_cache_sadd",
start_time=start_time,
end_time=end_time,
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
)
)
except Exception as e:
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_set_cache_sadd",
start_time=start_time,
end_time=end_time,
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
)
)
# NON blocking - notify users Redis is throwing an exception
verbose_logger.error(
"LiteLLM Redis Caching: async set_cache_sadd() - Got exception from REDIS %s, Writing value=%s",
str(e),
value,
)
async def batch_cache_write(self, key, value, **kwargs): async def batch_cache_write(self, key, value, **kwargs):
print_verbose( print_verbose(
f"in batch cache writing for redis buffer size={len(self.redis_batch_writing_buffer)}", f"in batch cache writing for redis buffer size={len(self.redis_batch_writing_buffer)}",
@ -886,11 +981,10 @@ class RedisSemanticCache(BaseCache):
def get_cache(self, key, **kwargs): def get_cache(self, key, **kwargs):
print_verbose(f"sync redis semantic-cache get_cache, kwargs: {kwargs}") print_verbose(f"sync redis semantic-cache get_cache, kwargs: {kwargs}")
from redisvl.query import VectorQuery
import numpy as np import numpy as np
from redisvl.query import VectorQuery
# query # query
# get the messages # get the messages
messages = kwargs["messages"] messages = kwargs["messages"]
prompt = "".join(message["content"] for message in messages) prompt = "".join(message["content"] for message in messages)
@ -943,7 +1037,8 @@ class RedisSemanticCache(BaseCache):
async def async_set_cache(self, key, value, **kwargs): async def async_set_cache(self, key, value, **kwargs):
import numpy as np import numpy as np
from litellm.proxy.proxy_server import llm_router, llm_model_list
from litellm.proxy.proxy_server import llm_model_list, llm_router
try: try:
await self.index.acreate(overwrite=False) # don't overwrite existing index await self.index.acreate(overwrite=False) # don't overwrite existing index
@ -998,12 +1093,12 @@ class RedisSemanticCache(BaseCache):
async def async_get_cache(self, key, **kwargs): async def async_get_cache(self, key, **kwargs):
print_verbose(f"async redis semantic-cache get_cache, kwargs: {kwargs}") print_verbose(f"async redis semantic-cache get_cache, kwargs: {kwargs}")
from redisvl.query import VectorQuery
import numpy as np import numpy as np
from litellm.proxy.proxy_server import llm_router, llm_model_list from redisvl.query import VectorQuery
from litellm.proxy.proxy_server import llm_model_list, llm_router
# query # query
# get the messages # get the messages
messages = kwargs["messages"] messages = kwargs["messages"]
prompt = "".join(message["content"] for message in messages) prompt = "".join(message["content"] for message in messages)
@ -1161,7 +1256,8 @@ class S3Cache(BaseCache):
self.set_cache(key=key, value=value, **kwargs) self.set_cache(key=key, value=value, **kwargs)
def get_cache(self, key, **kwargs): def get_cache(self, key, **kwargs):
import boto3, botocore import boto3
import botocore
try: try:
key = self.key_prefix + key key = self.key_prefix + key
@ -1471,7 +1567,7 @@ class DualCache(BaseCache):
key, value, **kwargs key, value, **kwargs
) )
if self.redis_cache is not None and local_only == False: if self.redis_cache is not None and local_only is False:
result = await self.redis_cache.async_increment(key, value, **kwargs) result = await self.redis_cache.async_increment(key, value, **kwargs)
return result return result
@ -1480,6 +1576,38 @@ class DualCache(BaseCache):
verbose_logger.debug(traceback.format_exc()) verbose_logger.debug(traceback.format_exc())
raise e raise e
async def async_set_cache_sadd(
self, key, value: List, local_only: bool = False, **kwargs
) -> None:
"""
Add value to a set
Key - the key in cache
Value - str - the value you want to add to the set
Returns - None
"""
try:
if self.in_memory_cache is not None:
_ = await self.in_memory_cache.async_set_cache_sadd(
key, value, ttl=kwargs.get("ttl", None)
)
if self.redis_cache is not None and local_only is False:
_ = await self.redis_cache.async_set_cache_sadd(
key, value, ttl=kwargs.get("ttl", None) ** kwargs
)
return None
except Exception as e:
verbose_logger.error(
"LiteLLM Cache: Excepton async set_cache_sadd: {}\n{}".format(
str(e), traceback.format_exc()
)
)
raise e
def flush_cache(self): def flush_cache(self):
if self.in_memory_cache is not None: if self.in_memory_cache is not None:
self.in_memory_cache.flush_cache() self.in_memory_cache.flush_cache()

View file

@ -105,8 +105,8 @@ class OpenTelemetry(CustomLogger):
self, self,
payload: ServiceLoggerPayload, payload: ServiceLoggerPayload,
parent_otel_span: Optional[Span] = None, parent_otel_span: Optional[Span] = None,
start_time: Optional[datetime] = None, start_time: Optional[Union[datetime, float]] = None,
end_time: Optional[datetime] = None, end_time: Optional[Union[datetime, float]] = None,
): ):
from datetime import datetime from datetime import datetime
@ -144,8 +144,8 @@ class OpenTelemetry(CustomLogger):
self, self,
payload: ServiceLoggerPayload, payload: ServiceLoggerPayload,
parent_otel_span: Optional[Span] = None, parent_otel_span: Optional[Span] = None,
start_time: Optional[datetime] = None, start_time: Optional[Union[datetime, float]] = None,
end_time: Optional[datetime] = None, end_time: Optional[Union[float, datetime]] = None,
): ):
from datetime import datetime from datetime import datetime

View file

@ -1,10 +1,12 @@
# What is this? # What is this?
## Allocates dynamic tpm/rpm quota for a project based on current traffic ## Allocates dynamic tpm/rpm quota for a project based on current traffic
## Tracks num active projects per minute
import asyncio
import sys import sys
import traceback import traceback
from datetime import datetime from datetime import datetime
from typing import Optional from typing import List, Literal, Optional, Tuple, Union
from fastapi import HTTPException from fastapi import HTTPException
@ -15,6 +17,7 @@ from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.router import ModelGroupInfo from litellm.types.router import ModelGroupInfo
from litellm.utils import get_utc_datetime
class DynamicRateLimiterCache: class DynamicRateLimiterCache:
@ -29,13 +32,34 @@ class DynamicRateLimiterCache:
self.ttl = 60 # 1 min ttl self.ttl = 60 # 1 min ttl
async def async_get_cache(self, model: str) -> Optional[int]: async def async_get_cache(self, model: str) -> Optional[int]:
key_name = "{}".format(model) dt = get_utc_datetime()
response = await self.cache.async_get_cache(key=key_name) current_minute = dt.strftime("%H-%M")
key_name = "{}:{}".format(current_minute, model)
_response = await self.cache.async_get_cache(key=key_name)
response: Optional[int] = None
if _response is not None:
response = len(_response)
return response return response
async def async_increment_cache(self, model: str, value: int): async def async_set_cache_sadd(self, model: str, value: List):
key_name = "{}".format(model) """
await self.cache.async_increment_cache(key=key_name, value=value, ttl=self.ttl) Add value to set.
Parameters:
- model: str, the name of the model group
- value: str, the team id
Returns:
- None
Raises:
- Exception, if unable to connect to cache client (if redis caching enabled)
"""
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
key_name = "{}:{}".format(current_minute, model)
await self.cache.async_set_cache_sadd(key=key_name, value=value, ttl=self.ttl)
class _PROXY_DynamicRateLimitHandler(CustomLogger): class _PROXY_DynamicRateLimitHandler(CustomLogger):
@ -47,13 +71,17 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
def update_variables(self, llm_router: Router): def update_variables(self, llm_router: Router):
self.llm_router = llm_router self.llm_router = llm_router
async def check_available_tpm(self, model: str) -> Optional[int]: async def check_available_tpm(
self, model: str
) -> Tuple[Optional[int], Optional[int], Optional[int]]:
""" """
For a given model, get it's available tpm For a given model, get its available tpm
Returns Returns
- int: if number found - Tuple[available_tpm, model_tpm, active_projects]
- None: if not found - available_tpm: int or null
- model_tpm: int or null. If available tpm is int, then this will be too.
- active_projects: int or null
""" """
active_projects = await self.internal_usage_cache.async_get_cache(model=model) active_projects = await self.internal_usage_cache.async_get_cache(model=model)
model_group_info: Optional[ModelGroupInfo] = ( model_group_info: Optional[ModelGroupInfo] = (
@ -61,490 +89,60 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
) )
available_tpm: Optional[int] = None available_tpm: Optional[int] = None
model_tpm: Optional[int] = None
if model_group_info is not None and model_group_info.tpm is not None: if model_group_info is not None and model_group_info.tpm is not None:
model_tpm = model_group_info.tpm
if active_projects is not None: if active_projects is not None:
available_tpm = int(model_group_info.tpm / active_projects) available_tpm = int(model_group_info.tpm / active_projects)
else: else:
available_tpm = model_group_info.tpm available_tpm = model_group_info.tpm
return available_tpm return available_tpm, model_tpm, active_projects
# async def check_key_in_limits( async def async_pre_call_hook(
# self, self,
# user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
# cache: DualCache, cache: DualCache,
# data: dict, data: dict,
# call_type: str, call_type: Literal[
# max_parallel_requests: int, "completion",
# tpm_limit: int, "text_completion",
# rpm_limit: int, "embeddings",
# request_count_api_key: str, "image_generation",
# ): "moderation",
# current = await self.internal_usage_cache.async_get_cache( "audio_transcription",
# key=request_count_api_key ],
# ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10} ) -> Optional[
# if current is None: Union[Exception, str, dict]
# if max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0: ]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm
# # base case """
# raise HTTPException( - For a model group
# status_code=429, detail="Max parallel request limit reached." - Check if tpm available
# ) - Raise RateLimitError if no tpm available
# new_val = { """
# "current_requests": 1, if "model" in data:
# "current_tpm": 0, available_tpm, model_tpm, active_projects = await self.check_available_tpm(
# "current_rpm": 0, model=data["model"]
# } )
# await self.internal_usage_cache.async_set_cache( if available_tpm is not None and available_tpm == 0:
# request_count_api_key, new_val raise HTTPException(
# ) status_code=429,
# elif ( detail={
# int(current["current_requests"]) < max_parallel_requests "error": "Team={} over available TPM={}. Model TPM={}, Active teams={}".format(
# and current["current_tpm"] < tpm_limit user_api_key_dict.team_id,
# and current["current_rpm"] < rpm_limit available_tpm,
# ): model_tpm,
# # Increase count for this token active_projects,
# new_val = { )
# "current_requests": current["current_requests"] + 1, },
# "current_tpm": current["current_tpm"], )
# "current_rpm": current["current_rpm"], elif available_tpm is not None:
# } ## UPDATE CACHE WITH ACTIVE PROJECT
# await self.internal_usage_cache.async_set_cache( asyncio.create_task(
# request_count_api_key, new_val self.internal_usage_cache.async_set_cache_sadd(
# ) model=data["model"], # type: ignore
# else: value=[user_api_key_dict.team_id or "default_team"],
# raise HTTPException( )
# status_code=429, )
# detail=f"LiteLLM Rate Limit Handler: Crossed TPM, RPM Limit. current rpm: {current['current_rpm']}, rpm limit: {rpm_limit}, current tpm: {current['current_tpm']}, tpm limit: {tpm_limit}", return None
# )
# async def async_pre_call_hook(
# self,
# user_api_key_dict: UserAPIKeyAuth,
# cache: DualCache,
# data: dict,
# call_type: str,
# ):
# self.print_verbose("Inside Dynamic Rate Limit Handler Pre-Call Hook")
# api_key = user_api_key_dict.api_key
# max_parallel_requests = user_api_key_dict.max_parallel_requests
# if max_parallel_requests is None:
# max_parallel_requests = sys.maxsize
# global_max_parallel_requests = data.get("metadata", {}).get(
# "global_max_parallel_requests", None
# )
# tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize)
# if tpm_limit is None:
# tpm_limit = sys.maxsize
# rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize)
# if rpm_limit is None:
# rpm_limit = sys.maxsize
# # ------------
# # Setup values
# # ------------
# if global_max_parallel_requests is not None:
# # get value from cache
# _key = "global_max_parallel_requests"
# current_global_requests = await self.internal_usage_cache.async_get_cache(
# key=_key, local_only=True
# )
# # check if below limit
# if current_global_requests is None:
# current_global_requests = 1
# # if above -> raise error
# if current_global_requests >= global_max_parallel_requests:
# raise HTTPException(
# status_code=429, detail="Max parallel request limit reached."
# )
# # if below -> increment
# else:
# await self.internal_usage_cache.async_increment_cache(
# key=_key, value=1, local_only=True
# )
# current_date = datetime.now().strftime("%Y-%m-%d")
# current_hour = datetime.now().strftime("%H")
# current_minute = datetime.now().strftime("%M")
# precise_minute = f"{current_date}-{current_hour}-{current_minute}"
# if api_key is not None:
# request_count_api_key = f"{api_key}::{precise_minute}::request_count"
# # CHECK IF REQUEST ALLOWED for key
# current = await self.internal_usage_cache.async_get_cache(
# key=request_count_api_key
# ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
# self.print_verbose(f"current: {current}")
# if (
# max_parallel_requests == sys.maxsize
# and tpm_limit == sys.maxsize
# and rpm_limit == sys.maxsize
# ):
# pass
# elif max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0:
# raise HTTPException(
# status_code=429, detail="Max parallel request limit reached."
# )
# elif current is None:
# new_val = {
# "current_requests": 1,
# "current_tpm": 0,
# "current_rpm": 0,
# }
# await self.internal_usage_cache.async_set_cache(
# request_count_api_key, new_val
# )
# elif (
# int(current["current_requests"]) < max_parallel_requests
# and current["current_tpm"] < tpm_limit
# and current["current_rpm"] < rpm_limit
# ):
# # Increase count for this token
# new_val = {
# "current_requests": current["current_requests"] + 1,
# "current_tpm": current["current_tpm"],
# "current_rpm": current["current_rpm"],
# }
# await self.internal_usage_cache.async_set_cache(
# request_count_api_key, new_val
# )
# else:
# raise HTTPException(
# status_code=429, detail="Max parallel request limit reached."
# )
# # check if REQUEST ALLOWED for user_id
# user_id = user_api_key_dict.user_id
# if user_id is not None:
# _user_id_rate_limits = await self.internal_usage_cache.async_get_cache(
# key=user_id
# )
# # get user tpm/rpm limits
# if _user_id_rate_limits is not None and isinstance(
# _user_id_rate_limits, dict
# ):
# user_tpm_limit = _user_id_rate_limits.get("tpm_limit", None)
# user_rpm_limit = _user_id_rate_limits.get("rpm_limit", None)
# if user_tpm_limit is None:
# user_tpm_limit = sys.maxsize
# if user_rpm_limit is None:
# user_rpm_limit = sys.maxsize
# # now do the same tpm/rpm checks
# request_count_api_key = f"{user_id}::{precise_minute}::request_count"
# # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
# await self.check_key_in_limits(
# user_api_key_dict=user_api_key_dict,
# cache=cache,
# data=data,
# call_type=call_type,
# max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user
# request_count_api_key=request_count_api_key,
# tpm_limit=user_tpm_limit,
# rpm_limit=user_rpm_limit,
# )
# # TEAM RATE LIMITS
# ## get team tpm/rpm limits
# team_id = user_api_key_dict.team_id
# if team_id is not None:
# team_tpm_limit = user_api_key_dict.team_tpm_limit
# team_rpm_limit = user_api_key_dict.team_rpm_limit
# if team_tpm_limit is None:
# team_tpm_limit = sys.maxsize
# if team_rpm_limit is None:
# team_rpm_limit = sys.maxsize
# # now do the same tpm/rpm checks
# request_count_api_key = f"{team_id}::{precise_minute}::request_count"
# # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
# await self.check_key_in_limits(
# user_api_key_dict=user_api_key_dict,
# cache=cache,
# data=data,
# call_type=call_type,
# max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a team
# request_count_api_key=request_count_api_key,
# tpm_limit=team_tpm_limit,
# rpm_limit=team_rpm_limit,
# )
# # End-User Rate Limits
# # Only enforce if user passed `user` to /chat, /completions, /embeddings
# if user_api_key_dict.end_user_id:
# end_user_tpm_limit = getattr(
# user_api_key_dict, "end_user_tpm_limit", sys.maxsize
# )
# end_user_rpm_limit = getattr(
# user_api_key_dict, "end_user_rpm_limit", sys.maxsize
# )
# if end_user_tpm_limit is None:
# end_user_tpm_limit = sys.maxsize
# if end_user_rpm_limit is None:
# end_user_rpm_limit = sys.maxsize
# # now do the same tpm/rpm checks
# request_count_api_key = (
# f"{user_api_key_dict.end_user_id}::{precise_minute}::request_count"
# )
# # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
# await self.check_key_in_limits(
# user_api_key_dict=user_api_key_dict,
# cache=cache,
# data=data,
# call_type=call_type,
# max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for an End-User
# request_count_api_key=request_count_api_key,
# tpm_limit=end_user_tpm_limit,
# rpm_limit=end_user_rpm_limit,
# )
# return
# async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
# try:
# self.print_verbose("INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
# global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
# "global_max_parallel_requests", None
# )
# user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
# user_api_key_user_id = kwargs["litellm_params"]["metadata"].get(
# "user_api_key_user_id", None
# )
# user_api_key_team_id = kwargs["litellm_params"]["metadata"].get(
# "user_api_key_team_id", None
# )
# user_api_key_end_user_id = kwargs.get("user")
# # ------------
# # Setup values
# # ------------
# if global_max_parallel_requests is not None:
# # get value from cache
# _key = "global_max_parallel_requests"
# # decrement
# await self.internal_usage_cache.async_increment_cache(
# key=_key, value=-1, local_only=True
# )
# current_date = datetime.now().strftime("%Y-%m-%d")
# current_hour = datetime.now().strftime("%H")
# current_minute = datetime.now().strftime("%M")
# precise_minute = f"{current_date}-{current_hour}-{current_minute}"
# total_tokens = 0
# if isinstance(response_obj, ModelResponse):
# total_tokens = response_obj.usage.total_tokens
# # ------------
# # Update usage - API Key
# # ------------
# if user_api_key is not None:
# request_count_api_key = (
# f"{user_api_key}::{precise_minute}::request_count"
# )
# current = await self.internal_usage_cache.async_get_cache(
# key=request_count_api_key
# ) or {
# "current_requests": 1,
# "current_tpm": total_tokens,
# "current_rpm": 1,
# }
# new_val = {
# "current_requests": max(current["current_requests"] - 1, 0),
# "current_tpm": current["current_tpm"] + total_tokens,
# "current_rpm": current["current_rpm"] + 1,
# }
# self.print_verbose(
# f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
# )
# await self.internal_usage_cache.async_set_cache(
# request_count_api_key, new_val, ttl=60
# ) # store in cache for 1 min.
# # ------------
# # Update usage - User
# # ------------
# if user_api_key_user_id is not None:
# total_tokens = 0
# if isinstance(response_obj, ModelResponse):
# total_tokens = response_obj.usage.total_tokens
# request_count_api_key = (
# f"{user_api_key_user_id}::{precise_minute}::request_count"
# )
# current = await self.internal_usage_cache.async_get_cache(
# key=request_count_api_key
# ) or {
# "current_requests": 1,
# "current_tpm": total_tokens,
# "current_rpm": 1,
# }
# new_val = {
# "current_requests": max(current["current_requests"] - 1, 0),
# "current_tpm": current["current_tpm"] + total_tokens,
# "current_rpm": current["current_rpm"] + 1,
# }
# self.print_verbose(
# f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
# )
# await self.internal_usage_cache.async_set_cache(
# request_count_api_key, new_val, ttl=60
# ) # store in cache for 1 min.
# # ------------
# # Update usage - Team
# # ------------
# if user_api_key_team_id is not None:
# total_tokens = 0
# if isinstance(response_obj, ModelResponse):
# total_tokens = response_obj.usage.total_tokens
# request_count_api_key = (
# f"{user_api_key_team_id}::{precise_minute}::request_count"
# )
# current = await self.internal_usage_cache.async_get_cache(
# key=request_count_api_key
# ) or {
# "current_requests": 1,
# "current_tpm": total_tokens,
# "current_rpm": 1,
# }
# new_val = {
# "current_requests": max(current["current_requests"] - 1, 0),
# "current_tpm": current["current_tpm"] + total_tokens,
# "current_rpm": current["current_rpm"] + 1,
# }
# self.print_verbose(
# f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
# )
# await self.internal_usage_cache.async_set_cache(
# request_count_api_key, new_val, ttl=60
# ) # store in cache for 1 min.
# # ------------
# # Update usage - End User
# # ------------
# if user_api_key_end_user_id is not None:
# total_tokens = 0
# if isinstance(response_obj, ModelResponse):
# total_tokens = response_obj.usage.total_tokens
# request_count_api_key = (
# f"{user_api_key_end_user_id}::{precise_minute}::request_count"
# )
# current = await self.internal_usage_cache.async_get_cache(
# key=request_count_api_key
# ) or {
# "current_requests": 1,
# "current_tpm": total_tokens,
# "current_rpm": 1,
# }
# new_val = {
# "current_requests": max(current["current_requests"] - 1, 0),
# "current_tpm": current["current_tpm"] + total_tokens,
# "current_rpm": current["current_rpm"] + 1,
# }
# self.print_verbose(
# f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
# )
# await self.internal_usage_cache.async_set_cache(
# request_count_api_key, new_val, ttl=60
# ) # store in cache for 1 min.
# except Exception as e:
# self.print_verbose(e) # noqa
# async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
# try:
# self.print_verbose(f"Inside Max Parallel Request Failure Hook")
# global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
# "global_max_parallel_requests", None
# )
# user_api_key = (
# kwargs["litellm_params"].get("metadata", {}).get("user_api_key", None)
# )
# self.print_verbose(f"user_api_key: {user_api_key}")
# if user_api_key is None:
# return
# ## decrement call count if call failed
# if "Max parallel request limit reached" in str(kwargs["exception"]):
# pass # ignore failed calls due to max limit being reached
# else:
# # ------------
# # Setup values
# # ------------
# if global_max_parallel_requests is not None:
# # get value from cache
# _key = "global_max_parallel_requests"
# current_global_requests = (
# await self.internal_usage_cache.async_get_cache(
# key=_key, local_only=True
# )
# )
# # decrement
# await self.internal_usage_cache.async_increment_cache(
# key=_key, value=-1, local_only=True
# )
# current_date = datetime.now().strftime("%Y-%m-%d")
# current_hour = datetime.now().strftime("%H")
# current_minute = datetime.now().strftime("%M")
# precise_minute = f"{current_date}-{current_hour}-{current_minute}"
# request_count_api_key = (
# f"{user_api_key}::{precise_minute}::request_count"
# )
# # ------------
# # Update usage
# # ------------
# current = await self.internal_usage_cache.async_get_cache(
# key=request_count_api_key
# ) or {
# "current_requests": 1,
# "current_tpm": 0,
# "current_rpm": 0,
# }
# new_val = {
# "current_requests": max(current["current_requests"] - 1, 0),
# "current_tpm": current["current_tpm"],
# "current_rpm": current["current_rpm"],
# }
# self.print_verbose(f"updated_value in failure call: {new_val}")
# await self.internal_usage_cache.async_set_cache(
# request_count_api_key, new_val, ttl=60
# ) # save in cache for up to 1 min.
# except Exception as e:
# verbose_proxy_logger.info(
# f"Inside Parallel Request Limiter: An exception occurred - {str(e)}."
# )

View file

@ -6,6 +6,7 @@ import random
import sys import sys
import time import time
import traceback import traceback
import uuid
from datetime import datetime from datetime import datetime
from typing import Tuple from typing import Tuple
@ -44,8 +45,10 @@ def dynamic_rate_limit_handler() -> DynamicRateLimitHandler:
async def test_available_tpm(num_projects, dynamic_rate_limit_handler): async def test_available_tpm(num_projects, dynamic_rate_limit_handler):
model = "my-fake-model" model = "my-fake-model"
## SET CACHE W/ ACTIVE PROJECTS ## SET CACHE W/ ACTIVE PROJECTS
await dynamic_rate_limit_handler.internal_usage_cache.async_increment_cache( projects = [str(uuid.uuid4()) for _ in range(num_projects)]
model=model, value=num_projects
await dynamic_rate_limit_handler.internal_usage_cache.async_set_cache_sadd(
model=model, value=projects
) )
model_tpm = 100 model_tpm = 100
@ -66,7 +69,9 @@ async def test_available_tpm(num_projects, dynamic_rate_limit_handler):
## CHECK AVAILABLE TPM PER PROJECT ## CHECK AVAILABLE TPM PER PROJECT
availability = await dynamic_rate_limit_handler.check_available_tpm(model=model) availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
model=model
)
expected_availability = int(model_tpm / num_projects) expected_availability = int(model_tpm / num_projects)