Litellm dev 11 02 2024 (#6561)

* fix(dual_cache.py): update in-memory check for redis batch get cache

Fixes latency delay for async_batch_redis_cache

* fix(service_logger.py): fix race condition causing otel service logging to be overwritten if service_callbacks set

* feat(user_api_key_auth.py): add parent otel component for auth

allows us to isolate how much latency is added by auth checks

* perf(parallel_request_limiter.py): move async_set_cache_pipeline (from max parallel request limiter) out of execution path (background task)

reduces latency by 200ms

* feat(user_api_key_auth.py): have user api key auth object return user tpm/rpm limits - reduces redis calls in downstream task (parallel_request_limiter)

Reduces latency by 400-800ms

* fix(parallel_request_limiter.py): use batch get cache to reduce user/key/team usage object calls

reduces latency by 50-100ms

* fix: fix linting error

* fix(_service_logger.py): fix import

* fix(user_api_key_auth.py): fix service logging

* fix(dual_cache.py): don't pass 'self'

* fix: fix python3.8 error

* fix: fix init]
This commit is contained in:
Krish Dholakia 2024-11-04 07:48:20 +05:30 committed by GitHub
parent 587d5fe277
commit d88e8922d4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 303 additions and 157 deletions

View file

@ -173,6 +173,7 @@ cache: Optional[Cache] = (
)
default_in_memory_ttl: Optional[float] = None
default_redis_ttl: Optional[float] = None
default_redis_batch_cache_expiry: Optional[float] = None
model_alias_map: Dict[str, str] = {}
model_group_alias_map: Dict[str, str] = {}
max_budget: float = 0.0 # set the max budget across all providers

View file

@ -13,9 +13,13 @@ from .types.services import ServiceLoggerPayload, ServiceTypes
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.integrations.opentelemetry import OpenTelemetry
Span = _Span
OTELClass = OpenTelemetry
else:
Span = Any
OTELClass = Any
class ServiceLogging(CustomLogger):
@ -111,6 +115,7 @@ class ServiceLogging(CustomLogger):
"""
- For counting if the redis, postgres call is successful
"""
from litellm.integrations.opentelemetry import OpenTelemetry
if self.mock_testing:
self.mock_testing_async_success_hook += 1
@ -122,6 +127,7 @@ class ServiceLogging(CustomLogger):
duration=duration,
call_type=call_type,
)
for callback in litellm.service_callback:
if callback == "prometheus_system":
await self.init_prometheus_services_logger_if_none()
@ -139,8 +145,7 @@ class ServiceLogging(CustomLogger):
end_time=end_time,
event_metadata=event_metadata,
)
elif callback == "otel":
from litellm.integrations.opentelemetry import OpenTelemetry
elif callback == "otel" or isinstance(callback, OpenTelemetry):
from litellm.proxy.proxy_server import open_telemetry_logger
await self.init_otel_logger_if_none()
@ -214,6 +219,8 @@ class ServiceLogging(CustomLogger):
"""
- For counting if the redis, postgres call is unsuccessful
"""
from litellm.integrations.opentelemetry import OpenTelemetry
if self.mock_testing:
self.mock_testing_async_failure_hook += 1
@ -246,8 +253,7 @@ class ServiceLogging(CustomLogger):
end_time=end_time,
event_metadata=event_metadata,
)
elif callback == "otel":
from litellm.integrations.opentelemetry import OpenTelemetry
elif callback == "otel" or isinstance(callback, OpenTelemetry):
from litellm.proxy.proxy_server import open_telemetry_logger
await self.init_otel_logger_if_none()

View file

@ -8,8 +8,10 @@ Has 4 primary methods:
- async_get_cache
"""
import asyncio
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
import litellm
@ -40,6 +42,7 @@ class LimitedSizeOrderedDict(OrderedDict):
self.popitem(last=False)
super().__setitem__(key, value)
class DualCache(BaseCache):
"""
DualCache is a cache implementation that updates both Redis and an in-memory cache simultaneously.
@ -53,7 +56,7 @@ class DualCache(BaseCache):
redis_cache: Optional[RedisCache] = None,
default_in_memory_ttl: Optional[float] = None,
default_redis_ttl: Optional[float] = None,
default_redis_batch_cache_expiry: float = 1,
default_redis_batch_cache_expiry: Optional[float] = None,
default_max_redis_batch_cache_size: int = 100,
) -> None:
super().__init__()
@ -64,7 +67,11 @@ class DualCache(BaseCache):
self.last_redis_batch_access_time = LimitedSizeOrderedDict(
max_size=default_max_redis_batch_cache_size
)
self.redis_batch_cache_expiry = default_redis_batch_cache_expiry
self.redis_batch_cache_expiry = (
default_redis_batch_cache_expiry
or litellm.default_redis_batch_cache_expiry
or 5
)
self.default_in_memory_ttl = (
default_in_memory_ttl or litellm.default_in_memory_ttl
)
@ -156,52 +163,33 @@ class DualCache(BaseCache):
local_only: bool = False,
**kwargs,
):
received_args = locals()
received_args.pop("self")
def run_in_new_loop():
"""Run the coroutine in a new event loop within this thread."""
new_loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(new_loop)
return new_loop.run_until_complete(
self.async_batch_get_cache(**received_args)
)
finally:
new_loop.close()
asyncio.set_event_loop(None)
try:
result = [None for _ in range(len(keys))]
if self.in_memory_cache is not None:
in_memory_result = self.in_memory_cache.batch_get_cache(keys, **kwargs)
# First, try to get the current event loop
_ = asyncio.get_running_loop()
# If we're already in an event loop, run in a separate thread
# to avoid nested event loop issues
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(run_in_new_loop)
return future.result()
if in_memory_result is not None:
result = in_memory_result
if None in result and self.redis_cache is not None and local_only is False:
"""
- for the none values in the result
- check the redis cache
"""
# Track the last access time for these keys
current_time = time.time()
key_tuple = tuple(keys)
# Only hit Redis if the last access time was more than 5 seconds ago
if (
key_tuple not in self.last_redis_batch_access_time
or current_time - self.last_redis_batch_access_time[key_tuple]
>= self.redis_batch_cache_expiry
):
sublist_keys = [
key for key, value in zip(keys, result) if value is None
]
# If not found in in-memory cache, try fetching from Redis
redis_result = self.redis_cache.batch_get_cache(
sublist_keys, parent_otel_span=parent_otel_span
)
if redis_result is not None:
# Update in-memory cache with the value from Redis
for key in redis_result:
self.in_memory_cache.set_cache(
key, redis_result[key], **kwargs
)
for key, value in redis_result.items():
result[keys.index(key)] = value
print_verbose(f"async batch get cache: cache result: {result}")
return result
except Exception:
verbose_logger.error(traceback.format_exc())
except RuntimeError:
# No running event loop, we can safely run in this thread
return run_in_new_loop()
async def async_get_cache(
self,
@ -244,6 +232,23 @@ class DualCache(BaseCache):
except Exception:
verbose_logger.error(traceback.format_exc())
def get_redis_batch_keys(
self,
current_time: float,
keys: List[str],
result: List[Any],
) -> List[str]:
sublist_keys = []
for key, value in zip(keys, result):
if value is None:
if (
key not in self.last_redis_batch_access_time
or current_time - self.last_redis_batch_access_time[key]
>= self.redis_batch_cache_expiry
):
sublist_keys.append(key)
return sublist_keys
async def async_batch_get_cache(
self,
keys: list,
@ -266,25 +271,16 @@ class DualCache(BaseCache):
- for the none values in the result
- check the redis cache
"""
# Track the last access time for these keys
current_time = time.time()
key_tuple = tuple(keys)
sublist_keys = self.get_redis_batch_keys(current_time, keys, result)
# Only hit Redis if the last access time was more than 5 seconds ago
if (
key_tuple not in self.last_redis_batch_access_time
or current_time - self.last_redis_batch_access_time[key_tuple]
>= self.redis_batch_cache_expiry
):
sublist_keys = [
key for key, value in zip(keys, result) if value is None
]
if len(sublist_keys) > 0:
# If not found in in-memory cache, try fetching from Redis
redis_result = await self.redis_cache.async_batch_get_cache(
sublist_keys, parent_otel_span=parent_otel_span
)
if redis_result is not None:
# Update in-memory cache with the value from Redis
for key, value in redis_result.items():
@ -292,6 +288,9 @@ class DualCache(BaseCache):
await self.in_memory_cache.async_set_cache(
key, redis_result[key], **kwargs
)
# Update the last access time for each key fetched from Redis
self.last_redis_batch_access_time[key] = current_time
for key, value in redis_result.items():
index = keys.index(key)
result[index] = value

View file

@ -732,7 +732,6 @@ class RedisCache(BaseCache):
"""
Use Redis for bulk read operations
"""
_redis_client = await self.init_async_client()
key_value_dict = {}
start_time = time.time()

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -3,8 +3,7 @@ model_list:
litellm_params:
model: claude-3-5-sonnet-20240620
api_key: os.environ/ANTHROPIC_API_KEY
api_base: "http://0.0.0.0:8000"
- model_name: my-fallback-openai-model
- model_name: claude-3-5-sonnet-aihubmix
litellm_params:
model: openai/claude-3-5-sonnet-20240620
input_cost_per_token: 0.000003 # 3$/M
@ -15,9 +14,35 @@ model_list:
litellm_params:
model: gemini/gemini-1.5-flash-002
# litellm_settings:
# fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }]
# callbacks: ["otel", "prometheus"]
# default_redis_batch_cache_expiry: 10
litellm_settings:
fallbacks: [{ "claude-3-5-sonnet-20240620": ["my-fallback-openai-model"] }]
callbacks: ["otel", "prometheus"]
cache: True
cache_params:
type: redis
# disable caching on the actual API call
supported_call_types: []
# see https://docs.litellm.ai/docs/proxy/prod#3-use-redis-porthost-password-not-redis_url
host: os.environ/REDIS_HOST
port: os.environ/REDIS_PORT
password: os.environ/REDIS_PASSWORD
# see https://docs.litellm.ai/docs/proxy/caching#turn-on-batch_redis_requests
# see https://docs.litellm.ai/docs/proxy/prometheus
callbacks: ['prometheus', 'otel']
# # see https://docs.litellm.ai/docs/proxy/logging#logging-proxy-inputoutput---sentry
failure_callback: ['sentry']
service_callback: ['prometheus_system']
# redact_user_api_key_info: true
router_settings:
routing_strategy: latency-based-routing
@ -29,4 +54,19 @@ router_settings:
ttl: 300
redis_host: os.environ/REDIS_HOST
redis_port: os.environ/REDIS_PORT
redis_password: os.environ/REDIS_PASSWORD
redis_password: os.environ/REDIS_PASSWORD
# see https://docs.litellm.ai/docs/proxy/prod#1-use-this-configyaml
general_settings:
master_key: os.environ/LITELLM_MASTER_KEY
database_url: os.environ/DATABASE_URL
disable_master_key_return: true
# alerting: ['slack', 'email']
alerting: ['email']
# Batch write spend updates every 60s
proxy_batch_write_at: 60
# see https://docs.litellm.ai/docs/proxy/caching#advanced---user-api-key-cache-ttl
# our api keys rarely change
user_api_key_cache_ttl: 3600

View file

@ -1419,6 +1419,8 @@ class UserAPIKeyAuth(
parent_otel_span: Optional[Span] = None
rpm_limit_per_model: Optional[Dict[str, int]] = None
tpm_limit_per_model: Optional[Dict[str, int]] = None
user_tpm_limit: Optional[int] = None
user_rpm_limit: Optional[int] = None
@model_validator(mode="before")
@classmethod

View file

@ -9,6 +9,7 @@ Run checks for:
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
"""
import time
import traceback
from datetime import datetime
from typing import TYPE_CHECKING, Any, List, Literal, Optional

View file

@ -10,6 +10,7 @@ Returns a UserAPIKeyAuth object if the API key is valid
import asyncio
import json
import secrets
import time
import traceback
from datetime import datetime, timedelta, timezone
from typing import Optional, Tuple
@ -44,6 +45,7 @@ from pydantic import BaseModel
import litellm
from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm._service_logger import ServiceLogging
from litellm.proxy._types import *
from litellm.proxy.auth.auth_checks import (
_cache_key_object,
@ -73,6 +75,10 @@ from litellm.proxy.auth.route_checks import RouteChecks
from litellm.proxy.auth.service_account_checks import service_account_checks
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.proxy.utils import _to_ns
from litellm.types.services import ServiceTypes
user_api_key_service_logger_obj = ServiceLogging() # used for tracking latency on OTEL
api_key_header = APIKeyHeader(
name=SpecialHeaders.openai_authorization.value,
@ -214,7 +220,7 @@ async def user_api_key_auth( # noqa: PLR0915
)
parent_otel_span: Optional[Span] = None
start_time = datetime.now()
try:
route: str = get_request_route(request=request)
# get the request body
@ -255,7 +261,7 @@ async def user_api_key_auth( # noqa: PLR0915
if open_telemetry_logger is not None:
parent_otel_span = open_telemetry_logger.tracer.start_span(
name="Received Proxy Server Request",
start_time=_to_ns(datetime.now()),
start_time=_to_ns(start_time),
context=open_telemetry_logger.get_traceparent_from_header(
headers=request.headers
),
@ -1165,6 +1171,7 @@ async def user_api_key_auth( # noqa: PLR0915
parent_otel_span=parent_otel_span,
valid_token_dict=valid_token_dict,
route=route,
start_time=start_time,
)
else:
raise Exception()
@ -1219,31 +1226,39 @@ def _return_user_api_key_auth_obj(
parent_otel_span: Optional[Span],
valid_token_dict: dict,
route: str,
start_time: datetime,
) -> UserAPIKeyAuth:
end_time = datetime.now()
user_api_key_service_logger_obj.service_success_hook(
service=ServiceTypes.AUTH,
call_type=route,
start_time=start_time,
end_time=end_time,
duration=end_time.timestamp() - start_time.timestamp(),
parent_otel_span=parent_otel_span,
)
retrieved_user_role = (
_get_user_role(user_obj=user_obj) or LitellmUserRoles.INTERNAL_USER
)
user_api_key_kwargs = {
"api_key": api_key,
"parent_otel_span": parent_otel_span,
"user_role": retrieved_user_role,
**valid_token_dict,
}
if user_obj is not None:
user_api_key_kwargs.update(
user_tpm_limit=user_obj.tpm_limit,
user_rpm_limit=user_obj.rpm_limit,
)
if user_obj is not None and _is_user_proxy_admin(user_obj=user_obj):
return UserAPIKeyAuth(
api_key=api_key,
user_api_key_kwargs.update(
user_role=LitellmUserRoles.PROXY_ADMIN,
parent_otel_span=parent_otel_span,
**valid_token_dict,
)
elif _has_user_setup_sso() and route in LiteLLMRoutes.sso_only_routes.value:
return UserAPIKeyAuth(
api_key=api_key,
user_role=retrieved_user_role,
parent_otel_span=parent_otel_span,
**valid_token_dict,
)
return UserAPIKeyAuth(**user_api_key_kwargs)
else:
return UserAPIKeyAuth(
api_key=api_key,
user_role=retrieved_user_role,
parent_otel_span=parent_otel_span,
**valid_token_dict,
)
return UserAPIKeyAuth(**user_api_key_kwargs)
def _is_user_proxy_admin(user_obj: Optional[LiteLLM_UserTable]):

View file

@ -1,7 +1,8 @@
import asyncio
import sys
import traceback
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, TypedDict, Union
from fastapi import HTTPException
from pydantic import BaseModel
@ -29,6 +30,14 @@ else:
InternalUsageCache = Any
class CacheObject(TypedDict):
current_global_requests: Optional[dict]
request_count_api_key: Optional[dict]
request_count_user_id: Optional[dict]
request_count_team_id: Optional[dict]
request_count_end_user_id: Optional[dict]
class _PROXY_MaxParallelRequestsHandler(CustomLogger):
# Class variables or attributes
def __init__(self, internal_usage_cache: InternalUsageCache):
@ -51,14 +60,15 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
max_parallel_requests: int,
tpm_limit: int,
rpm_limit: int,
current: Optional[dict],
request_count_api_key: str,
rate_limit_type: Literal["user", "customer", "team"],
values_to_update_in_cache: List[Tuple[Any, Any]],
):
current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
# current = await self.internal_usage_cache.async_get_cache(
# key=request_count_api_key,
# litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
# ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
if current is None:
if max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0:
# base case
@ -117,6 +127,44 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
headers={"retry-after": str(self.time_to_next_minute())},
)
async def get_all_cache_objects(
self,
current_global_requests: Optional[str],
request_count_api_key: Optional[str],
request_count_user_id: Optional[str],
request_count_team_id: Optional[str],
request_count_end_user_id: Optional[str],
parent_otel_span: Optional[Span] = None,
) -> CacheObject:
keys = [
current_global_requests,
request_count_api_key,
request_count_user_id,
request_count_team_id,
request_count_end_user_id,
]
results = await self.internal_usage_cache.async_batch_get_cache(
keys=keys,
parent_otel_span=parent_otel_span,
)
if results is None:
return CacheObject(
current_global_requests=None,
request_count_api_key=None,
request_count_user_id=None,
request_count_team_id=None,
request_count_end_user_id=None,
)
return CacheObject(
current_global_requests=results[0],
request_count_api_key=results[1],
request_count_user_id=results[2],
request_count_team_id=results[3],
request_count_end_user_id=results[4],
)
async def async_pre_call_hook( # noqa: PLR0915
self,
user_api_key_dict: UserAPIKeyAuth,
@ -149,6 +197,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
# Setup values
# ------------
new_val: Optional[dict] = None
if global_max_parallel_requests is not None:
# get value from cache
_key = "global_max_parallel_requests"
@ -179,15 +228,40 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
cache_objects: CacheObject = await self.get_all_cache_objects(
current_global_requests=(
"global_max_parallel_requests"
if global_max_parallel_requests is not None
else None
),
request_count_api_key=(
f"{api_key}::{precise_minute}::request_count"
if api_key is not None
else None
),
request_count_user_id=(
f"{user_api_key_dict.user_id}::{precise_minute}::request_count"
if user_api_key_dict.user_id is not None
else None
),
request_count_team_id=(
f"{user_api_key_dict.team_id}::{precise_minute}::request_count"
if user_api_key_dict.team_id is not None
else None
),
request_count_end_user_id=(
f"{user_api_key_dict.end_user_id}::{precise_minute}::request_count"
if user_api_key_dict.end_user_id is not None
else None
),
parent_otel_span=user_api_key_dict.parent_otel_span,
)
if api_key is not None:
request_count_api_key = f"{api_key}::{precise_minute}::request_count"
# CHECK IF REQUEST ALLOWED for key
current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
current = cache_objects["request_count_api_key"]
self.print_verbose(f"current: {current}")
if (
max_parallel_requests == sys.maxsize
@ -303,42 +377,28 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
# check if REQUEST ALLOWED for user_id
user_id = user_api_key_dict.user_id
if user_id is not None:
_user_id_rate_limits = await self.get_internal_user_object(
user_id=user_id,
user_tpm_limit = user_api_key_dict.user_tpm_limit
user_rpm_limit = user_api_key_dict.user_rpm_limit
if user_tpm_limit is None:
user_tpm_limit = sys.maxsize
if user_rpm_limit is None:
user_rpm_limit = sys.maxsize
request_count_api_key = f"{user_id}::{precise_minute}::request_count"
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
await self.check_key_in_limits(
user_api_key_dict=user_api_key_dict,
cache=cache,
data=data,
call_type=call_type,
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user
current=cache_objects["request_count_user_id"],
request_count_api_key=request_count_api_key,
tpm_limit=user_tpm_limit,
rpm_limit=user_rpm_limit,
rate_limit_type="user",
values_to_update_in_cache=values_to_update_in_cache,
)
# get user tpm/rpm limits
if (
_user_id_rate_limits is not None
and isinstance(_user_id_rate_limits, dict)
and (
_user_id_rate_limits.get("tpm_limit", None) is not None
or _user_id_rate_limits.get("rpm_limit", None) is not None
)
):
user_tpm_limit = _user_id_rate_limits.get("tpm_limit", None)
user_rpm_limit = _user_id_rate_limits.get("rpm_limit", None)
if user_tpm_limit is None:
user_tpm_limit = sys.maxsize
if user_rpm_limit is None:
user_rpm_limit = sys.maxsize
# now do the same tpm/rpm checks
request_count_api_key = f"{user_id}::{precise_minute}::request_count"
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
await self.check_key_in_limits(
user_api_key_dict=user_api_key_dict,
cache=cache,
data=data,
call_type=call_type,
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user
request_count_api_key=request_count_api_key,
tpm_limit=user_tpm_limit,
rpm_limit=user_rpm_limit,
rate_limit_type="user",
values_to_update_in_cache=values_to_update_in_cache,
)
# TEAM RATE LIMITS
## get team tpm/rpm limits
@ -352,9 +412,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
if team_rpm_limit is None:
team_rpm_limit = sys.maxsize
# now do the same tpm/rpm checks
request_count_api_key = f"{team_id}::{precise_minute}::request_count"
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
await self.check_key_in_limits(
user_api_key_dict=user_api_key_dict,
@ -362,6 +420,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
data=data,
call_type=call_type,
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a team
current=cache_objects["request_count_team_id"],
request_count_api_key=request_count_api_key,
tpm_limit=team_tpm_limit,
rpm_limit=team_rpm_limit,
@ -397,16 +456,19 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
call_type=call_type,
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for an End-User
request_count_api_key=request_count_api_key,
current=cache_objects["request_count_end_user_id"],
tpm_limit=end_user_tpm_limit,
rpm_limit=end_user_rpm_limit,
rate_limit_type="customer",
values_to_update_in_cache=values_to_update_in_cache,
)
await self.internal_usage_cache.async_batch_set_cache(
cache_list=values_to_update_in_cache,
ttl=60,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
asyncio.create_task(
self.internal_usage_cache.async_batch_set_cache(
cache_list=values_to_update_in_cache,
ttl=60,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
) # don't block execution for cache updates
)
return
@ -481,8 +543,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
litellm_parent_otel_span=litellm_parent_otel_span,
) or {
"current_requests": 1,
"current_tpm": total_tokens,
"current_rpm": 1,
"current_tpm": 0,
"current_rpm": 0,
}
new_val = {
@ -517,8 +579,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
litellm_parent_otel_span=litellm_parent_otel_span,
) or {
"current_requests": 1,
"current_tpm": total_tokens,
"current_rpm": 1,
"current_tpm": 0,
"current_rpm": 0,
}
new_val = {

View file

@ -262,6 +262,18 @@ class InternalUsageCache:
**kwargs,
)
async def async_batch_get_cache(
self,
keys: list,
parent_otel_span: Optional[Span] = None,
local_only: bool = False,
):
return await self.dual_cache.async_batch_get_cache(
keys=keys,
parent_otel_span=parent_otel_span,
local_only=local_only,
)
async def async_increment_cache(
self,
key,
@ -442,6 +454,8 @@ class ProxyLogging:
litellm._async_success_callback.append(callback) # type: ignore
if callback not in litellm._async_failure_callback:
litellm._async_failure_callback.append(callback) # type: ignore
if callback not in litellm.service_callback:
litellm.service_callback.append(callback) # type: ignore
if (
len(litellm.input_callback) > 0

View file

@ -15,6 +15,7 @@ class ServiceTypes(str, enum.Enum):
BATCH_WRITE_TO_DB = "batch_write_to_db"
LITELLM = "self"
ROUTER = "router"
AUTH = "auth"
class ServiceLoggerPayload(BaseModel):

View file

@ -59,12 +59,15 @@ async def test_dual_cache_async_batch_get_cache():
redis_cache = RedisCache() # get credentials from environment
dual_cache = DualCache(in_memory_cache=in_memory_cache, redis_cache=redis_cache)
in_memory_cache.set_cache(key="test_value", value="hello world")
with patch.object(
dual_cache.redis_cache, "async_batch_get_cache", new=AsyncMock()
) as mock_redis_cache:
mock_redis_cache.return_value = {"test_value_2": None, "test_value": "hello"}
result = await dual_cache.async_batch_get_cache(keys=["test_value", "test_value_2"])
await dual_cache.async_batch_get_cache(keys=["test_value", "test_value_2"])
await dual_cache.async_batch_get_cache(keys=["test_value", "test_value_2"])
assert result[0] == "hello world"
assert result[1] == None
assert mock_redis_cache.call_count == 1
def test_dual_cache_batch_get_cache():

View file

@ -96,6 +96,7 @@ async def test_pre_call_hook():
key=request_count_api_key
)
)
await asyncio.sleep(1)
assert (
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
@ -110,6 +111,7 @@ async def test_pre_call_hook_rpm_limits():
Test if error raised on hitting rpm limits
"""
_api_key = "sk-12345"
_api_key = hash_token(_api_key)
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key, max_parallel_requests=1, tpm_limit=9, rpm_limit=1
)
@ -152,6 +154,7 @@ async def test_pre_call_hook_rpm_limits_retry_after():
Test if rate limit error, returns 'retry_after'
"""
_api_key = "sk-12345"
_api_key = hash_token(_api_key)
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key, max_parallel_requests=1, tpm_limit=9, rpm_limit=1
)
@ -251,6 +254,7 @@ async def test_pre_call_hook_tpm_limits():
Test if error raised on hitting tpm limits
"""
_api_key = "sk-12345"
_api_key = hash_token(_api_key)
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key, max_parallel_requests=1, tpm_limit=9, rpm_limit=10
)
@ -306,9 +310,9 @@ async def test_pre_call_hook_user_tpm_limits():
local_cache.set_cache(key=user_id, value=user_obj)
_api_key = "sk-12345"
_api_key = hash_token(_api_key)
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key,
user_id=user_id,
api_key=_api_key, user_id=user_id, user_rpm_limit=10, user_tpm_limit=9
)
res = dict(user_api_key_dict)
print("dict user", res)
@ -372,7 +376,7 @@ async def test_success_call_hook():
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
await asyncio.sleep(1)
assert (
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
@ -416,7 +420,7 @@ async def test_failure_call_hook():
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
await asyncio.sleep(1)
assert (
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
@ -497,7 +501,7 @@ async def test_normal_router_call():
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
await asyncio.sleep(1)
assert (
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
@ -579,7 +583,7 @@ async def test_normal_router_tpm_limit():
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
print("Test: Checking current_requests for precise_minute=", precise_minute)
await asyncio.sleep(1)
assert (
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
@ -658,7 +662,7 @@ async def test_streaming_router_call():
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
await asyncio.sleep(1)
assert (
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
@ -736,7 +740,7 @@ async def test_streaming_router_tpm_limit():
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
await asyncio.sleep(1)
assert (
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
@ -814,7 +818,7 @@ async def test_bad_router_call():
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
await asyncio.sleep(1)
assert (
parallel_request_handler.internal_usage_cache.get_cache( # type: ignore
key=request_count_api_key
@ -890,7 +894,7 @@ async def test_bad_router_tpm_limit():
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
await asyncio.sleep(1)
assert (
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
@ -979,7 +983,7 @@ async def test_bad_router_tpm_limit_per_model():
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{model}::{precise_minute}::request_count"
await asyncio.sleep(1)
print(
"internal usage cache: ",
parallel_request_handler.internal_usage_cache.dual_cache.in_memory_cache.cache_dict,

View file

@ -139,6 +139,7 @@ async def test_check_blocked_team():
def test_returned_user_api_key_auth(user_role, expected_role):
from litellm.proxy._types import LiteLLM_UserTable, LitellmUserRoles
from litellm.proxy.auth.user_api_key_auth import _return_user_api_key_auth_obj
from datetime import datetime
new_obj = _return_user_api_key_auth_obj(
user_obj=LiteLLM_UserTable(
@ -148,6 +149,7 @@ def test_returned_user_api_key_auth(user_role, expected_role):
parent_otel_span=None,
valid_token_dict={},
route="/chat/completion",
start_time=datetime.now(),
)
assert new_obj.user_role == expected_role