diff --git a/litellm/__init__.py b/litellm/__init__.py
index 6a5898ddb..eb59f6d6b 100644
--- a/litellm/__init__.py
+++ b/litellm/__init__.py
@@ -173,6 +173,7 @@ cache: Optional[Cache] = (
)
default_in_memory_ttl: Optional[float] = None
default_redis_ttl: Optional[float] = None
+default_redis_batch_cache_expiry: Optional[float] = None
model_alias_map: Dict[str, str] = {}
model_group_alias_map: Dict[str, str] = {}
max_budget: float = 0.0 # set the max budget across all providers
diff --git a/litellm/_service_logger.py b/litellm/_service_logger.py
index 4db645e66..d4aad68bb 100644
--- a/litellm/_service_logger.py
+++ b/litellm/_service_logger.py
@@ -13,9 +13,13 @@ from .types.services import ServiceLoggerPayload, ServiceTypes
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
+ from litellm.integrations.opentelemetry import OpenTelemetry
+
Span = _Span
+ OTELClass = OpenTelemetry
else:
Span = Any
+ OTELClass = Any
class ServiceLogging(CustomLogger):
@@ -111,6 +115,7 @@ class ServiceLogging(CustomLogger):
"""
- For counting if the redis, postgres call is successful
"""
+ from litellm.integrations.opentelemetry import OpenTelemetry
if self.mock_testing:
self.mock_testing_async_success_hook += 1
@@ -122,6 +127,7 @@ class ServiceLogging(CustomLogger):
duration=duration,
call_type=call_type,
)
+
for callback in litellm.service_callback:
if callback == "prometheus_system":
await self.init_prometheus_services_logger_if_none()
@@ -139,8 +145,7 @@ class ServiceLogging(CustomLogger):
end_time=end_time,
event_metadata=event_metadata,
)
- elif callback == "otel":
- from litellm.integrations.opentelemetry import OpenTelemetry
+ elif callback == "otel" or isinstance(callback, OpenTelemetry):
from litellm.proxy.proxy_server import open_telemetry_logger
await self.init_otel_logger_if_none()
@@ -214,6 +219,8 @@ class ServiceLogging(CustomLogger):
"""
- For counting if the redis, postgres call is unsuccessful
"""
+ from litellm.integrations.opentelemetry import OpenTelemetry
+
if self.mock_testing:
self.mock_testing_async_failure_hook += 1
@@ -246,8 +253,7 @@ class ServiceLogging(CustomLogger):
end_time=end_time,
event_metadata=event_metadata,
)
- elif callback == "otel":
- from litellm.integrations.opentelemetry import OpenTelemetry
+ elif callback == "otel" or isinstance(callback, OpenTelemetry):
from litellm.proxy.proxy_server import open_telemetry_logger
await self.init_otel_logger_if_none()
diff --git a/litellm/caching/dual_cache.py b/litellm/caching/dual_cache.py
index ef168f65f..a55a1a577 100644
--- a/litellm/caching/dual_cache.py
+++ b/litellm/caching/dual_cache.py
@@ -8,8 +8,10 @@ Has 4 primary methods:
- async_get_cache
"""
+import asyncio
import time
import traceback
+from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
import litellm
@@ -40,6 +42,7 @@ class LimitedSizeOrderedDict(OrderedDict):
self.popitem(last=False)
super().__setitem__(key, value)
+
class DualCache(BaseCache):
"""
DualCache is a cache implementation that updates both Redis and an in-memory cache simultaneously.
@@ -53,7 +56,7 @@ class DualCache(BaseCache):
redis_cache: Optional[RedisCache] = None,
default_in_memory_ttl: Optional[float] = None,
default_redis_ttl: Optional[float] = None,
- default_redis_batch_cache_expiry: float = 1,
+ default_redis_batch_cache_expiry: Optional[float] = None,
default_max_redis_batch_cache_size: int = 100,
) -> None:
super().__init__()
@@ -64,7 +67,11 @@ class DualCache(BaseCache):
self.last_redis_batch_access_time = LimitedSizeOrderedDict(
max_size=default_max_redis_batch_cache_size
)
- self.redis_batch_cache_expiry = default_redis_batch_cache_expiry
+ self.redis_batch_cache_expiry = (
+ default_redis_batch_cache_expiry
+ or litellm.default_redis_batch_cache_expiry
+ or 5
+ )
self.default_in_memory_ttl = (
default_in_memory_ttl or litellm.default_in_memory_ttl
)
@@ -156,52 +163,33 @@ class DualCache(BaseCache):
local_only: bool = False,
**kwargs,
):
+ received_args = locals()
+ received_args.pop("self")
+
+ def run_in_new_loop():
+ """Run the coroutine in a new event loop within this thread."""
+ new_loop = asyncio.new_event_loop()
+ try:
+ asyncio.set_event_loop(new_loop)
+ return new_loop.run_until_complete(
+ self.async_batch_get_cache(**received_args)
+ )
+ finally:
+ new_loop.close()
+ asyncio.set_event_loop(None)
+
try:
- result = [None for _ in range(len(keys))]
- if self.in_memory_cache is not None:
- in_memory_result = self.in_memory_cache.batch_get_cache(keys, **kwargs)
+ # First, try to get the current event loop
+ _ = asyncio.get_running_loop()
+ # If we're already in an event loop, run in a separate thread
+ # to avoid nested event loop issues
+ with ThreadPoolExecutor(max_workers=1) as executor:
+ future = executor.submit(run_in_new_loop)
+ return future.result()
- if in_memory_result is not None:
- result = in_memory_result
-
- if None in result and self.redis_cache is not None and local_only is False:
- """
- - for the none values in the result
- - check the redis cache
- """
- # Track the last access time for these keys
- current_time = time.time()
- key_tuple = tuple(keys)
-
- # Only hit Redis if the last access time was more than 5 seconds ago
- if (
- key_tuple not in self.last_redis_batch_access_time
- or current_time - self.last_redis_batch_access_time[key_tuple]
- >= self.redis_batch_cache_expiry
- ):
-
- sublist_keys = [
- key for key, value in zip(keys, result) if value is None
- ]
- # If not found in in-memory cache, try fetching from Redis
- redis_result = self.redis_cache.batch_get_cache(
- sublist_keys, parent_otel_span=parent_otel_span
- )
- if redis_result is not None:
- # Update in-memory cache with the value from Redis
- for key in redis_result:
- self.in_memory_cache.set_cache(
- key, redis_result[key], **kwargs
- )
-
-
- for key, value in redis_result.items():
- result[keys.index(key)] = value
-
- print_verbose(f"async batch get cache: cache result: {result}")
- return result
- except Exception:
- verbose_logger.error(traceback.format_exc())
+ except RuntimeError:
+ # No running event loop, we can safely run in this thread
+ return run_in_new_loop()
async def async_get_cache(
self,
@@ -244,6 +232,23 @@ class DualCache(BaseCache):
except Exception:
verbose_logger.error(traceback.format_exc())
+ def get_redis_batch_keys(
+ self,
+ current_time: float,
+ keys: List[str],
+ result: List[Any],
+ ) -> List[str]:
+ sublist_keys = []
+ for key, value in zip(keys, result):
+ if value is None:
+ if (
+ key not in self.last_redis_batch_access_time
+ or current_time - self.last_redis_batch_access_time[key]
+ >= self.redis_batch_cache_expiry
+ ):
+ sublist_keys.append(key)
+ return sublist_keys
+
async def async_batch_get_cache(
self,
keys: list,
@@ -266,25 +271,16 @@ class DualCache(BaseCache):
- for the none values in the result
- check the redis cache
"""
- # Track the last access time for these keys
current_time = time.time()
- key_tuple = tuple(keys)
+ sublist_keys = self.get_redis_batch_keys(current_time, keys, result)
# Only hit Redis if the last access time was more than 5 seconds ago
- if (
- key_tuple not in self.last_redis_batch_access_time
- or current_time - self.last_redis_batch_access_time[key_tuple]
- >= self.redis_batch_cache_expiry
- ):
- sublist_keys = [
- key for key, value in zip(keys, result) if value is None
- ]
+ if len(sublist_keys) > 0:
# If not found in in-memory cache, try fetching from Redis
redis_result = await self.redis_cache.async_batch_get_cache(
sublist_keys, parent_otel_span=parent_otel_span
)
-
if redis_result is not None:
# Update in-memory cache with the value from Redis
for key, value in redis_result.items():
@@ -292,6 +288,9 @@ class DualCache(BaseCache):
await self.in_memory_cache.async_set_cache(
key, redis_result[key], **kwargs
)
+ # Update the last access time for each key fetched from Redis
+ self.last_redis_batch_access_time[key] = current_time
+
for key, value in redis_result.items():
index = keys.index(key)
result[index] = value
diff --git a/litellm/caching/redis_cache.py b/litellm/caching/redis_cache.py
index 042a083a4..40bb49f44 100644
--- a/litellm/caching/redis_cache.py
+++ b/litellm/caching/redis_cache.py
@@ -732,7 +732,6 @@ class RedisCache(BaseCache):
"""
Use Redis for bulk read operations
"""
-
_redis_client = await self.init_async_client()
key_value_dict = {}
start_time = time.time()
diff --git a/litellm/proxy/_experimental/out/404.html b/litellm/proxy/_experimental/out/404.html
deleted file mode 100644
index eb9614a33..000000000
--- a/litellm/proxy/_experimental/out/404.html
+++ /dev/null
@@ -1 +0,0 @@
-
404: This page could not be found.LiteLLM Dashboard
404
This page could not be found.
\ No newline at end of file
diff --git a/litellm/proxy/_experimental/out/model_hub.html b/litellm/proxy/_experimental/out/model_hub.html
deleted file mode 100644
index 656b238bf..000000000
--- a/litellm/proxy/_experimental/out/model_hub.html
+++ /dev/null
@@ -1 +0,0 @@
-LiteLLM Dashboard
\ No newline at end of file
diff --git a/litellm/proxy/_experimental/out/onboarding.html b/litellm/proxy/_experimental/out/onboarding.html
deleted file mode 100644
index 4f05d163e..000000000
--- a/litellm/proxy/_experimental/out/onboarding.html
+++ /dev/null
@@ -1 +0,0 @@
-LiteLLM Dashboard
\ No newline at end of file
diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml
index 3271d11d9..b9315670a 100644
--- a/litellm/proxy/_new_secret_config.yaml
+++ b/litellm/proxy/_new_secret_config.yaml
@@ -3,8 +3,7 @@ model_list:
litellm_params:
model: claude-3-5-sonnet-20240620
api_key: os.environ/ANTHROPIC_API_KEY
- api_base: "http://0.0.0.0:8000"
- - model_name: my-fallback-openai-model
+ - model_name: claude-3-5-sonnet-aihubmix
litellm_params:
model: openai/claude-3-5-sonnet-20240620
input_cost_per_token: 0.000003 # 3$/M
@@ -15,9 +14,35 @@ model_list:
litellm_params:
model: gemini/gemini-1.5-flash-002
+# litellm_settings:
+# fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }]
+# callbacks: ["otel", "prometheus"]
+# default_redis_batch_cache_expiry: 10
+
+
litellm_settings:
- fallbacks: [{ "claude-3-5-sonnet-20240620": ["my-fallback-openai-model"] }]
- callbacks: ["otel", "prometheus"]
+ cache: True
+ cache_params:
+ type: redis
+
+ # disable caching on the actual API call
+ supported_call_types: []
+
+ # see https://docs.litellm.ai/docs/proxy/prod#3-use-redis-porthost-password-not-redis_url
+ host: os.environ/REDIS_HOST
+ port: os.environ/REDIS_PORT
+ password: os.environ/REDIS_PASSWORD
+
+ # see https://docs.litellm.ai/docs/proxy/caching#turn-on-batch_redis_requests
+ # see https://docs.litellm.ai/docs/proxy/prometheus
+ callbacks: ['prometheus', 'otel']
+
+ # # see https://docs.litellm.ai/docs/proxy/logging#logging-proxy-inputoutput---sentry
+ failure_callback: ['sentry']
+ service_callback: ['prometheus_system']
+
+ # redact_user_api_key_info: true
+
router_settings:
routing_strategy: latency-based-routing
@@ -29,4 +54,19 @@ router_settings:
ttl: 300
redis_host: os.environ/REDIS_HOST
redis_port: os.environ/REDIS_PORT
- redis_password: os.environ/REDIS_PASSWORD
\ No newline at end of file
+ redis_password: os.environ/REDIS_PASSWORD
+
+# see https://docs.litellm.ai/docs/proxy/prod#1-use-this-configyaml
+general_settings:
+ master_key: os.environ/LITELLM_MASTER_KEY
+ database_url: os.environ/DATABASE_URL
+ disable_master_key_return: true
+ # alerting: ['slack', 'email']
+ alerting: ['email']
+
+ # Batch write spend updates every 60s
+ proxy_batch_write_at: 60
+
+ # see https://docs.litellm.ai/docs/proxy/caching#advanced---user-api-key-cache-ttl
+ # our api keys rarely change
+ user_api_key_cache_ttl: 3600
\ No newline at end of file
diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py
index ae50326ca..9aebd9071 100644
--- a/litellm/proxy/_types.py
+++ b/litellm/proxy/_types.py
@@ -1419,6 +1419,8 @@ class UserAPIKeyAuth(
parent_otel_span: Optional[Span] = None
rpm_limit_per_model: Optional[Dict[str, int]] = None
tpm_limit_per_model: Optional[Dict[str, int]] = None
+ user_tpm_limit: Optional[int] = None
+ user_rpm_limit: Optional[int] = None
@model_validator(mode="before")
@classmethod
diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py
index 87a7b9ce2..b3f249d6f 100644
--- a/litellm/proxy/auth/auth_checks.py
+++ b/litellm/proxy/auth/auth_checks.py
@@ -9,6 +9,7 @@ Run checks for:
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
"""
import time
+import traceback
from datetime import datetime
from typing import TYPE_CHECKING, Any, List, Literal, Optional
diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py
index f6c3de22c..995a95f79 100644
--- a/litellm/proxy/auth/user_api_key_auth.py
+++ b/litellm/proxy/auth/user_api_key_auth.py
@@ -10,6 +10,7 @@ Returns a UserAPIKeyAuth object if the API key is valid
import asyncio
import json
import secrets
+import time
import traceback
from datetime import datetime, timedelta, timezone
from typing import Optional, Tuple
@@ -44,6 +45,7 @@ from pydantic import BaseModel
import litellm
from litellm._logging import verbose_logger, verbose_proxy_logger
+from litellm._service_logger import ServiceLogging
from litellm.proxy._types import *
from litellm.proxy.auth.auth_checks import (
_cache_key_object,
@@ -73,6 +75,10 @@ from litellm.proxy.auth.route_checks import RouteChecks
from litellm.proxy.auth.service_account_checks import service_account_checks
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.proxy.utils import _to_ns
+from litellm.types.services import ServiceTypes
+
+user_api_key_service_logger_obj = ServiceLogging() # used for tracking latency on OTEL
+
api_key_header = APIKeyHeader(
name=SpecialHeaders.openai_authorization.value,
@@ -214,7 +220,7 @@ async def user_api_key_auth( # noqa: PLR0915
)
parent_otel_span: Optional[Span] = None
-
+ start_time = datetime.now()
try:
route: str = get_request_route(request=request)
# get the request body
@@ -255,7 +261,7 @@ async def user_api_key_auth( # noqa: PLR0915
if open_telemetry_logger is not None:
parent_otel_span = open_telemetry_logger.tracer.start_span(
name="Received Proxy Server Request",
- start_time=_to_ns(datetime.now()),
+ start_time=_to_ns(start_time),
context=open_telemetry_logger.get_traceparent_from_header(
headers=request.headers
),
@@ -1165,6 +1171,7 @@ async def user_api_key_auth( # noqa: PLR0915
parent_otel_span=parent_otel_span,
valid_token_dict=valid_token_dict,
route=route,
+ start_time=start_time,
)
else:
raise Exception()
@@ -1219,31 +1226,39 @@ def _return_user_api_key_auth_obj(
parent_otel_span: Optional[Span],
valid_token_dict: dict,
route: str,
+ start_time: datetime,
) -> UserAPIKeyAuth:
+ end_time = datetime.now()
+ user_api_key_service_logger_obj.service_success_hook(
+ service=ServiceTypes.AUTH,
+ call_type=route,
+ start_time=start_time,
+ end_time=end_time,
+ duration=end_time.timestamp() - start_time.timestamp(),
+ parent_otel_span=parent_otel_span,
+ )
retrieved_user_role = (
_get_user_role(user_obj=user_obj) or LitellmUserRoles.INTERNAL_USER
)
+
+ user_api_key_kwargs = {
+ "api_key": api_key,
+ "parent_otel_span": parent_otel_span,
+ "user_role": retrieved_user_role,
+ **valid_token_dict,
+ }
+ if user_obj is not None:
+ user_api_key_kwargs.update(
+ user_tpm_limit=user_obj.tpm_limit,
+ user_rpm_limit=user_obj.rpm_limit,
+ )
if user_obj is not None and _is_user_proxy_admin(user_obj=user_obj):
- return UserAPIKeyAuth(
- api_key=api_key,
+ user_api_key_kwargs.update(
user_role=LitellmUserRoles.PROXY_ADMIN,
- parent_otel_span=parent_otel_span,
- **valid_token_dict,
- )
- elif _has_user_setup_sso() and route in LiteLLMRoutes.sso_only_routes.value:
- return UserAPIKeyAuth(
- api_key=api_key,
- user_role=retrieved_user_role,
- parent_otel_span=parent_otel_span,
- **valid_token_dict,
)
+ return UserAPIKeyAuth(**user_api_key_kwargs)
else:
- return UserAPIKeyAuth(
- api_key=api_key,
- user_role=retrieved_user_role,
- parent_otel_span=parent_otel_span,
- **valid_token_dict,
- )
+ return UserAPIKeyAuth(**user_api_key_kwargs)
def _is_user_proxy_admin(user_obj: Optional[LiteLLM_UserTable]):
diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py
index 75fbb68e2..4d2913912 100644
--- a/litellm/proxy/hooks/parallel_request_limiter.py
+++ b/litellm/proxy/hooks/parallel_request_limiter.py
@@ -1,7 +1,8 @@
+import asyncio
import sys
import traceback
from datetime import datetime, timedelta
-from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, TypedDict, Union
from fastapi import HTTPException
from pydantic import BaseModel
@@ -29,6 +30,14 @@ else:
InternalUsageCache = Any
+class CacheObject(TypedDict):
+ current_global_requests: Optional[dict]
+ request_count_api_key: Optional[dict]
+ request_count_user_id: Optional[dict]
+ request_count_team_id: Optional[dict]
+ request_count_end_user_id: Optional[dict]
+
+
class _PROXY_MaxParallelRequestsHandler(CustomLogger):
# Class variables or attributes
def __init__(self, internal_usage_cache: InternalUsageCache):
@@ -51,14 +60,15 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
max_parallel_requests: int,
tpm_limit: int,
rpm_limit: int,
+ current: Optional[dict],
request_count_api_key: str,
rate_limit_type: Literal["user", "customer", "team"],
values_to_update_in_cache: List[Tuple[Any, Any]],
):
- current = await self.internal_usage_cache.async_get_cache(
- key=request_count_api_key,
- litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
- ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
+ # current = await self.internal_usage_cache.async_get_cache(
+ # key=request_count_api_key,
+ # litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
+ # ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
if current is None:
if max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0:
# base case
@@ -117,6 +127,44 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
headers={"retry-after": str(self.time_to_next_minute())},
)
+ async def get_all_cache_objects(
+ self,
+ current_global_requests: Optional[str],
+ request_count_api_key: Optional[str],
+ request_count_user_id: Optional[str],
+ request_count_team_id: Optional[str],
+ request_count_end_user_id: Optional[str],
+ parent_otel_span: Optional[Span] = None,
+ ) -> CacheObject:
+ keys = [
+ current_global_requests,
+ request_count_api_key,
+ request_count_user_id,
+ request_count_team_id,
+ request_count_end_user_id,
+ ]
+ results = await self.internal_usage_cache.async_batch_get_cache(
+ keys=keys,
+ parent_otel_span=parent_otel_span,
+ )
+
+ if results is None:
+ return CacheObject(
+ current_global_requests=None,
+ request_count_api_key=None,
+ request_count_user_id=None,
+ request_count_team_id=None,
+ request_count_end_user_id=None,
+ )
+
+ return CacheObject(
+ current_global_requests=results[0],
+ request_count_api_key=results[1],
+ request_count_user_id=results[2],
+ request_count_team_id=results[3],
+ request_count_end_user_id=results[4],
+ )
+
async def async_pre_call_hook( # noqa: PLR0915
self,
user_api_key_dict: UserAPIKeyAuth,
@@ -149,6 +197,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
# Setup values
# ------------
new_val: Optional[dict] = None
+
if global_max_parallel_requests is not None:
# get value from cache
_key = "global_max_parallel_requests"
@@ -179,15 +228,40 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
+ cache_objects: CacheObject = await self.get_all_cache_objects(
+ current_global_requests=(
+ "global_max_parallel_requests"
+ if global_max_parallel_requests is not None
+ else None
+ ),
+ request_count_api_key=(
+ f"{api_key}::{precise_minute}::request_count"
+ if api_key is not None
+ else None
+ ),
+ request_count_user_id=(
+ f"{user_api_key_dict.user_id}::{precise_minute}::request_count"
+ if user_api_key_dict.user_id is not None
+ else None
+ ),
+ request_count_team_id=(
+ f"{user_api_key_dict.team_id}::{precise_minute}::request_count"
+ if user_api_key_dict.team_id is not None
+ else None
+ ),
+ request_count_end_user_id=(
+ f"{user_api_key_dict.end_user_id}::{precise_minute}::request_count"
+ if user_api_key_dict.end_user_id is not None
+ else None
+ ),
+ parent_otel_span=user_api_key_dict.parent_otel_span,
+ )
if api_key is not None:
request_count_api_key = f"{api_key}::{precise_minute}::request_count"
# CHECK IF REQUEST ALLOWED for key
- current = await self.internal_usage_cache.async_get_cache(
- key=request_count_api_key,
- litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
- ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
+ current = cache_objects["request_count_api_key"]
self.print_verbose(f"current: {current}")
if (
max_parallel_requests == sys.maxsize
@@ -303,42 +377,28 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
# check if REQUEST ALLOWED for user_id
user_id = user_api_key_dict.user_id
if user_id is not None:
- _user_id_rate_limits = await self.get_internal_user_object(
- user_id=user_id,
+ user_tpm_limit = user_api_key_dict.user_tpm_limit
+ user_rpm_limit = user_api_key_dict.user_rpm_limit
+ if user_tpm_limit is None:
+ user_tpm_limit = sys.maxsize
+ if user_rpm_limit is None:
+ user_rpm_limit = sys.maxsize
+
+ request_count_api_key = f"{user_id}::{precise_minute}::request_count"
+ # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
+ await self.check_key_in_limits(
user_api_key_dict=user_api_key_dict,
+ cache=cache,
+ data=data,
+ call_type=call_type,
+ max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user
+ current=cache_objects["request_count_user_id"],
+ request_count_api_key=request_count_api_key,
+ tpm_limit=user_tpm_limit,
+ rpm_limit=user_rpm_limit,
+ rate_limit_type="user",
+ values_to_update_in_cache=values_to_update_in_cache,
)
- # get user tpm/rpm limits
- if (
- _user_id_rate_limits is not None
- and isinstance(_user_id_rate_limits, dict)
- and (
- _user_id_rate_limits.get("tpm_limit", None) is not None
- or _user_id_rate_limits.get("rpm_limit", None) is not None
- )
- ):
- user_tpm_limit = _user_id_rate_limits.get("tpm_limit", None)
- user_rpm_limit = _user_id_rate_limits.get("rpm_limit", None)
- if user_tpm_limit is None:
- user_tpm_limit = sys.maxsize
- if user_rpm_limit is None:
- user_rpm_limit = sys.maxsize
-
- # now do the same tpm/rpm checks
- request_count_api_key = f"{user_id}::{precise_minute}::request_count"
-
- # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
- await self.check_key_in_limits(
- user_api_key_dict=user_api_key_dict,
- cache=cache,
- data=data,
- call_type=call_type,
- max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user
- request_count_api_key=request_count_api_key,
- tpm_limit=user_tpm_limit,
- rpm_limit=user_rpm_limit,
- rate_limit_type="user",
- values_to_update_in_cache=values_to_update_in_cache,
- )
# TEAM RATE LIMITS
## get team tpm/rpm limits
@@ -352,9 +412,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
if team_rpm_limit is None:
team_rpm_limit = sys.maxsize
- # now do the same tpm/rpm checks
request_count_api_key = f"{team_id}::{precise_minute}::request_count"
-
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
await self.check_key_in_limits(
user_api_key_dict=user_api_key_dict,
@@ -362,6 +420,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
data=data,
call_type=call_type,
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a team
+ current=cache_objects["request_count_team_id"],
request_count_api_key=request_count_api_key,
tpm_limit=team_tpm_limit,
rpm_limit=team_rpm_limit,
@@ -397,16 +456,19 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
call_type=call_type,
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for an End-User
request_count_api_key=request_count_api_key,
+ current=cache_objects["request_count_end_user_id"],
tpm_limit=end_user_tpm_limit,
rpm_limit=end_user_rpm_limit,
rate_limit_type="customer",
values_to_update_in_cache=values_to_update_in_cache,
)
- await self.internal_usage_cache.async_batch_set_cache(
- cache_list=values_to_update_in_cache,
- ttl=60,
- litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
+ asyncio.create_task(
+ self.internal_usage_cache.async_batch_set_cache(
+ cache_list=values_to_update_in_cache,
+ ttl=60,
+ litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
+ ) # don't block execution for cache updates
)
return
@@ -481,8 +543,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
litellm_parent_otel_span=litellm_parent_otel_span,
) or {
"current_requests": 1,
- "current_tpm": total_tokens,
- "current_rpm": 1,
+ "current_tpm": 0,
+ "current_rpm": 0,
}
new_val = {
@@ -517,8 +579,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
litellm_parent_otel_span=litellm_parent_otel_span,
) or {
"current_requests": 1,
- "current_tpm": total_tokens,
- "current_rpm": 1,
+ "current_tpm": 0,
+ "current_rpm": 0,
}
new_val = {
diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py
index 8919da978..82831b3b2 100644
--- a/litellm/proxy/utils.py
+++ b/litellm/proxy/utils.py
@@ -262,6 +262,18 @@ class InternalUsageCache:
**kwargs,
)
+ async def async_batch_get_cache(
+ self,
+ keys: list,
+ parent_otel_span: Optional[Span] = None,
+ local_only: bool = False,
+ ):
+ return await self.dual_cache.async_batch_get_cache(
+ keys=keys,
+ parent_otel_span=parent_otel_span,
+ local_only=local_only,
+ )
+
async def async_increment_cache(
self,
key,
@@ -442,6 +454,8 @@ class ProxyLogging:
litellm._async_success_callback.append(callback) # type: ignore
if callback not in litellm._async_failure_callback:
litellm._async_failure_callback.append(callback) # type: ignore
+ if callback not in litellm.service_callback:
+ litellm.service_callback.append(callback) # type: ignore
if (
len(litellm.input_callback) > 0
diff --git a/litellm/types/services.py b/litellm/types/services.py
index 08259c741..5f690f328 100644
--- a/litellm/types/services.py
+++ b/litellm/types/services.py
@@ -15,6 +15,7 @@ class ServiceTypes(str, enum.Enum):
BATCH_WRITE_TO_DB = "batch_write_to_db"
LITELLM = "self"
ROUTER = "router"
+ AUTH = "auth"
class ServiceLoggerPayload(BaseModel):
diff --git a/tests/local_testing/test_caching.py b/tests/local_testing/test_caching.py
index 1116840b5..479c1204e 100644
--- a/tests/local_testing/test_caching.py
+++ b/tests/local_testing/test_caching.py
@@ -59,12 +59,15 @@ async def test_dual_cache_async_batch_get_cache():
redis_cache = RedisCache() # get credentials from environment
dual_cache = DualCache(in_memory_cache=in_memory_cache, redis_cache=redis_cache)
- in_memory_cache.set_cache(key="test_value", value="hello world")
+ with patch.object(
+ dual_cache.redis_cache, "async_batch_get_cache", new=AsyncMock()
+ ) as mock_redis_cache:
+ mock_redis_cache.return_value = {"test_value_2": None, "test_value": "hello"}
- result = await dual_cache.async_batch_get_cache(keys=["test_value", "test_value_2"])
+ await dual_cache.async_batch_get_cache(keys=["test_value", "test_value_2"])
+ await dual_cache.async_batch_get_cache(keys=["test_value", "test_value_2"])
- assert result[0] == "hello world"
- assert result[1] == None
+ assert mock_redis_cache.call_count == 1
def test_dual_cache_batch_get_cache():
diff --git a/tests/local_testing/test_parallel_request_limiter.py b/tests/local_testing/test_parallel_request_limiter.py
index d0a9f9843..9bb2589aa 100644
--- a/tests/local_testing/test_parallel_request_limiter.py
+++ b/tests/local_testing/test_parallel_request_limiter.py
@@ -96,6 +96,7 @@ async def test_pre_call_hook():
key=request_count_api_key
)
)
+ await asyncio.sleep(1)
assert (
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
@@ -110,6 +111,7 @@ async def test_pre_call_hook_rpm_limits():
Test if error raised on hitting rpm limits
"""
_api_key = "sk-12345"
+ _api_key = hash_token(_api_key)
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key, max_parallel_requests=1, tpm_limit=9, rpm_limit=1
)
@@ -152,6 +154,7 @@ async def test_pre_call_hook_rpm_limits_retry_after():
Test if rate limit error, returns 'retry_after'
"""
_api_key = "sk-12345"
+ _api_key = hash_token(_api_key)
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key, max_parallel_requests=1, tpm_limit=9, rpm_limit=1
)
@@ -251,6 +254,7 @@ async def test_pre_call_hook_tpm_limits():
Test if error raised on hitting tpm limits
"""
_api_key = "sk-12345"
+ _api_key = hash_token(_api_key)
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key, max_parallel_requests=1, tpm_limit=9, rpm_limit=10
)
@@ -306,9 +310,9 @@ async def test_pre_call_hook_user_tpm_limits():
local_cache.set_cache(key=user_id, value=user_obj)
_api_key = "sk-12345"
+ _api_key = hash_token(_api_key)
user_api_key_dict = UserAPIKeyAuth(
- api_key=_api_key,
- user_id=user_id,
+ api_key=_api_key, user_id=user_id, user_rpm_limit=10, user_tpm_limit=9
)
res = dict(user_api_key_dict)
print("dict user", res)
@@ -372,7 +376,7 @@ async def test_success_call_hook():
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
-
+ await asyncio.sleep(1)
assert (
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
@@ -416,7 +420,7 @@ async def test_failure_call_hook():
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
-
+ await asyncio.sleep(1)
assert (
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
@@ -497,7 +501,7 @@ async def test_normal_router_call():
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
-
+ await asyncio.sleep(1)
assert (
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
@@ -579,7 +583,7 @@ async def test_normal_router_tpm_limit():
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
print("Test: Checking current_requests for precise_minute=", precise_minute)
-
+ await asyncio.sleep(1)
assert (
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
@@ -658,7 +662,7 @@ async def test_streaming_router_call():
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
-
+ await asyncio.sleep(1)
assert (
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
@@ -736,7 +740,7 @@ async def test_streaming_router_tpm_limit():
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
-
+ await asyncio.sleep(1)
assert (
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
@@ -814,7 +818,7 @@ async def test_bad_router_call():
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
-
+ await asyncio.sleep(1)
assert (
parallel_request_handler.internal_usage_cache.get_cache( # type: ignore
key=request_count_api_key
@@ -890,7 +894,7 @@ async def test_bad_router_tpm_limit():
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
-
+ await asyncio.sleep(1)
assert (
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
@@ -979,7 +983,7 @@ async def test_bad_router_tpm_limit_per_model():
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{model}::{precise_minute}::request_count"
-
+ await asyncio.sleep(1)
print(
"internal usage cache: ",
parallel_request_handler.internal_usage_cache.dual_cache.in_memory_cache.cache_dict,
diff --git a/tests/local_testing/test_user_api_key_auth.py b/tests/local_testing/test_user_api_key_auth.py
index 668d4cab4..36bb71eb9 100644
--- a/tests/local_testing/test_user_api_key_auth.py
+++ b/tests/local_testing/test_user_api_key_auth.py
@@ -139,6 +139,7 @@ async def test_check_blocked_team():
def test_returned_user_api_key_auth(user_role, expected_role):
from litellm.proxy._types import LiteLLM_UserTable, LitellmUserRoles
from litellm.proxy.auth.user_api_key_auth import _return_user_api_key_auth_obj
+ from datetime import datetime
new_obj = _return_user_api_key_auth_obj(
user_obj=LiteLLM_UserTable(
@@ -148,6 +149,7 @@ def test_returned_user_api_key_auth(user_role, expected_role):
parent_otel_span=None,
valid_token_dict={},
route="/chat/completion",
+ start_time=datetime.now(),
)
assert new_obj.user_role == expected_role