diff --git a/litellm/_service_logger.py b/litellm/_service_logger.py index f2edb403e9..0aa0312b80 100644 --- a/litellm/_service_logger.py +++ b/litellm/_service_logger.py @@ -1,11 +1,12 @@ -from datetime import datetime +from datetime import datetime, timedelta +from typing import TYPE_CHECKING, Any, Optional, Union + import litellm from litellm.proxy._types import UserAPIKeyAuth -from .types.services import ServiceTypes, ServiceLoggerPayload -from .integrations.prometheus_services import PrometheusServicesLogger + from .integrations.custom_logger import CustomLogger -from datetime import timedelta -from typing import Union, Optional, TYPE_CHECKING, Any +from .integrations.prometheus_services import PrometheusServicesLogger +from .types.services import ServiceLoggerPayload, ServiceTypes if TYPE_CHECKING: from opentelemetry.trace import Span as _Span @@ -53,8 +54,8 @@ class ServiceLogging(CustomLogger): call_type: str, duration: float, parent_otel_span: Optional[Span] = None, - start_time: Optional[datetime] = None, - end_time: Optional[datetime] = None, + start_time: Optional[Union[datetime, float]] = None, + end_time: Optional[Union[datetime, float]] = None, ): """ - For counting if the redis, postgres call is successful @@ -92,8 +93,8 @@ class ServiceLogging(CustomLogger): error: Union[str, Exception], call_type: str, parent_otel_span: Optional[Span] = None, - start_time: Optional[datetime] = None, - end_time: Optional[datetime] = None, + start_time: Optional[Union[datetime, float]] = None, + end_time: Optional[Union[float, datetime]] = None, ): """ - For counting if the redis, postgres call is unsuccessful diff --git a/litellm/caching.py b/litellm/caching.py index 6b58cf5276..95cad01cfd 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -7,14 +7,21 @@ # # Thank you users! We ❤️ you! - Krrish & Ishaan -import litellm -import time, logging, asyncio -import json, traceback, ast, hashlib -from typing import Optional, Literal, List, Union, Any, BinaryIO +import ast +import asyncio +import hashlib +import json +import logging +import time +import traceback +from datetime import timedelta +from typing import Any, BinaryIO, List, Literal, Optional, Union + from openai._models import BaseModel as OpenAIObject + +import litellm from litellm._logging import verbose_logger from litellm.types.services import ServiceLoggerPayload, ServiceTypes -import traceback def print_verbose(print_statement): @@ -78,6 +85,17 @@ class InMemoryCache(BaseCache): else: self.set_cache(key=cache_key, value=cache_value) + async def async_set_cache_sadd(self, key, value: List, ttl: Optional[float]): + """ + Add value to set + """ + # get the value + init_value = self.get_cache(key=key) or set() + for val in value: + init_value.add(val) + self.set_cache(key, init_value, ttl=ttl) + return value + def get_cache(self, key, **kwargs): if key in self.cache_dict: if key in self.ttl_dict: @@ -147,10 +165,12 @@ class RedisCache(BaseCache): namespace: Optional[str] = None, **kwargs, ): - from ._redis import get_redis_client, get_redis_connection_pool - from litellm._service_logger import ServiceLogging import redis + from litellm._service_logger import ServiceLogging + + from ._redis import get_redis_client, get_redis_connection_pool + redis_kwargs = {} if host is not None: redis_kwargs["host"] = host @@ -329,6 +349,7 @@ class RedisCache(BaseCache): start_time=start_time, end_time=end_time, parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + call_type="async_set_cache", ) ) # NON blocking - notify users Redis is throwing an exception @@ -448,6 +469,80 @@ class RedisCache(BaseCache): cache_value, ) + async def async_set_cache_sadd( + self, key, value: List, ttl: Optional[float], **kwargs + ): + start_time = time.time() + try: + _redis_client = self.init_async_client() + except Exception as e: + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_failure_hook( + service=ServiceTypes.REDIS, + duration=_duration, + error=e, + start_time=start_time, + end_time=end_time, + parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + call_type="async_set_cache_sadd", + ) + ) + # NON blocking - notify users Redis is throwing an exception + verbose_logger.error( + "LiteLLM Redis Caching: async set() - Got exception from REDIS %s, Writing value=%s", + str(e), + value, + ) + raise e + + key = self.check_and_fix_namespace(key=key) + async with _redis_client as redis_client: + print_verbose( + f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}" + ) + try: + await redis_client.sadd(key, *value) + if ttl is not None: + _td = timedelta(seconds=ttl) + await redis_client.expire(key, _td) + print_verbose( + f"Successfully Set ASYNC Redis Cache SADD: key: {key}\nValue {value}\nttl={ttl}" + ) + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_success_hook( + service=ServiceTypes.REDIS, + duration=_duration, + call_type="async_set_cache_sadd", + start_time=start_time, + end_time=end_time, + parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + ) + ) + except Exception as e: + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_failure_hook( + service=ServiceTypes.REDIS, + duration=_duration, + error=e, + call_type="async_set_cache_sadd", + start_time=start_time, + end_time=end_time, + parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + ) + ) + # NON blocking - notify users Redis is throwing an exception + verbose_logger.error( + "LiteLLM Redis Caching: async set_cache_sadd() - Got exception from REDIS %s, Writing value=%s", + str(e), + value, + ) + async def batch_cache_write(self, key, value, **kwargs): print_verbose( f"in batch cache writing for redis buffer size={len(self.redis_batch_writing_buffer)}", @@ -886,11 +981,10 @@ class RedisSemanticCache(BaseCache): def get_cache(self, key, **kwargs): print_verbose(f"sync redis semantic-cache get_cache, kwargs: {kwargs}") - from redisvl.query import VectorQuery import numpy as np + from redisvl.query import VectorQuery # query - # get the messages messages = kwargs["messages"] prompt = "".join(message["content"] for message in messages) @@ -943,7 +1037,8 @@ class RedisSemanticCache(BaseCache): async def async_set_cache(self, key, value, **kwargs): import numpy as np - from litellm.proxy.proxy_server import llm_router, llm_model_list + + from litellm.proxy.proxy_server import llm_model_list, llm_router try: await self.index.acreate(overwrite=False) # don't overwrite existing index @@ -998,12 +1093,12 @@ class RedisSemanticCache(BaseCache): async def async_get_cache(self, key, **kwargs): print_verbose(f"async redis semantic-cache get_cache, kwargs: {kwargs}") - from redisvl.query import VectorQuery import numpy as np - from litellm.proxy.proxy_server import llm_router, llm_model_list + from redisvl.query import VectorQuery + + from litellm.proxy.proxy_server import llm_model_list, llm_router # query - # get the messages messages = kwargs["messages"] prompt = "".join(message["content"] for message in messages) @@ -1161,7 +1256,8 @@ class S3Cache(BaseCache): self.set_cache(key=key, value=value, **kwargs) def get_cache(self, key, **kwargs): - import boto3, botocore + import boto3 + import botocore try: key = self.key_prefix + key @@ -1471,7 +1567,7 @@ class DualCache(BaseCache): key, value, **kwargs ) - if self.redis_cache is not None and local_only == False: + if self.redis_cache is not None and local_only is False: result = await self.redis_cache.async_increment(key, value, **kwargs) return result @@ -1480,6 +1576,38 @@ class DualCache(BaseCache): verbose_logger.debug(traceback.format_exc()) raise e + async def async_set_cache_sadd( + self, key, value: List, local_only: bool = False, **kwargs + ) -> None: + """ + Add value to a set + + Key - the key in cache + + Value - str - the value you want to add to the set + + Returns - None + """ + try: + if self.in_memory_cache is not None: + _ = await self.in_memory_cache.async_set_cache_sadd( + key, value, ttl=kwargs.get("ttl", None) + ) + + if self.redis_cache is not None and local_only is False: + _ = await self.redis_cache.async_set_cache_sadd( + key, value, ttl=kwargs.get("ttl", None) ** kwargs + ) + + return None + except Exception as e: + verbose_logger.error( + "LiteLLM Cache: Excepton async set_cache_sadd: {}\n{}".format( + str(e), traceback.format_exc() + ) + ) + raise e + def flush_cache(self): if self.in_memory_cache is not None: self.in_memory_cache.flush_cache() diff --git a/litellm/integrations/opentelemetry.py b/litellm/integrations/opentelemetry.py index e4daff1218..fa7be1d574 100644 --- a/litellm/integrations/opentelemetry.py +++ b/litellm/integrations/opentelemetry.py @@ -105,8 +105,8 @@ class OpenTelemetry(CustomLogger): self, payload: ServiceLoggerPayload, parent_otel_span: Optional[Span] = None, - start_time: Optional[datetime] = None, - end_time: Optional[datetime] = None, + start_time: Optional[Union[datetime, float]] = None, + end_time: Optional[Union[datetime, float]] = None, ): from datetime import datetime @@ -144,8 +144,8 @@ class OpenTelemetry(CustomLogger): self, payload: ServiceLoggerPayload, parent_otel_span: Optional[Span] = None, - start_time: Optional[datetime] = None, - end_time: Optional[datetime] = None, + start_time: Optional[Union[datetime, float]] = None, + end_time: Optional[Union[float, datetime]] = None, ): from datetime import datetime diff --git a/litellm/proxy/hooks/dynamic_rate_limiter.py b/litellm/proxy/hooks/dynamic_rate_limiter.py index 87e9ed8d4a..dc29597f6e 100644 --- a/litellm/proxy/hooks/dynamic_rate_limiter.py +++ b/litellm/proxy/hooks/dynamic_rate_limiter.py @@ -1,10 +1,12 @@ # What is this? ## Allocates dynamic tpm/rpm quota for a project based on current traffic +## Tracks num active projects per minute +import asyncio import sys import traceback from datetime import datetime -from typing import Optional +from typing import List, Literal, Optional, Tuple, Union from fastapi import HTTPException @@ -15,6 +17,7 @@ from litellm.caching import DualCache from litellm.integrations.custom_logger import CustomLogger from litellm.proxy._types import UserAPIKeyAuth from litellm.types.router import ModelGroupInfo +from litellm.utils import get_utc_datetime class DynamicRateLimiterCache: @@ -29,13 +32,34 @@ class DynamicRateLimiterCache: self.ttl = 60 # 1 min ttl async def async_get_cache(self, model: str) -> Optional[int]: - key_name = "{}".format(model) - response = await self.cache.async_get_cache(key=key_name) + dt = get_utc_datetime() + current_minute = dt.strftime("%H-%M") + key_name = "{}:{}".format(current_minute, model) + _response = await self.cache.async_get_cache(key=key_name) + response: Optional[int] = None + if _response is not None: + response = len(_response) return response - async def async_increment_cache(self, model: str, value: int): - key_name = "{}".format(model) - await self.cache.async_increment_cache(key=key_name, value=value, ttl=self.ttl) + async def async_set_cache_sadd(self, model: str, value: List): + """ + Add value to set. + + Parameters: + - model: str, the name of the model group + - value: str, the team id + + Returns: + - None + + Raises: + - Exception, if unable to connect to cache client (if redis caching enabled) + """ + dt = get_utc_datetime() + current_minute = dt.strftime("%H-%M") + + key_name = "{}:{}".format(current_minute, model) + await self.cache.async_set_cache_sadd(key=key_name, value=value, ttl=self.ttl) class _PROXY_DynamicRateLimitHandler(CustomLogger): @@ -47,13 +71,17 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger): def update_variables(self, llm_router: Router): self.llm_router = llm_router - async def check_available_tpm(self, model: str) -> Optional[int]: + async def check_available_tpm( + self, model: str + ) -> Tuple[Optional[int], Optional[int], Optional[int]]: """ - For a given model, get it's available tpm + For a given model, get its available tpm Returns - - int: if number found - - None: if not found + - Tuple[available_tpm, model_tpm, active_projects] + - available_tpm: int or null + - model_tpm: int or null. If available tpm is int, then this will be too. + - active_projects: int or null """ active_projects = await self.internal_usage_cache.async_get_cache(model=model) model_group_info: Optional[ModelGroupInfo] = ( @@ -61,490 +89,60 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger): ) available_tpm: Optional[int] = None + model_tpm: Optional[int] = None + if model_group_info is not None and model_group_info.tpm is not None: + model_tpm = model_group_info.tpm if active_projects is not None: available_tpm = int(model_group_info.tpm / active_projects) else: available_tpm = model_group_info.tpm - return available_tpm + return available_tpm, model_tpm, active_projects - # async def check_key_in_limits( - # self, - # user_api_key_dict: UserAPIKeyAuth, - # cache: DualCache, - # data: dict, - # call_type: str, - # max_parallel_requests: int, - # tpm_limit: int, - # rpm_limit: int, - # request_count_api_key: str, - # ): - # current = await self.internal_usage_cache.async_get_cache( - # key=request_count_api_key - # ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10} - # if current is None: - # if max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0: - # # base case - # raise HTTPException( - # status_code=429, detail="Max parallel request limit reached." - # ) - # new_val = { - # "current_requests": 1, - # "current_tpm": 0, - # "current_rpm": 0, - # } - # await self.internal_usage_cache.async_set_cache( - # request_count_api_key, new_val - # ) - # elif ( - # int(current["current_requests"]) < max_parallel_requests - # and current["current_tpm"] < tpm_limit - # and current["current_rpm"] < rpm_limit - # ): - # # Increase count for this token - # new_val = { - # "current_requests": current["current_requests"] + 1, - # "current_tpm": current["current_tpm"], - # "current_rpm": current["current_rpm"], - # } - # await self.internal_usage_cache.async_set_cache( - # request_count_api_key, new_val - # ) - # else: - # raise HTTPException( - # status_code=429, - # detail=f"LiteLLM Rate Limit Handler: Crossed TPM, RPM Limit. current rpm: {current['current_rpm']}, rpm limit: {rpm_limit}, current tpm: {current['current_tpm']}, tpm limit: {tpm_limit}", - # ) - - # async def async_pre_call_hook( - # self, - # user_api_key_dict: UserAPIKeyAuth, - # cache: DualCache, - # data: dict, - # call_type: str, - # ): - # self.print_verbose("Inside Dynamic Rate Limit Handler Pre-Call Hook") - # api_key = user_api_key_dict.api_key - # max_parallel_requests = user_api_key_dict.max_parallel_requests - # if max_parallel_requests is None: - # max_parallel_requests = sys.maxsize - # global_max_parallel_requests = data.get("metadata", {}).get( - # "global_max_parallel_requests", None - # ) - # tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize) - # if tpm_limit is None: - # tpm_limit = sys.maxsize - # rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize) - # if rpm_limit is None: - # rpm_limit = sys.maxsize - - # # ------------ - # # Setup values - # # ------------ - - # if global_max_parallel_requests is not None: - # # get value from cache - # _key = "global_max_parallel_requests" - # current_global_requests = await self.internal_usage_cache.async_get_cache( - # key=_key, local_only=True - # ) - # # check if below limit - # if current_global_requests is None: - # current_global_requests = 1 - # # if above -> raise error - # if current_global_requests >= global_max_parallel_requests: - # raise HTTPException( - # status_code=429, detail="Max parallel request limit reached." - # ) - # # if below -> increment - # else: - # await self.internal_usage_cache.async_increment_cache( - # key=_key, value=1, local_only=True - # ) - - # current_date = datetime.now().strftime("%Y-%m-%d") - # current_hour = datetime.now().strftime("%H") - # current_minute = datetime.now().strftime("%M") - # precise_minute = f"{current_date}-{current_hour}-{current_minute}" - - # if api_key is not None: - # request_count_api_key = f"{api_key}::{precise_minute}::request_count" - - # # CHECK IF REQUEST ALLOWED for key - - # current = await self.internal_usage_cache.async_get_cache( - # key=request_count_api_key - # ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10} - # self.print_verbose(f"current: {current}") - # if ( - # max_parallel_requests == sys.maxsize - # and tpm_limit == sys.maxsize - # and rpm_limit == sys.maxsize - # ): - # pass - # elif max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0: - # raise HTTPException( - # status_code=429, detail="Max parallel request limit reached." - # ) - # elif current is None: - # new_val = { - # "current_requests": 1, - # "current_tpm": 0, - # "current_rpm": 0, - # } - # await self.internal_usage_cache.async_set_cache( - # request_count_api_key, new_val - # ) - # elif ( - # int(current["current_requests"]) < max_parallel_requests - # and current["current_tpm"] < tpm_limit - # and current["current_rpm"] < rpm_limit - # ): - # # Increase count for this token - # new_val = { - # "current_requests": current["current_requests"] + 1, - # "current_tpm": current["current_tpm"], - # "current_rpm": current["current_rpm"], - # } - # await self.internal_usage_cache.async_set_cache( - # request_count_api_key, new_val - # ) - # else: - # raise HTTPException( - # status_code=429, detail="Max parallel request limit reached." - # ) - - # # check if REQUEST ALLOWED for user_id - # user_id = user_api_key_dict.user_id - # if user_id is not None: - # _user_id_rate_limits = await self.internal_usage_cache.async_get_cache( - # key=user_id - # ) - # # get user tpm/rpm limits - # if _user_id_rate_limits is not None and isinstance( - # _user_id_rate_limits, dict - # ): - # user_tpm_limit = _user_id_rate_limits.get("tpm_limit", None) - # user_rpm_limit = _user_id_rate_limits.get("rpm_limit", None) - # if user_tpm_limit is None: - # user_tpm_limit = sys.maxsize - # if user_rpm_limit is None: - # user_rpm_limit = sys.maxsize - - # # now do the same tpm/rpm checks - # request_count_api_key = f"{user_id}::{precise_minute}::request_count" - - # # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}") - # await self.check_key_in_limits( - # user_api_key_dict=user_api_key_dict, - # cache=cache, - # data=data, - # call_type=call_type, - # max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user - # request_count_api_key=request_count_api_key, - # tpm_limit=user_tpm_limit, - # rpm_limit=user_rpm_limit, - # ) - - # # TEAM RATE LIMITS - # ## get team tpm/rpm limits - # team_id = user_api_key_dict.team_id - # if team_id is not None: - # team_tpm_limit = user_api_key_dict.team_tpm_limit - # team_rpm_limit = user_api_key_dict.team_rpm_limit - - # if team_tpm_limit is None: - # team_tpm_limit = sys.maxsize - # if team_rpm_limit is None: - # team_rpm_limit = sys.maxsize - - # # now do the same tpm/rpm checks - # request_count_api_key = f"{team_id}::{precise_minute}::request_count" - - # # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}") - # await self.check_key_in_limits( - # user_api_key_dict=user_api_key_dict, - # cache=cache, - # data=data, - # call_type=call_type, - # max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a team - # request_count_api_key=request_count_api_key, - # tpm_limit=team_tpm_limit, - # rpm_limit=team_rpm_limit, - # ) - - # # End-User Rate Limits - # # Only enforce if user passed `user` to /chat, /completions, /embeddings - # if user_api_key_dict.end_user_id: - # end_user_tpm_limit = getattr( - # user_api_key_dict, "end_user_tpm_limit", sys.maxsize - # ) - # end_user_rpm_limit = getattr( - # user_api_key_dict, "end_user_rpm_limit", sys.maxsize - # ) - - # if end_user_tpm_limit is None: - # end_user_tpm_limit = sys.maxsize - # if end_user_rpm_limit is None: - # end_user_rpm_limit = sys.maxsize - - # # now do the same tpm/rpm checks - # request_count_api_key = ( - # f"{user_api_key_dict.end_user_id}::{precise_minute}::request_count" - # ) - - # # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}") - # await self.check_key_in_limits( - # user_api_key_dict=user_api_key_dict, - # cache=cache, - # data=data, - # call_type=call_type, - # max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for an End-User - # request_count_api_key=request_count_api_key, - # tpm_limit=end_user_tpm_limit, - # rpm_limit=end_user_rpm_limit, - # ) - - # return - - # async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): - # try: - # self.print_verbose("INSIDE parallel request limiter ASYNC SUCCESS LOGGING") - # global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get( - # "global_max_parallel_requests", None - # ) - # user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"] - # user_api_key_user_id = kwargs["litellm_params"]["metadata"].get( - # "user_api_key_user_id", None - # ) - # user_api_key_team_id = kwargs["litellm_params"]["metadata"].get( - # "user_api_key_team_id", None - # ) - # user_api_key_end_user_id = kwargs.get("user") - - # # ------------ - # # Setup values - # # ------------ - - # if global_max_parallel_requests is not None: - # # get value from cache - # _key = "global_max_parallel_requests" - # # decrement - # await self.internal_usage_cache.async_increment_cache( - # key=_key, value=-1, local_only=True - # ) - - # current_date = datetime.now().strftime("%Y-%m-%d") - # current_hour = datetime.now().strftime("%H") - # current_minute = datetime.now().strftime("%M") - # precise_minute = f"{current_date}-{current_hour}-{current_minute}" - - # total_tokens = 0 - - # if isinstance(response_obj, ModelResponse): - # total_tokens = response_obj.usage.total_tokens - - # # ------------ - # # Update usage - API Key - # # ------------ - - # if user_api_key is not None: - # request_count_api_key = ( - # f"{user_api_key}::{precise_minute}::request_count" - # ) - - # current = await self.internal_usage_cache.async_get_cache( - # key=request_count_api_key - # ) or { - # "current_requests": 1, - # "current_tpm": total_tokens, - # "current_rpm": 1, - # } - - # new_val = { - # "current_requests": max(current["current_requests"] - 1, 0), - # "current_tpm": current["current_tpm"] + total_tokens, - # "current_rpm": current["current_rpm"] + 1, - # } - - # self.print_verbose( - # f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" - # ) - # await self.internal_usage_cache.async_set_cache( - # request_count_api_key, new_val, ttl=60 - # ) # store in cache for 1 min. - - # # ------------ - # # Update usage - User - # # ------------ - # if user_api_key_user_id is not None: - # total_tokens = 0 - - # if isinstance(response_obj, ModelResponse): - # total_tokens = response_obj.usage.total_tokens - - # request_count_api_key = ( - # f"{user_api_key_user_id}::{precise_minute}::request_count" - # ) - - # current = await self.internal_usage_cache.async_get_cache( - # key=request_count_api_key - # ) or { - # "current_requests": 1, - # "current_tpm": total_tokens, - # "current_rpm": 1, - # } - - # new_val = { - # "current_requests": max(current["current_requests"] - 1, 0), - # "current_tpm": current["current_tpm"] + total_tokens, - # "current_rpm": current["current_rpm"] + 1, - # } - - # self.print_verbose( - # f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" - # ) - # await self.internal_usage_cache.async_set_cache( - # request_count_api_key, new_val, ttl=60 - # ) # store in cache for 1 min. - - # # ------------ - # # Update usage - Team - # # ------------ - # if user_api_key_team_id is not None: - # total_tokens = 0 - - # if isinstance(response_obj, ModelResponse): - # total_tokens = response_obj.usage.total_tokens - - # request_count_api_key = ( - # f"{user_api_key_team_id}::{precise_minute}::request_count" - # ) - - # current = await self.internal_usage_cache.async_get_cache( - # key=request_count_api_key - # ) or { - # "current_requests": 1, - # "current_tpm": total_tokens, - # "current_rpm": 1, - # } - - # new_val = { - # "current_requests": max(current["current_requests"] - 1, 0), - # "current_tpm": current["current_tpm"] + total_tokens, - # "current_rpm": current["current_rpm"] + 1, - # } - - # self.print_verbose( - # f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" - # ) - # await self.internal_usage_cache.async_set_cache( - # request_count_api_key, new_val, ttl=60 - # ) # store in cache for 1 min. - - # # ------------ - # # Update usage - End User - # # ------------ - # if user_api_key_end_user_id is not None: - # total_tokens = 0 - - # if isinstance(response_obj, ModelResponse): - # total_tokens = response_obj.usage.total_tokens - - # request_count_api_key = ( - # f"{user_api_key_end_user_id}::{precise_minute}::request_count" - # ) - - # current = await self.internal_usage_cache.async_get_cache( - # key=request_count_api_key - # ) or { - # "current_requests": 1, - # "current_tpm": total_tokens, - # "current_rpm": 1, - # } - - # new_val = { - # "current_requests": max(current["current_requests"] - 1, 0), - # "current_tpm": current["current_tpm"] + total_tokens, - # "current_rpm": current["current_rpm"] + 1, - # } - - # self.print_verbose( - # f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" - # ) - # await self.internal_usage_cache.async_set_cache( - # request_count_api_key, new_val, ttl=60 - # ) # store in cache for 1 min. - - # except Exception as e: - # self.print_verbose(e) # noqa - - # async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): - # try: - # self.print_verbose(f"Inside Max Parallel Request Failure Hook") - # global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get( - # "global_max_parallel_requests", None - # ) - # user_api_key = ( - # kwargs["litellm_params"].get("metadata", {}).get("user_api_key", None) - # ) - # self.print_verbose(f"user_api_key: {user_api_key}") - # if user_api_key is None: - # return - - # ## decrement call count if call failed - # if "Max parallel request limit reached" in str(kwargs["exception"]): - # pass # ignore failed calls due to max limit being reached - # else: - # # ------------ - # # Setup values - # # ------------ - - # if global_max_parallel_requests is not None: - # # get value from cache - # _key = "global_max_parallel_requests" - # current_global_requests = ( - # await self.internal_usage_cache.async_get_cache( - # key=_key, local_only=True - # ) - # ) - # # decrement - # await self.internal_usage_cache.async_increment_cache( - # key=_key, value=-1, local_only=True - # ) - - # current_date = datetime.now().strftime("%Y-%m-%d") - # current_hour = datetime.now().strftime("%H") - # current_minute = datetime.now().strftime("%M") - # precise_minute = f"{current_date}-{current_hour}-{current_minute}" - - # request_count_api_key = ( - # f"{user_api_key}::{precise_minute}::request_count" - # ) - - # # ------------ - # # Update usage - # # ------------ - # current = await self.internal_usage_cache.async_get_cache( - # key=request_count_api_key - # ) or { - # "current_requests": 1, - # "current_tpm": 0, - # "current_rpm": 0, - # } - - # new_val = { - # "current_requests": max(current["current_requests"] - 1, 0), - # "current_tpm": current["current_tpm"], - # "current_rpm": current["current_rpm"], - # } - - # self.print_verbose(f"updated_value in failure call: {new_val}") - # await self.internal_usage_cache.async_set_cache( - # request_count_api_key, new_val, ttl=60 - # ) # save in cache for up to 1 min. - # except Exception as e: - # verbose_proxy_logger.info( - # f"Inside Parallel Request Limiter: An exception occurred - {str(e)}." - # ) + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: Literal[ + "completion", + "text_completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + ], + ) -> Optional[ + Union[Exception, str, dict] + ]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm + """ + - For a model group + - Check if tpm available + - Raise RateLimitError if no tpm available + """ + if "model" in data: + available_tpm, model_tpm, active_projects = await self.check_available_tpm( + model=data["model"] + ) + if available_tpm is not None and available_tpm == 0: + raise HTTPException( + status_code=429, + detail={ + "error": "Team={} over available TPM={}. Model TPM={}, Active teams={}".format( + user_api_key_dict.team_id, + available_tpm, + model_tpm, + active_projects, + ) + }, + ) + elif available_tpm is not None: + ## UPDATE CACHE WITH ACTIVE PROJECT + asyncio.create_task( + self.internal_usage_cache.async_set_cache_sadd( + model=data["model"], # type: ignore + value=[user_api_key_dict.team_id or "default_team"], + ) + ) + return None diff --git a/litellm/tests/test_dynamic_rate_limit_handler.py b/litellm/tests/test_dynamic_rate_limit_handler.py index 1efe6ef260..71e3ac5359 100644 --- a/litellm/tests/test_dynamic_rate_limit_handler.py +++ b/litellm/tests/test_dynamic_rate_limit_handler.py @@ -6,6 +6,7 @@ import random import sys import time import traceback +import uuid from datetime import datetime from typing import Tuple @@ -44,8 +45,10 @@ def dynamic_rate_limit_handler() -> DynamicRateLimitHandler: async def test_available_tpm(num_projects, dynamic_rate_limit_handler): model = "my-fake-model" ## SET CACHE W/ ACTIVE PROJECTS - await dynamic_rate_limit_handler.internal_usage_cache.async_increment_cache( - model=model, value=num_projects + projects = [str(uuid.uuid4()) for _ in range(num_projects)] + + await dynamic_rate_limit_handler.internal_usage_cache.async_set_cache_sadd( + model=model, value=projects ) model_tpm = 100 @@ -66,7 +69,9 @@ async def test_available_tpm(num_projects, dynamic_rate_limit_handler): ## CHECK AVAILABLE TPM PER PROJECT - availability = await dynamic_rate_limit_handler.check_available_tpm(model=model) + availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm( + model=model + ) expected_availability = int(model_tpm / num_projects)