diff --git a/litellm/__init__.py b/litellm/__init__.py index 6a5898ddb..eb59f6d6b 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -173,6 +173,7 @@ cache: Optional[Cache] = ( ) default_in_memory_ttl: Optional[float] = None default_redis_ttl: Optional[float] = None +default_redis_batch_cache_expiry: Optional[float] = None model_alias_map: Dict[str, str] = {} model_group_alias_map: Dict[str, str] = {} max_budget: float = 0.0 # set the max budget across all providers diff --git a/litellm/_service_logger.py b/litellm/_service_logger.py index 4db645e66..d4aad68bb 100644 --- a/litellm/_service_logger.py +++ b/litellm/_service_logger.py @@ -13,9 +13,13 @@ from .types.services import ServiceLoggerPayload, ServiceTypes if TYPE_CHECKING: from opentelemetry.trace import Span as _Span + from litellm.integrations.opentelemetry import OpenTelemetry + Span = _Span + OTELClass = OpenTelemetry else: Span = Any + OTELClass = Any class ServiceLogging(CustomLogger): @@ -111,6 +115,7 @@ class ServiceLogging(CustomLogger): """ - For counting if the redis, postgres call is successful """ + from litellm.integrations.opentelemetry import OpenTelemetry if self.mock_testing: self.mock_testing_async_success_hook += 1 @@ -122,6 +127,7 @@ class ServiceLogging(CustomLogger): duration=duration, call_type=call_type, ) + for callback in litellm.service_callback: if callback == "prometheus_system": await self.init_prometheus_services_logger_if_none() @@ -139,8 +145,7 @@ class ServiceLogging(CustomLogger): end_time=end_time, event_metadata=event_metadata, ) - elif callback == "otel": - from litellm.integrations.opentelemetry import OpenTelemetry + elif callback == "otel" or isinstance(callback, OpenTelemetry): from litellm.proxy.proxy_server import open_telemetry_logger await self.init_otel_logger_if_none() @@ -214,6 +219,8 @@ class ServiceLogging(CustomLogger): """ - For counting if the redis, postgres call is unsuccessful """ + from litellm.integrations.opentelemetry import OpenTelemetry + if self.mock_testing: self.mock_testing_async_failure_hook += 1 @@ -246,8 +253,7 @@ class ServiceLogging(CustomLogger): end_time=end_time, event_metadata=event_metadata, ) - elif callback == "otel": - from litellm.integrations.opentelemetry import OpenTelemetry + elif callback == "otel" or isinstance(callback, OpenTelemetry): from litellm.proxy.proxy_server import open_telemetry_logger await self.init_otel_logger_if_none() diff --git a/litellm/caching/dual_cache.py b/litellm/caching/dual_cache.py index ef168f65f..a55a1a577 100644 --- a/litellm/caching/dual_cache.py +++ b/litellm/caching/dual_cache.py @@ -8,8 +8,10 @@ Has 4 primary methods: - async_get_cache """ +import asyncio import time import traceback +from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING, Any, List, Optional, Tuple import litellm @@ -40,6 +42,7 @@ class LimitedSizeOrderedDict(OrderedDict): self.popitem(last=False) super().__setitem__(key, value) + class DualCache(BaseCache): """ DualCache is a cache implementation that updates both Redis and an in-memory cache simultaneously. @@ -53,7 +56,7 @@ class DualCache(BaseCache): redis_cache: Optional[RedisCache] = None, default_in_memory_ttl: Optional[float] = None, default_redis_ttl: Optional[float] = None, - default_redis_batch_cache_expiry: float = 1, + default_redis_batch_cache_expiry: Optional[float] = None, default_max_redis_batch_cache_size: int = 100, ) -> None: super().__init__() @@ -64,7 +67,11 @@ class DualCache(BaseCache): self.last_redis_batch_access_time = LimitedSizeOrderedDict( max_size=default_max_redis_batch_cache_size ) - self.redis_batch_cache_expiry = default_redis_batch_cache_expiry + self.redis_batch_cache_expiry = ( + default_redis_batch_cache_expiry + or litellm.default_redis_batch_cache_expiry + or 5 + ) self.default_in_memory_ttl = ( default_in_memory_ttl or litellm.default_in_memory_ttl ) @@ -156,52 +163,33 @@ class DualCache(BaseCache): local_only: bool = False, **kwargs, ): + received_args = locals() + received_args.pop("self") + + def run_in_new_loop(): + """Run the coroutine in a new event loop within this thread.""" + new_loop = asyncio.new_event_loop() + try: + asyncio.set_event_loop(new_loop) + return new_loop.run_until_complete( + self.async_batch_get_cache(**received_args) + ) + finally: + new_loop.close() + asyncio.set_event_loop(None) + try: - result = [None for _ in range(len(keys))] - if self.in_memory_cache is not None: - in_memory_result = self.in_memory_cache.batch_get_cache(keys, **kwargs) + # First, try to get the current event loop + _ = asyncio.get_running_loop() + # If we're already in an event loop, run in a separate thread + # to avoid nested event loop issues + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(run_in_new_loop) + return future.result() - if in_memory_result is not None: - result = in_memory_result - - if None in result and self.redis_cache is not None and local_only is False: - """ - - for the none values in the result - - check the redis cache - """ - # Track the last access time for these keys - current_time = time.time() - key_tuple = tuple(keys) - - # Only hit Redis if the last access time was more than 5 seconds ago - if ( - key_tuple not in self.last_redis_batch_access_time - or current_time - self.last_redis_batch_access_time[key_tuple] - >= self.redis_batch_cache_expiry - ): - - sublist_keys = [ - key for key, value in zip(keys, result) if value is None - ] - # If not found in in-memory cache, try fetching from Redis - redis_result = self.redis_cache.batch_get_cache( - sublist_keys, parent_otel_span=parent_otel_span - ) - if redis_result is not None: - # Update in-memory cache with the value from Redis - for key in redis_result: - self.in_memory_cache.set_cache( - key, redis_result[key], **kwargs - ) - - - for key, value in redis_result.items(): - result[keys.index(key)] = value - - print_verbose(f"async batch get cache: cache result: {result}") - return result - except Exception: - verbose_logger.error(traceback.format_exc()) + except RuntimeError: + # No running event loop, we can safely run in this thread + return run_in_new_loop() async def async_get_cache( self, @@ -244,6 +232,23 @@ class DualCache(BaseCache): except Exception: verbose_logger.error(traceback.format_exc()) + def get_redis_batch_keys( + self, + current_time: float, + keys: List[str], + result: List[Any], + ) -> List[str]: + sublist_keys = [] + for key, value in zip(keys, result): + if value is None: + if ( + key not in self.last_redis_batch_access_time + or current_time - self.last_redis_batch_access_time[key] + >= self.redis_batch_cache_expiry + ): + sublist_keys.append(key) + return sublist_keys + async def async_batch_get_cache( self, keys: list, @@ -266,25 +271,16 @@ class DualCache(BaseCache): - for the none values in the result - check the redis cache """ - # Track the last access time for these keys current_time = time.time() - key_tuple = tuple(keys) + sublist_keys = self.get_redis_batch_keys(current_time, keys, result) # Only hit Redis if the last access time was more than 5 seconds ago - if ( - key_tuple not in self.last_redis_batch_access_time - or current_time - self.last_redis_batch_access_time[key_tuple] - >= self.redis_batch_cache_expiry - ): - sublist_keys = [ - key for key, value in zip(keys, result) if value is None - ] + if len(sublist_keys) > 0: # If not found in in-memory cache, try fetching from Redis redis_result = await self.redis_cache.async_batch_get_cache( sublist_keys, parent_otel_span=parent_otel_span ) - if redis_result is not None: # Update in-memory cache with the value from Redis for key, value in redis_result.items(): @@ -292,6 +288,9 @@ class DualCache(BaseCache): await self.in_memory_cache.async_set_cache( key, redis_result[key], **kwargs ) + # Update the last access time for each key fetched from Redis + self.last_redis_batch_access_time[key] = current_time + for key, value in redis_result.items(): index = keys.index(key) result[index] = value diff --git a/litellm/caching/redis_cache.py b/litellm/caching/redis_cache.py index 042a083a4..40bb49f44 100644 --- a/litellm/caching/redis_cache.py +++ b/litellm/caching/redis_cache.py @@ -732,7 +732,6 @@ class RedisCache(BaseCache): """ Use Redis for bulk read operations """ - _redis_client = await self.init_async_client() key_value_dict = {} start_time = time.time() diff --git a/litellm/proxy/_experimental/out/404.html b/litellm/proxy/_experimental/out/404.html deleted file mode 100644 index eb9614a33..000000000 --- a/litellm/proxy/_experimental/out/404.html +++ /dev/null @@ -1 +0,0 @@ -404: This page could not be found.LiteLLM Dashboard

404

This page could not be found.

\ No newline at end of file diff --git a/litellm/proxy/_experimental/out/model_hub.html b/litellm/proxy/_experimental/out/model_hub.html deleted file mode 100644 index 656b238bf..000000000 --- a/litellm/proxy/_experimental/out/model_hub.html +++ /dev/null @@ -1 +0,0 @@ -LiteLLM Dashboard \ No newline at end of file diff --git a/litellm/proxy/_experimental/out/onboarding.html b/litellm/proxy/_experimental/out/onboarding.html deleted file mode 100644 index 4f05d163e..000000000 --- a/litellm/proxy/_experimental/out/onboarding.html +++ /dev/null @@ -1 +0,0 @@ -LiteLLM Dashboard \ No newline at end of file diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 3271d11d9..b9315670a 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -3,8 +3,7 @@ model_list: litellm_params: model: claude-3-5-sonnet-20240620 api_key: os.environ/ANTHROPIC_API_KEY - api_base: "http://0.0.0.0:8000" - - model_name: my-fallback-openai-model + - model_name: claude-3-5-sonnet-aihubmix litellm_params: model: openai/claude-3-5-sonnet-20240620 input_cost_per_token: 0.000003 # 3$/M @@ -15,9 +14,35 @@ model_list: litellm_params: model: gemini/gemini-1.5-flash-002 +# litellm_settings: +# fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }] +# callbacks: ["otel", "prometheus"] +# default_redis_batch_cache_expiry: 10 + + litellm_settings: - fallbacks: [{ "claude-3-5-sonnet-20240620": ["my-fallback-openai-model"] }] - callbacks: ["otel", "prometheus"] + cache: True + cache_params: + type: redis + + # disable caching on the actual API call + supported_call_types: [] + + # see https://docs.litellm.ai/docs/proxy/prod#3-use-redis-porthost-password-not-redis_url + host: os.environ/REDIS_HOST + port: os.environ/REDIS_PORT + password: os.environ/REDIS_PASSWORD + + # see https://docs.litellm.ai/docs/proxy/caching#turn-on-batch_redis_requests + # see https://docs.litellm.ai/docs/proxy/prometheus + callbacks: ['prometheus', 'otel'] + + # # see https://docs.litellm.ai/docs/proxy/logging#logging-proxy-inputoutput---sentry + failure_callback: ['sentry'] + service_callback: ['prometheus_system'] + + # redact_user_api_key_info: true + router_settings: routing_strategy: latency-based-routing @@ -29,4 +54,19 @@ router_settings: ttl: 300 redis_host: os.environ/REDIS_HOST redis_port: os.environ/REDIS_PORT - redis_password: os.environ/REDIS_PASSWORD \ No newline at end of file + redis_password: os.environ/REDIS_PASSWORD + +# see https://docs.litellm.ai/docs/proxy/prod#1-use-this-configyaml +general_settings: + master_key: os.environ/LITELLM_MASTER_KEY + database_url: os.environ/DATABASE_URL + disable_master_key_return: true + # alerting: ['slack', 'email'] + alerting: ['email'] + + # Batch write spend updates every 60s + proxy_batch_write_at: 60 + + # see https://docs.litellm.ai/docs/proxy/caching#advanced---user-api-key-cache-ttl + # our api keys rarely change + user_api_key_cache_ttl: 3600 \ No newline at end of file diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index ae50326ca..9aebd9071 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1419,6 +1419,8 @@ class UserAPIKeyAuth( parent_otel_span: Optional[Span] = None rpm_limit_per_model: Optional[Dict[str, int]] = None tpm_limit_per_model: Optional[Dict[str, int]] = None + user_tpm_limit: Optional[int] = None + user_rpm_limit: Optional[int] = None @model_validator(mode="before") @classmethod diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 87a7b9ce2..b3f249d6f 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -9,6 +9,7 @@ Run checks for: 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget """ import time +import traceback from datetime import datetime from typing import TYPE_CHECKING, Any, List, Literal, Optional diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index f6c3de22c..995a95f79 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -10,6 +10,7 @@ Returns a UserAPIKeyAuth object if the API key is valid import asyncio import json import secrets +import time import traceback from datetime import datetime, timedelta, timezone from typing import Optional, Tuple @@ -44,6 +45,7 @@ from pydantic import BaseModel import litellm from litellm._logging import verbose_logger, verbose_proxy_logger +from litellm._service_logger import ServiceLogging from litellm.proxy._types import * from litellm.proxy.auth.auth_checks import ( _cache_key_object, @@ -73,6 +75,10 @@ from litellm.proxy.auth.route_checks import RouteChecks from litellm.proxy.auth.service_account_checks import service_account_checks from litellm.proxy.common_utils.http_parsing_utils import _read_request_body from litellm.proxy.utils import _to_ns +from litellm.types.services import ServiceTypes + +user_api_key_service_logger_obj = ServiceLogging() # used for tracking latency on OTEL + api_key_header = APIKeyHeader( name=SpecialHeaders.openai_authorization.value, @@ -214,7 +220,7 @@ async def user_api_key_auth( # noqa: PLR0915 ) parent_otel_span: Optional[Span] = None - + start_time = datetime.now() try: route: str = get_request_route(request=request) # get the request body @@ -255,7 +261,7 @@ async def user_api_key_auth( # noqa: PLR0915 if open_telemetry_logger is not None: parent_otel_span = open_telemetry_logger.tracer.start_span( name="Received Proxy Server Request", - start_time=_to_ns(datetime.now()), + start_time=_to_ns(start_time), context=open_telemetry_logger.get_traceparent_from_header( headers=request.headers ), @@ -1165,6 +1171,7 @@ async def user_api_key_auth( # noqa: PLR0915 parent_otel_span=parent_otel_span, valid_token_dict=valid_token_dict, route=route, + start_time=start_time, ) else: raise Exception() @@ -1219,31 +1226,39 @@ def _return_user_api_key_auth_obj( parent_otel_span: Optional[Span], valid_token_dict: dict, route: str, + start_time: datetime, ) -> UserAPIKeyAuth: + end_time = datetime.now() + user_api_key_service_logger_obj.service_success_hook( + service=ServiceTypes.AUTH, + call_type=route, + start_time=start_time, + end_time=end_time, + duration=end_time.timestamp() - start_time.timestamp(), + parent_otel_span=parent_otel_span, + ) retrieved_user_role = ( _get_user_role(user_obj=user_obj) or LitellmUserRoles.INTERNAL_USER ) + + user_api_key_kwargs = { + "api_key": api_key, + "parent_otel_span": parent_otel_span, + "user_role": retrieved_user_role, + **valid_token_dict, + } + if user_obj is not None: + user_api_key_kwargs.update( + user_tpm_limit=user_obj.tpm_limit, + user_rpm_limit=user_obj.rpm_limit, + ) if user_obj is not None and _is_user_proxy_admin(user_obj=user_obj): - return UserAPIKeyAuth( - api_key=api_key, + user_api_key_kwargs.update( user_role=LitellmUserRoles.PROXY_ADMIN, - parent_otel_span=parent_otel_span, - **valid_token_dict, - ) - elif _has_user_setup_sso() and route in LiteLLMRoutes.sso_only_routes.value: - return UserAPIKeyAuth( - api_key=api_key, - user_role=retrieved_user_role, - parent_otel_span=parent_otel_span, - **valid_token_dict, ) + return UserAPIKeyAuth(**user_api_key_kwargs) else: - return UserAPIKeyAuth( - api_key=api_key, - user_role=retrieved_user_role, - parent_otel_span=parent_otel_span, - **valid_token_dict, - ) + return UserAPIKeyAuth(**user_api_key_kwargs) def _is_user_proxy_admin(user_obj: Optional[LiteLLM_UserTable]): diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index 75fbb68e2..4d2913912 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -1,7 +1,8 @@ +import asyncio import sys import traceback from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, TypedDict, Union from fastapi import HTTPException from pydantic import BaseModel @@ -29,6 +30,14 @@ else: InternalUsageCache = Any +class CacheObject(TypedDict): + current_global_requests: Optional[dict] + request_count_api_key: Optional[dict] + request_count_user_id: Optional[dict] + request_count_team_id: Optional[dict] + request_count_end_user_id: Optional[dict] + + class _PROXY_MaxParallelRequestsHandler(CustomLogger): # Class variables or attributes def __init__(self, internal_usage_cache: InternalUsageCache): @@ -51,14 +60,15 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): max_parallel_requests: int, tpm_limit: int, rpm_limit: int, + current: Optional[dict], request_count_api_key: str, rate_limit_type: Literal["user", "customer", "team"], values_to_update_in_cache: List[Tuple[Any, Any]], ): - current = await self.internal_usage_cache.async_get_cache( - key=request_count_api_key, - litellm_parent_otel_span=user_api_key_dict.parent_otel_span, - ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10} + # current = await self.internal_usage_cache.async_get_cache( + # key=request_count_api_key, + # litellm_parent_otel_span=user_api_key_dict.parent_otel_span, + # ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10} if current is None: if max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0: # base case @@ -117,6 +127,44 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): headers={"retry-after": str(self.time_to_next_minute())}, ) + async def get_all_cache_objects( + self, + current_global_requests: Optional[str], + request_count_api_key: Optional[str], + request_count_user_id: Optional[str], + request_count_team_id: Optional[str], + request_count_end_user_id: Optional[str], + parent_otel_span: Optional[Span] = None, + ) -> CacheObject: + keys = [ + current_global_requests, + request_count_api_key, + request_count_user_id, + request_count_team_id, + request_count_end_user_id, + ] + results = await self.internal_usage_cache.async_batch_get_cache( + keys=keys, + parent_otel_span=parent_otel_span, + ) + + if results is None: + return CacheObject( + current_global_requests=None, + request_count_api_key=None, + request_count_user_id=None, + request_count_team_id=None, + request_count_end_user_id=None, + ) + + return CacheObject( + current_global_requests=results[0], + request_count_api_key=results[1], + request_count_user_id=results[2], + request_count_team_id=results[3], + request_count_end_user_id=results[4], + ) + async def async_pre_call_hook( # noqa: PLR0915 self, user_api_key_dict: UserAPIKeyAuth, @@ -149,6 +197,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): # Setup values # ------------ new_val: Optional[dict] = None + if global_max_parallel_requests is not None: # get value from cache _key = "global_max_parallel_requests" @@ -179,15 +228,40 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): current_minute = datetime.now().strftime("%M") precise_minute = f"{current_date}-{current_hour}-{current_minute}" + cache_objects: CacheObject = await self.get_all_cache_objects( + current_global_requests=( + "global_max_parallel_requests" + if global_max_parallel_requests is not None + else None + ), + request_count_api_key=( + f"{api_key}::{precise_minute}::request_count" + if api_key is not None + else None + ), + request_count_user_id=( + f"{user_api_key_dict.user_id}::{precise_minute}::request_count" + if user_api_key_dict.user_id is not None + else None + ), + request_count_team_id=( + f"{user_api_key_dict.team_id}::{precise_minute}::request_count" + if user_api_key_dict.team_id is not None + else None + ), + request_count_end_user_id=( + f"{user_api_key_dict.end_user_id}::{precise_minute}::request_count" + if user_api_key_dict.end_user_id is not None + else None + ), + parent_otel_span=user_api_key_dict.parent_otel_span, + ) if api_key is not None: request_count_api_key = f"{api_key}::{precise_minute}::request_count" # CHECK IF REQUEST ALLOWED for key - current = await self.internal_usage_cache.async_get_cache( - key=request_count_api_key, - litellm_parent_otel_span=user_api_key_dict.parent_otel_span, - ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10} + current = cache_objects["request_count_api_key"] self.print_verbose(f"current: {current}") if ( max_parallel_requests == sys.maxsize @@ -303,42 +377,28 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): # check if REQUEST ALLOWED for user_id user_id = user_api_key_dict.user_id if user_id is not None: - _user_id_rate_limits = await self.get_internal_user_object( - user_id=user_id, + user_tpm_limit = user_api_key_dict.user_tpm_limit + user_rpm_limit = user_api_key_dict.user_rpm_limit + if user_tpm_limit is None: + user_tpm_limit = sys.maxsize + if user_rpm_limit is None: + user_rpm_limit = sys.maxsize + + request_count_api_key = f"{user_id}::{precise_minute}::request_count" + # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}") + await self.check_key_in_limits( user_api_key_dict=user_api_key_dict, + cache=cache, + data=data, + call_type=call_type, + max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user + current=cache_objects["request_count_user_id"], + request_count_api_key=request_count_api_key, + tpm_limit=user_tpm_limit, + rpm_limit=user_rpm_limit, + rate_limit_type="user", + values_to_update_in_cache=values_to_update_in_cache, ) - # get user tpm/rpm limits - if ( - _user_id_rate_limits is not None - and isinstance(_user_id_rate_limits, dict) - and ( - _user_id_rate_limits.get("tpm_limit", None) is not None - or _user_id_rate_limits.get("rpm_limit", None) is not None - ) - ): - user_tpm_limit = _user_id_rate_limits.get("tpm_limit", None) - user_rpm_limit = _user_id_rate_limits.get("rpm_limit", None) - if user_tpm_limit is None: - user_tpm_limit = sys.maxsize - if user_rpm_limit is None: - user_rpm_limit = sys.maxsize - - # now do the same tpm/rpm checks - request_count_api_key = f"{user_id}::{precise_minute}::request_count" - - # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}") - await self.check_key_in_limits( - user_api_key_dict=user_api_key_dict, - cache=cache, - data=data, - call_type=call_type, - max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user - request_count_api_key=request_count_api_key, - tpm_limit=user_tpm_limit, - rpm_limit=user_rpm_limit, - rate_limit_type="user", - values_to_update_in_cache=values_to_update_in_cache, - ) # TEAM RATE LIMITS ## get team tpm/rpm limits @@ -352,9 +412,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): if team_rpm_limit is None: team_rpm_limit = sys.maxsize - # now do the same tpm/rpm checks request_count_api_key = f"{team_id}::{precise_minute}::request_count" - # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}") await self.check_key_in_limits( user_api_key_dict=user_api_key_dict, @@ -362,6 +420,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): data=data, call_type=call_type, max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a team + current=cache_objects["request_count_team_id"], request_count_api_key=request_count_api_key, tpm_limit=team_tpm_limit, rpm_limit=team_rpm_limit, @@ -397,16 +456,19 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): call_type=call_type, max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for an End-User request_count_api_key=request_count_api_key, + current=cache_objects["request_count_end_user_id"], tpm_limit=end_user_tpm_limit, rpm_limit=end_user_rpm_limit, rate_limit_type="customer", values_to_update_in_cache=values_to_update_in_cache, ) - await self.internal_usage_cache.async_batch_set_cache( - cache_list=values_to_update_in_cache, - ttl=60, - litellm_parent_otel_span=user_api_key_dict.parent_otel_span, + asyncio.create_task( + self.internal_usage_cache.async_batch_set_cache( + cache_list=values_to_update_in_cache, + ttl=60, + litellm_parent_otel_span=user_api_key_dict.parent_otel_span, + ) # don't block execution for cache updates ) return @@ -481,8 +543,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): litellm_parent_otel_span=litellm_parent_otel_span, ) or { "current_requests": 1, - "current_tpm": total_tokens, - "current_rpm": 1, + "current_tpm": 0, + "current_rpm": 0, } new_val = { @@ -517,8 +579,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): litellm_parent_otel_span=litellm_parent_otel_span, ) or { "current_requests": 1, - "current_tpm": total_tokens, - "current_rpm": 1, + "current_tpm": 0, + "current_rpm": 0, } new_val = { diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 8919da978..82831b3b2 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -262,6 +262,18 @@ class InternalUsageCache: **kwargs, ) + async def async_batch_get_cache( + self, + keys: list, + parent_otel_span: Optional[Span] = None, + local_only: bool = False, + ): + return await self.dual_cache.async_batch_get_cache( + keys=keys, + parent_otel_span=parent_otel_span, + local_only=local_only, + ) + async def async_increment_cache( self, key, @@ -442,6 +454,8 @@ class ProxyLogging: litellm._async_success_callback.append(callback) # type: ignore if callback not in litellm._async_failure_callback: litellm._async_failure_callback.append(callback) # type: ignore + if callback not in litellm.service_callback: + litellm.service_callback.append(callback) # type: ignore if ( len(litellm.input_callback) > 0 diff --git a/litellm/types/services.py b/litellm/types/services.py index 08259c741..5f690f328 100644 --- a/litellm/types/services.py +++ b/litellm/types/services.py @@ -15,6 +15,7 @@ class ServiceTypes(str, enum.Enum): BATCH_WRITE_TO_DB = "batch_write_to_db" LITELLM = "self" ROUTER = "router" + AUTH = "auth" class ServiceLoggerPayload(BaseModel): diff --git a/tests/local_testing/test_caching.py b/tests/local_testing/test_caching.py index 1116840b5..479c1204e 100644 --- a/tests/local_testing/test_caching.py +++ b/tests/local_testing/test_caching.py @@ -59,12 +59,15 @@ async def test_dual_cache_async_batch_get_cache(): redis_cache = RedisCache() # get credentials from environment dual_cache = DualCache(in_memory_cache=in_memory_cache, redis_cache=redis_cache) - in_memory_cache.set_cache(key="test_value", value="hello world") + with patch.object( + dual_cache.redis_cache, "async_batch_get_cache", new=AsyncMock() + ) as mock_redis_cache: + mock_redis_cache.return_value = {"test_value_2": None, "test_value": "hello"} - result = await dual_cache.async_batch_get_cache(keys=["test_value", "test_value_2"]) + await dual_cache.async_batch_get_cache(keys=["test_value", "test_value_2"]) + await dual_cache.async_batch_get_cache(keys=["test_value", "test_value_2"]) - assert result[0] == "hello world" - assert result[1] == None + assert mock_redis_cache.call_count == 1 def test_dual_cache_batch_get_cache(): diff --git a/tests/local_testing/test_parallel_request_limiter.py b/tests/local_testing/test_parallel_request_limiter.py index d0a9f9843..9bb2589aa 100644 --- a/tests/local_testing/test_parallel_request_limiter.py +++ b/tests/local_testing/test_parallel_request_limiter.py @@ -96,6 +96,7 @@ async def test_pre_call_hook(): key=request_count_api_key ) ) + await asyncio.sleep(1) assert ( parallel_request_handler.internal_usage_cache.get_cache( key=request_count_api_key @@ -110,6 +111,7 @@ async def test_pre_call_hook_rpm_limits(): Test if error raised on hitting rpm limits """ _api_key = "sk-12345" + _api_key = hash_token(_api_key) user_api_key_dict = UserAPIKeyAuth( api_key=_api_key, max_parallel_requests=1, tpm_limit=9, rpm_limit=1 ) @@ -152,6 +154,7 @@ async def test_pre_call_hook_rpm_limits_retry_after(): Test if rate limit error, returns 'retry_after' """ _api_key = "sk-12345" + _api_key = hash_token(_api_key) user_api_key_dict = UserAPIKeyAuth( api_key=_api_key, max_parallel_requests=1, tpm_limit=9, rpm_limit=1 ) @@ -251,6 +254,7 @@ async def test_pre_call_hook_tpm_limits(): Test if error raised on hitting tpm limits """ _api_key = "sk-12345" + _api_key = hash_token(_api_key) user_api_key_dict = UserAPIKeyAuth( api_key=_api_key, max_parallel_requests=1, tpm_limit=9, rpm_limit=10 ) @@ -306,9 +310,9 @@ async def test_pre_call_hook_user_tpm_limits(): local_cache.set_cache(key=user_id, value=user_obj) _api_key = "sk-12345" + _api_key = hash_token(_api_key) user_api_key_dict = UserAPIKeyAuth( - api_key=_api_key, - user_id=user_id, + api_key=_api_key, user_id=user_id, user_rpm_limit=10, user_tpm_limit=9 ) res = dict(user_api_key_dict) print("dict user", res) @@ -372,7 +376,7 @@ async def test_success_call_hook(): current_minute = datetime.now().strftime("%M") precise_minute = f"{current_date}-{current_hour}-{current_minute}" request_count_api_key = f"{_api_key}::{precise_minute}::request_count" - + await asyncio.sleep(1) assert ( parallel_request_handler.internal_usage_cache.get_cache( key=request_count_api_key @@ -416,7 +420,7 @@ async def test_failure_call_hook(): current_minute = datetime.now().strftime("%M") precise_minute = f"{current_date}-{current_hour}-{current_minute}" request_count_api_key = f"{_api_key}::{precise_minute}::request_count" - + await asyncio.sleep(1) assert ( parallel_request_handler.internal_usage_cache.get_cache( key=request_count_api_key @@ -497,7 +501,7 @@ async def test_normal_router_call(): current_minute = datetime.now().strftime("%M") precise_minute = f"{current_date}-{current_hour}-{current_minute}" request_count_api_key = f"{_api_key}::{precise_minute}::request_count" - + await asyncio.sleep(1) assert ( parallel_request_handler.internal_usage_cache.get_cache( key=request_count_api_key @@ -579,7 +583,7 @@ async def test_normal_router_tpm_limit(): precise_minute = f"{current_date}-{current_hour}-{current_minute}" request_count_api_key = f"{_api_key}::{precise_minute}::request_count" print("Test: Checking current_requests for precise_minute=", precise_minute) - + await asyncio.sleep(1) assert ( parallel_request_handler.internal_usage_cache.get_cache( key=request_count_api_key @@ -658,7 +662,7 @@ async def test_streaming_router_call(): current_minute = datetime.now().strftime("%M") precise_minute = f"{current_date}-{current_hour}-{current_minute}" request_count_api_key = f"{_api_key}::{precise_minute}::request_count" - + await asyncio.sleep(1) assert ( parallel_request_handler.internal_usage_cache.get_cache( key=request_count_api_key @@ -736,7 +740,7 @@ async def test_streaming_router_tpm_limit(): current_minute = datetime.now().strftime("%M") precise_minute = f"{current_date}-{current_hour}-{current_minute}" request_count_api_key = f"{_api_key}::{precise_minute}::request_count" - + await asyncio.sleep(1) assert ( parallel_request_handler.internal_usage_cache.get_cache( key=request_count_api_key @@ -814,7 +818,7 @@ async def test_bad_router_call(): current_minute = datetime.now().strftime("%M") precise_minute = f"{current_date}-{current_hour}-{current_minute}" request_count_api_key = f"{_api_key}::{precise_minute}::request_count" - + await asyncio.sleep(1) assert ( parallel_request_handler.internal_usage_cache.get_cache( # type: ignore key=request_count_api_key @@ -890,7 +894,7 @@ async def test_bad_router_tpm_limit(): current_minute = datetime.now().strftime("%M") precise_minute = f"{current_date}-{current_hour}-{current_minute}" request_count_api_key = f"{_api_key}::{precise_minute}::request_count" - + await asyncio.sleep(1) assert ( parallel_request_handler.internal_usage_cache.get_cache( key=request_count_api_key @@ -979,7 +983,7 @@ async def test_bad_router_tpm_limit_per_model(): current_minute = datetime.now().strftime("%M") precise_minute = f"{current_date}-{current_hour}-{current_minute}" request_count_api_key = f"{_api_key}::{model}::{precise_minute}::request_count" - + await asyncio.sleep(1) print( "internal usage cache: ", parallel_request_handler.internal_usage_cache.dual_cache.in_memory_cache.cache_dict, diff --git a/tests/local_testing/test_user_api_key_auth.py b/tests/local_testing/test_user_api_key_auth.py index 668d4cab4..36bb71eb9 100644 --- a/tests/local_testing/test_user_api_key_auth.py +++ b/tests/local_testing/test_user_api_key_auth.py @@ -139,6 +139,7 @@ async def test_check_blocked_team(): def test_returned_user_api_key_auth(user_role, expected_role): from litellm.proxy._types import LiteLLM_UserTable, LitellmUserRoles from litellm.proxy.auth.user_api_key_auth import _return_user_api_key_auth_obj + from datetime import datetime new_obj = _return_user_api_key_auth_obj( user_obj=LiteLLM_UserTable( @@ -148,6 +149,7 @@ def test_returned_user_api_key_auth(user_role, expected_role): parent_otel_span=None, valid_token_dict={}, route="/chat/completion", + start_time=datetime.now(), ) assert new_obj.user_role == expected_role