forked from phoenix/litellm-mirror
Litellm dev 11 02 2024 (#6561)
* fix(dual_cache.py): update in-memory check for redis batch get cache Fixes latency delay for async_batch_redis_cache * fix(service_logger.py): fix race condition causing otel service logging to be overwritten if service_callbacks set * feat(user_api_key_auth.py): add parent otel component for auth allows us to isolate how much latency is added by auth checks * perf(parallel_request_limiter.py): move async_set_cache_pipeline (from max parallel request limiter) out of execution path (background task) reduces latency by 200ms * feat(user_api_key_auth.py): have user api key auth object return user tpm/rpm limits - reduces redis calls in downstream task (parallel_request_limiter) Reduces latency by 400-800ms * fix(parallel_request_limiter.py): use batch get cache to reduce user/key/team usage object calls reduces latency by 50-100ms * fix: fix linting error * fix(_service_logger.py): fix import * fix(user_api_key_auth.py): fix service logging * fix(dual_cache.py): don't pass 'self' * fix: fix python3.8 error * fix: fix init]
This commit is contained in:
parent
587d5fe277
commit
d88e8922d4
17 changed files with 303 additions and 157 deletions
|
@ -173,6 +173,7 @@ cache: Optional[Cache] = (
|
|||
)
|
||||
default_in_memory_ttl: Optional[float] = None
|
||||
default_redis_ttl: Optional[float] = None
|
||||
default_redis_batch_cache_expiry: Optional[float] = None
|
||||
model_alias_map: Dict[str, str] = {}
|
||||
model_group_alias_map: Dict[str, str] = {}
|
||||
max_budget: float = 0.0 # set the max budget across all providers
|
||||
|
|
|
@ -13,9 +13,13 @@ from .types.services import ServiceLoggerPayload, ServiceTypes
|
|||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
from litellm.integrations.opentelemetry import OpenTelemetry
|
||||
|
||||
Span = _Span
|
||||
OTELClass = OpenTelemetry
|
||||
else:
|
||||
Span = Any
|
||||
OTELClass = Any
|
||||
|
||||
|
||||
class ServiceLogging(CustomLogger):
|
||||
|
@ -111,6 +115,7 @@ class ServiceLogging(CustomLogger):
|
|||
"""
|
||||
- For counting if the redis, postgres call is successful
|
||||
"""
|
||||
from litellm.integrations.opentelemetry import OpenTelemetry
|
||||
|
||||
if self.mock_testing:
|
||||
self.mock_testing_async_success_hook += 1
|
||||
|
@ -122,6 +127,7 @@ class ServiceLogging(CustomLogger):
|
|||
duration=duration,
|
||||
call_type=call_type,
|
||||
)
|
||||
|
||||
for callback in litellm.service_callback:
|
||||
if callback == "prometheus_system":
|
||||
await self.init_prometheus_services_logger_if_none()
|
||||
|
@ -139,8 +145,7 @@ class ServiceLogging(CustomLogger):
|
|||
end_time=end_time,
|
||||
event_metadata=event_metadata,
|
||||
)
|
||||
elif callback == "otel":
|
||||
from litellm.integrations.opentelemetry import OpenTelemetry
|
||||
elif callback == "otel" or isinstance(callback, OpenTelemetry):
|
||||
from litellm.proxy.proxy_server import open_telemetry_logger
|
||||
|
||||
await self.init_otel_logger_if_none()
|
||||
|
@ -214,6 +219,8 @@ class ServiceLogging(CustomLogger):
|
|||
"""
|
||||
- For counting if the redis, postgres call is unsuccessful
|
||||
"""
|
||||
from litellm.integrations.opentelemetry import OpenTelemetry
|
||||
|
||||
if self.mock_testing:
|
||||
self.mock_testing_async_failure_hook += 1
|
||||
|
||||
|
@ -246,8 +253,7 @@ class ServiceLogging(CustomLogger):
|
|||
end_time=end_time,
|
||||
event_metadata=event_metadata,
|
||||
)
|
||||
elif callback == "otel":
|
||||
from litellm.integrations.opentelemetry import OpenTelemetry
|
||||
elif callback == "otel" or isinstance(callback, OpenTelemetry):
|
||||
from litellm.proxy.proxy_server import open_telemetry_logger
|
||||
|
||||
await self.init_otel_logger_if_none()
|
||||
|
|
|
@ -8,8 +8,10 @@ Has 4 primary methods:
|
|||
- async_get_cache
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
|
||||
|
||||
import litellm
|
||||
|
@ -40,6 +42,7 @@ class LimitedSizeOrderedDict(OrderedDict):
|
|||
self.popitem(last=False)
|
||||
super().__setitem__(key, value)
|
||||
|
||||
|
||||
class DualCache(BaseCache):
|
||||
"""
|
||||
DualCache is a cache implementation that updates both Redis and an in-memory cache simultaneously.
|
||||
|
@ -53,7 +56,7 @@ class DualCache(BaseCache):
|
|||
redis_cache: Optional[RedisCache] = None,
|
||||
default_in_memory_ttl: Optional[float] = None,
|
||||
default_redis_ttl: Optional[float] = None,
|
||||
default_redis_batch_cache_expiry: float = 1,
|
||||
default_redis_batch_cache_expiry: Optional[float] = None,
|
||||
default_max_redis_batch_cache_size: int = 100,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
@ -64,7 +67,11 @@ class DualCache(BaseCache):
|
|||
self.last_redis_batch_access_time = LimitedSizeOrderedDict(
|
||||
max_size=default_max_redis_batch_cache_size
|
||||
)
|
||||
self.redis_batch_cache_expiry = default_redis_batch_cache_expiry
|
||||
self.redis_batch_cache_expiry = (
|
||||
default_redis_batch_cache_expiry
|
||||
or litellm.default_redis_batch_cache_expiry
|
||||
or 5
|
||||
)
|
||||
self.default_in_memory_ttl = (
|
||||
default_in_memory_ttl or litellm.default_in_memory_ttl
|
||||
)
|
||||
|
@ -156,52 +163,33 @@ class DualCache(BaseCache):
|
|||
local_only: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
received_args = locals()
|
||||
received_args.pop("self")
|
||||
|
||||
def run_in_new_loop():
|
||||
"""Run the coroutine in a new event loop within this thread."""
|
||||
new_loop = asyncio.new_event_loop()
|
||||
try:
|
||||
asyncio.set_event_loop(new_loop)
|
||||
return new_loop.run_until_complete(
|
||||
self.async_batch_get_cache(**received_args)
|
||||
)
|
||||
finally:
|
||||
new_loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
try:
|
||||
result = [None for _ in range(len(keys))]
|
||||
if self.in_memory_cache is not None:
|
||||
in_memory_result = self.in_memory_cache.batch_get_cache(keys, **kwargs)
|
||||
# First, try to get the current event loop
|
||||
_ = asyncio.get_running_loop()
|
||||
# If we're already in an event loop, run in a separate thread
|
||||
# to avoid nested event loop issues
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(run_in_new_loop)
|
||||
return future.result()
|
||||
|
||||
if in_memory_result is not None:
|
||||
result = in_memory_result
|
||||
|
||||
if None in result and self.redis_cache is not None and local_only is False:
|
||||
"""
|
||||
- for the none values in the result
|
||||
- check the redis cache
|
||||
"""
|
||||
# Track the last access time for these keys
|
||||
current_time = time.time()
|
||||
key_tuple = tuple(keys)
|
||||
|
||||
# Only hit Redis if the last access time was more than 5 seconds ago
|
||||
if (
|
||||
key_tuple not in self.last_redis_batch_access_time
|
||||
or current_time - self.last_redis_batch_access_time[key_tuple]
|
||||
>= self.redis_batch_cache_expiry
|
||||
):
|
||||
|
||||
sublist_keys = [
|
||||
key for key, value in zip(keys, result) if value is None
|
||||
]
|
||||
# If not found in in-memory cache, try fetching from Redis
|
||||
redis_result = self.redis_cache.batch_get_cache(
|
||||
sublist_keys, parent_otel_span=parent_otel_span
|
||||
)
|
||||
if redis_result is not None:
|
||||
# Update in-memory cache with the value from Redis
|
||||
for key in redis_result:
|
||||
self.in_memory_cache.set_cache(
|
||||
key, redis_result[key], **kwargs
|
||||
)
|
||||
|
||||
|
||||
for key, value in redis_result.items():
|
||||
result[keys.index(key)] = value
|
||||
|
||||
print_verbose(f"async batch get cache: cache result: {result}")
|
||||
return result
|
||||
except Exception:
|
||||
verbose_logger.error(traceback.format_exc())
|
||||
except RuntimeError:
|
||||
# No running event loop, we can safely run in this thread
|
||||
return run_in_new_loop()
|
||||
|
||||
async def async_get_cache(
|
||||
self,
|
||||
|
@ -244,6 +232,23 @@ class DualCache(BaseCache):
|
|||
except Exception:
|
||||
verbose_logger.error(traceback.format_exc())
|
||||
|
||||
def get_redis_batch_keys(
|
||||
self,
|
||||
current_time: float,
|
||||
keys: List[str],
|
||||
result: List[Any],
|
||||
) -> List[str]:
|
||||
sublist_keys = []
|
||||
for key, value in zip(keys, result):
|
||||
if value is None:
|
||||
if (
|
||||
key not in self.last_redis_batch_access_time
|
||||
or current_time - self.last_redis_batch_access_time[key]
|
||||
>= self.redis_batch_cache_expiry
|
||||
):
|
||||
sublist_keys.append(key)
|
||||
return sublist_keys
|
||||
|
||||
async def async_batch_get_cache(
|
||||
self,
|
||||
keys: list,
|
||||
|
@ -266,25 +271,16 @@ class DualCache(BaseCache):
|
|||
- for the none values in the result
|
||||
- check the redis cache
|
||||
"""
|
||||
# Track the last access time for these keys
|
||||
current_time = time.time()
|
||||
key_tuple = tuple(keys)
|
||||
sublist_keys = self.get_redis_batch_keys(current_time, keys, result)
|
||||
|
||||
# Only hit Redis if the last access time was more than 5 seconds ago
|
||||
if (
|
||||
key_tuple not in self.last_redis_batch_access_time
|
||||
or current_time - self.last_redis_batch_access_time[key_tuple]
|
||||
>= self.redis_batch_cache_expiry
|
||||
):
|
||||
sublist_keys = [
|
||||
key for key, value in zip(keys, result) if value is None
|
||||
]
|
||||
if len(sublist_keys) > 0:
|
||||
# If not found in in-memory cache, try fetching from Redis
|
||||
redis_result = await self.redis_cache.async_batch_get_cache(
|
||||
sublist_keys, parent_otel_span=parent_otel_span
|
||||
)
|
||||
|
||||
|
||||
if redis_result is not None:
|
||||
# Update in-memory cache with the value from Redis
|
||||
for key, value in redis_result.items():
|
||||
|
@ -292,6 +288,9 @@ class DualCache(BaseCache):
|
|||
await self.in_memory_cache.async_set_cache(
|
||||
key, redis_result[key], **kwargs
|
||||
)
|
||||
# Update the last access time for each key fetched from Redis
|
||||
self.last_redis_batch_access_time[key] = current_time
|
||||
|
||||
for key, value in redis_result.items():
|
||||
index = keys.index(key)
|
||||
result[index] = value
|
||||
|
|
|
@ -732,7 +732,6 @@ class RedisCache(BaseCache):
|
|||
"""
|
||||
Use Redis for bulk read operations
|
||||
"""
|
||||
|
||||
_redis_client = await self.init_async_client()
|
||||
key_value_dict = {}
|
||||
start_time = time.time()
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -3,8 +3,7 @@ model_list:
|
|||
litellm_params:
|
||||
model: claude-3-5-sonnet-20240620
|
||||
api_key: os.environ/ANTHROPIC_API_KEY
|
||||
api_base: "http://0.0.0.0:8000"
|
||||
- model_name: my-fallback-openai-model
|
||||
- model_name: claude-3-5-sonnet-aihubmix
|
||||
litellm_params:
|
||||
model: openai/claude-3-5-sonnet-20240620
|
||||
input_cost_per_token: 0.000003 # 3$/M
|
||||
|
@ -15,9 +14,35 @@ model_list:
|
|||
litellm_params:
|
||||
model: gemini/gemini-1.5-flash-002
|
||||
|
||||
# litellm_settings:
|
||||
# fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }]
|
||||
# callbacks: ["otel", "prometheus"]
|
||||
# default_redis_batch_cache_expiry: 10
|
||||
|
||||
|
||||
litellm_settings:
|
||||
fallbacks: [{ "claude-3-5-sonnet-20240620": ["my-fallback-openai-model"] }]
|
||||
callbacks: ["otel", "prometheus"]
|
||||
cache: True
|
||||
cache_params:
|
||||
type: redis
|
||||
|
||||
# disable caching on the actual API call
|
||||
supported_call_types: []
|
||||
|
||||
# see https://docs.litellm.ai/docs/proxy/prod#3-use-redis-porthost-password-not-redis_url
|
||||
host: os.environ/REDIS_HOST
|
||||
port: os.environ/REDIS_PORT
|
||||
password: os.environ/REDIS_PASSWORD
|
||||
|
||||
# see https://docs.litellm.ai/docs/proxy/caching#turn-on-batch_redis_requests
|
||||
# see https://docs.litellm.ai/docs/proxy/prometheus
|
||||
callbacks: ['prometheus', 'otel']
|
||||
|
||||
# # see https://docs.litellm.ai/docs/proxy/logging#logging-proxy-inputoutput---sentry
|
||||
failure_callback: ['sentry']
|
||||
service_callback: ['prometheus_system']
|
||||
|
||||
# redact_user_api_key_info: true
|
||||
|
||||
|
||||
router_settings:
|
||||
routing_strategy: latency-based-routing
|
||||
|
@ -29,4 +54,19 @@ router_settings:
|
|||
ttl: 300
|
||||
redis_host: os.environ/REDIS_HOST
|
||||
redis_port: os.environ/REDIS_PORT
|
||||
redis_password: os.environ/REDIS_PASSWORD
|
||||
redis_password: os.environ/REDIS_PASSWORD
|
||||
|
||||
# see https://docs.litellm.ai/docs/proxy/prod#1-use-this-configyaml
|
||||
general_settings:
|
||||
master_key: os.environ/LITELLM_MASTER_KEY
|
||||
database_url: os.environ/DATABASE_URL
|
||||
disable_master_key_return: true
|
||||
# alerting: ['slack', 'email']
|
||||
alerting: ['email']
|
||||
|
||||
# Batch write spend updates every 60s
|
||||
proxy_batch_write_at: 60
|
||||
|
||||
# see https://docs.litellm.ai/docs/proxy/caching#advanced---user-api-key-cache-ttl
|
||||
# our api keys rarely change
|
||||
user_api_key_cache_ttl: 3600
|
|
@ -1419,6 +1419,8 @@ class UserAPIKeyAuth(
|
|||
parent_otel_span: Optional[Span] = None
|
||||
rpm_limit_per_model: Optional[Dict[str, int]] = None
|
||||
tpm_limit_per_model: Optional[Dict[str, int]] = None
|
||||
user_tpm_limit: Optional[int] = None
|
||||
user_rpm_limit: Optional[int] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
|
|
@ -9,6 +9,7 @@ Run checks for:
|
|||
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||
"""
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, List, Literal, Optional
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ Returns a UserAPIKeyAuth object if the API key is valid
|
|||
import asyncio
|
||||
import json
|
||||
import secrets
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, Tuple
|
||||
|
@ -44,6 +45,7 @@ from pydantic import BaseModel
|
|||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
from litellm._service_logger import ServiceLogging
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.auth_checks import (
|
||||
_cache_key_object,
|
||||
|
@ -73,6 +75,10 @@ from litellm.proxy.auth.route_checks import RouteChecks
|
|||
from litellm.proxy.auth.service_account_checks import service_account_checks
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
from litellm.proxy.utils import _to_ns
|
||||
from litellm.types.services import ServiceTypes
|
||||
|
||||
user_api_key_service_logger_obj = ServiceLogging() # used for tracking latency on OTEL
|
||||
|
||||
|
||||
api_key_header = APIKeyHeader(
|
||||
name=SpecialHeaders.openai_authorization.value,
|
||||
|
@ -214,7 +220,7 @@ async def user_api_key_auth( # noqa: PLR0915
|
|||
)
|
||||
|
||||
parent_otel_span: Optional[Span] = None
|
||||
|
||||
start_time = datetime.now()
|
||||
try:
|
||||
route: str = get_request_route(request=request)
|
||||
# get the request body
|
||||
|
@ -255,7 +261,7 @@ async def user_api_key_auth( # noqa: PLR0915
|
|||
if open_telemetry_logger is not None:
|
||||
parent_otel_span = open_telemetry_logger.tracer.start_span(
|
||||
name="Received Proxy Server Request",
|
||||
start_time=_to_ns(datetime.now()),
|
||||
start_time=_to_ns(start_time),
|
||||
context=open_telemetry_logger.get_traceparent_from_header(
|
||||
headers=request.headers
|
||||
),
|
||||
|
@ -1165,6 +1171,7 @@ async def user_api_key_auth( # noqa: PLR0915
|
|||
parent_otel_span=parent_otel_span,
|
||||
valid_token_dict=valid_token_dict,
|
||||
route=route,
|
||||
start_time=start_time,
|
||||
)
|
||||
else:
|
||||
raise Exception()
|
||||
|
@ -1219,31 +1226,39 @@ def _return_user_api_key_auth_obj(
|
|||
parent_otel_span: Optional[Span],
|
||||
valid_token_dict: dict,
|
||||
route: str,
|
||||
start_time: datetime,
|
||||
) -> UserAPIKeyAuth:
|
||||
end_time = datetime.now()
|
||||
user_api_key_service_logger_obj.service_success_hook(
|
||||
service=ServiceTypes.AUTH,
|
||||
call_type=route,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
duration=end_time.timestamp() - start_time.timestamp(),
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
retrieved_user_role = (
|
||||
_get_user_role(user_obj=user_obj) or LitellmUserRoles.INTERNAL_USER
|
||||
)
|
||||
|
||||
user_api_key_kwargs = {
|
||||
"api_key": api_key,
|
||||
"parent_otel_span": parent_otel_span,
|
||||
"user_role": retrieved_user_role,
|
||||
**valid_token_dict,
|
||||
}
|
||||
if user_obj is not None:
|
||||
user_api_key_kwargs.update(
|
||||
user_tpm_limit=user_obj.tpm_limit,
|
||||
user_rpm_limit=user_obj.rpm_limit,
|
||||
)
|
||||
if user_obj is not None and _is_user_proxy_admin(user_obj=user_obj):
|
||||
return UserAPIKeyAuth(
|
||||
api_key=api_key,
|
||||
user_api_key_kwargs.update(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
parent_otel_span=parent_otel_span,
|
||||
**valid_token_dict,
|
||||
)
|
||||
elif _has_user_setup_sso() and route in LiteLLMRoutes.sso_only_routes.value:
|
||||
return UserAPIKeyAuth(
|
||||
api_key=api_key,
|
||||
user_role=retrieved_user_role,
|
||||
parent_otel_span=parent_otel_span,
|
||||
**valid_token_dict,
|
||||
)
|
||||
return UserAPIKeyAuth(**user_api_key_kwargs)
|
||||
else:
|
||||
return UserAPIKeyAuth(
|
||||
api_key=api_key,
|
||||
user_role=retrieved_user_role,
|
||||
parent_otel_span=parent_otel_span,
|
||||
**valid_token_dict,
|
||||
)
|
||||
return UserAPIKeyAuth(**user_api_key_kwargs)
|
||||
|
||||
|
||||
def _is_user_proxy_admin(user_obj: Optional[LiteLLM_UserTable]):
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import asyncio
|
||||
import sys
|
||||
import traceback
|
||||
from datetime import datetime, timedelta
|
||||
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
@ -29,6 +30,14 @@ else:
|
|||
InternalUsageCache = Any
|
||||
|
||||
|
||||
class CacheObject(TypedDict):
|
||||
current_global_requests: Optional[dict]
|
||||
request_count_api_key: Optional[dict]
|
||||
request_count_user_id: Optional[dict]
|
||||
request_count_team_id: Optional[dict]
|
||||
request_count_end_user_id: Optional[dict]
|
||||
|
||||
|
||||
class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||
# Class variables or attributes
|
||||
def __init__(self, internal_usage_cache: InternalUsageCache):
|
||||
|
@ -51,14 +60,15 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
max_parallel_requests: int,
|
||||
tpm_limit: int,
|
||||
rpm_limit: int,
|
||||
current: Optional[dict],
|
||||
request_count_api_key: str,
|
||||
rate_limit_type: Literal["user", "customer", "team"],
|
||||
values_to_update_in_cache: List[Tuple[Any, Any]],
|
||||
):
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
key=request_count_api_key,
|
||||
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
|
||||
# current = await self.internal_usage_cache.async_get_cache(
|
||||
# key=request_count_api_key,
|
||||
# litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
# ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
|
||||
if current is None:
|
||||
if max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0:
|
||||
# base case
|
||||
|
@ -117,6 +127,44 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
headers={"retry-after": str(self.time_to_next_minute())},
|
||||
)
|
||||
|
||||
async def get_all_cache_objects(
|
||||
self,
|
||||
current_global_requests: Optional[str],
|
||||
request_count_api_key: Optional[str],
|
||||
request_count_user_id: Optional[str],
|
||||
request_count_team_id: Optional[str],
|
||||
request_count_end_user_id: Optional[str],
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
) -> CacheObject:
|
||||
keys = [
|
||||
current_global_requests,
|
||||
request_count_api_key,
|
||||
request_count_user_id,
|
||||
request_count_team_id,
|
||||
request_count_end_user_id,
|
||||
]
|
||||
results = await self.internal_usage_cache.async_batch_get_cache(
|
||||
keys=keys,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
|
||||
if results is None:
|
||||
return CacheObject(
|
||||
current_global_requests=None,
|
||||
request_count_api_key=None,
|
||||
request_count_user_id=None,
|
||||
request_count_team_id=None,
|
||||
request_count_end_user_id=None,
|
||||
)
|
||||
|
||||
return CacheObject(
|
||||
current_global_requests=results[0],
|
||||
request_count_api_key=results[1],
|
||||
request_count_user_id=results[2],
|
||||
request_count_team_id=results[3],
|
||||
request_count_end_user_id=results[4],
|
||||
)
|
||||
|
||||
async def async_pre_call_hook( # noqa: PLR0915
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
|
@ -149,6 +197,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
# Setup values
|
||||
# ------------
|
||||
new_val: Optional[dict] = None
|
||||
|
||||
if global_max_parallel_requests is not None:
|
||||
# get value from cache
|
||||
_key = "global_max_parallel_requests"
|
||||
|
@ -179,15 +228,40 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
|
||||
cache_objects: CacheObject = await self.get_all_cache_objects(
|
||||
current_global_requests=(
|
||||
"global_max_parallel_requests"
|
||||
if global_max_parallel_requests is not None
|
||||
else None
|
||||
),
|
||||
request_count_api_key=(
|
||||
f"{api_key}::{precise_minute}::request_count"
|
||||
if api_key is not None
|
||||
else None
|
||||
),
|
||||
request_count_user_id=(
|
||||
f"{user_api_key_dict.user_id}::{precise_minute}::request_count"
|
||||
if user_api_key_dict.user_id is not None
|
||||
else None
|
||||
),
|
||||
request_count_team_id=(
|
||||
f"{user_api_key_dict.team_id}::{precise_minute}::request_count"
|
||||
if user_api_key_dict.team_id is not None
|
||||
else None
|
||||
),
|
||||
request_count_end_user_id=(
|
||||
f"{user_api_key_dict.end_user_id}::{precise_minute}::request_count"
|
||||
if user_api_key_dict.end_user_id is not None
|
||||
else None
|
||||
),
|
||||
parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
)
|
||||
if api_key is not None:
|
||||
request_count_api_key = f"{api_key}::{precise_minute}::request_count"
|
||||
|
||||
# CHECK IF REQUEST ALLOWED for key
|
||||
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
key=request_count_api_key,
|
||||
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
|
||||
current = cache_objects["request_count_api_key"]
|
||||
self.print_verbose(f"current: {current}")
|
||||
if (
|
||||
max_parallel_requests == sys.maxsize
|
||||
|
@ -303,42 +377,28 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
# check if REQUEST ALLOWED for user_id
|
||||
user_id = user_api_key_dict.user_id
|
||||
if user_id is not None:
|
||||
_user_id_rate_limits = await self.get_internal_user_object(
|
||||
user_id=user_id,
|
||||
user_tpm_limit = user_api_key_dict.user_tpm_limit
|
||||
user_rpm_limit = user_api_key_dict.user_rpm_limit
|
||||
if user_tpm_limit is None:
|
||||
user_tpm_limit = sys.maxsize
|
||||
if user_rpm_limit is None:
|
||||
user_rpm_limit = sys.maxsize
|
||||
|
||||
request_count_api_key = f"{user_id}::{precise_minute}::request_count"
|
||||
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
|
||||
await self.check_key_in_limits(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=cache,
|
||||
data=data,
|
||||
call_type=call_type,
|
||||
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user
|
||||
current=cache_objects["request_count_user_id"],
|
||||
request_count_api_key=request_count_api_key,
|
||||
tpm_limit=user_tpm_limit,
|
||||
rpm_limit=user_rpm_limit,
|
||||
rate_limit_type="user",
|
||||
values_to_update_in_cache=values_to_update_in_cache,
|
||||
)
|
||||
# get user tpm/rpm limits
|
||||
if (
|
||||
_user_id_rate_limits is not None
|
||||
and isinstance(_user_id_rate_limits, dict)
|
||||
and (
|
||||
_user_id_rate_limits.get("tpm_limit", None) is not None
|
||||
or _user_id_rate_limits.get("rpm_limit", None) is not None
|
||||
)
|
||||
):
|
||||
user_tpm_limit = _user_id_rate_limits.get("tpm_limit", None)
|
||||
user_rpm_limit = _user_id_rate_limits.get("rpm_limit", None)
|
||||
if user_tpm_limit is None:
|
||||
user_tpm_limit = sys.maxsize
|
||||
if user_rpm_limit is None:
|
||||
user_rpm_limit = sys.maxsize
|
||||
|
||||
# now do the same tpm/rpm checks
|
||||
request_count_api_key = f"{user_id}::{precise_minute}::request_count"
|
||||
|
||||
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
|
||||
await self.check_key_in_limits(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=cache,
|
||||
data=data,
|
||||
call_type=call_type,
|
||||
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user
|
||||
request_count_api_key=request_count_api_key,
|
||||
tpm_limit=user_tpm_limit,
|
||||
rpm_limit=user_rpm_limit,
|
||||
rate_limit_type="user",
|
||||
values_to_update_in_cache=values_to_update_in_cache,
|
||||
)
|
||||
|
||||
# TEAM RATE LIMITS
|
||||
## get team tpm/rpm limits
|
||||
|
@ -352,9 +412,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
if team_rpm_limit is None:
|
||||
team_rpm_limit = sys.maxsize
|
||||
|
||||
# now do the same tpm/rpm checks
|
||||
request_count_api_key = f"{team_id}::{precise_minute}::request_count"
|
||||
|
||||
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
|
||||
await self.check_key_in_limits(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
|
@ -362,6 +420,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
data=data,
|
||||
call_type=call_type,
|
||||
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a team
|
||||
current=cache_objects["request_count_team_id"],
|
||||
request_count_api_key=request_count_api_key,
|
||||
tpm_limit=team_tpm_limit,
|
||||
rpm_limit=team_rpm_limit,
|
||||
|
@ -397,16 +456,19 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
call_type=call_type,
|
||||
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for an End-User
|
||||
request_count_api_key=request_count_api_key,
|
||||
current=cache_objects["request_count_end_user_id"],
|
||||
tpm_limit=end_user_tpm_limit,
|
||||
rpm_limit=end_user_rpm_limit,
|
||||
rate_limit_type="customer",
|
||||
values_to_update_in_cache=values_to_update_in_cache,
|
||||
)
|
||||
|
||||
await self.internal_usage_cache.async_batch_set_cache(
|
||||
cache_list=values_to_update_in_cache,
|
||||
ttl=60,
|
||||
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
asyncio.create_task(
|
||||
self.internal_usage_cache.async_batch_set_cache(
|
||||
cache_list=values_to_update_in_cache,
|
||||
ttl=60,
|
||||
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
) # don't block execution for cache updates
|
||||
)
|
||||
|
||||
return
|
||||
|
@ -481,8 +543,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
) or {
|
||||
"current_requests": 1,
|
||||
"current_tpm": total_tokens,
|
||||
"current_rpm": 1,
|
||||
"current_tpm": 0,
|
||||
"current_rpm": 0,
|
||||
}
|
||||
|
||||
new_val = {
|
||||
|
@ -517,8 +579,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
) or {
|
||||
"current_requests": 1,
|
||||
"current_tpm": total_tokens,
|
||||
"current_rpm": 1,
|
||||
"current_tpm": 0,
|
||||
"current_rpm": 0,
|
||||
}
|
||||
|
||||
new_val = {
|
||||
|
|
|
@ -262,6 +262,18 @@ class InternalUsageCache:
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
async def async_batch_get_cache(
|
||||
self,
|
||||
keys: list,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
local_only: bool = False,
|
||||
):
|
||||
return await self.dual_cache.async_batch_get_cache(
|
||||
keys=keys,
|
||||
parent_otel_span=parent_otel_span,
|
||||
local_only=local_only,
|
||||
)
|
||||
|
||||
async def async_increment_cache(
|
||||
self,
|
||||
key,
|
||||
|
@ -442,6 +454,8 @@ class ProxyLogging:
|
|||
litellm._async_success_callback.append(callback) # type: ignore
|
||||
if callback not in litellm._async_failure_callback:
|
||||
litellm._async_failure_callback.append(callback) # type: ignore
|
||||
if callback not in litellm.service_callback:
|
||||
litellm.service_callback.append(callback) # type: ignore
|
||||
|
||||
if (
|
||||
len(litellm.input_callback) > 0
|
||||
|
|
|
@ -15,6 +15,7 @@ class ServiceTypes(str, enum.Enum):
|
|||
BATCH_WRITE_TO_DB = "batch_write_to_db"
|
||||
LITELLM = "self"
|
||||
ROUTER = "router"
|
||||
AUTH = "auth"
|
||||
|
||||
|
||||
class ServiceLoggerPayload(BaseModel):
|
||||
|
|
|
@ -59,12 +59,15 @@ async def test_dual_cache_async_batch_get_cache():
|
|||
redis_cache = RedisCache() # get credentials from environment
|
||||
dual_cache = DualCache(in_memory_cache=in_memory_cache, redis_cache=redis_cache)
|
||||
|
||||
in_memory_cache.set_cache(key="test_value", value="hello world")
|
||||
with patch.object(
|
||||
dual_cache.redis_cache, "async_batch_get_cache", new=AsyncMock()
|
||||
) as mock_redis_cache:
|
||||
mock_redis_cache.return_value = {"test_value_2": None, "test_value": "hello"}
|
||||
|
||||
result = await dual_cache.async_batch_get_cache(keys=["test_value", "test_value_2"])
|
||||
await dual_cache.async_batch_get_cache(keys=["test_value", "test_value_2"])
|
||||
await dual_cache.async_batch_get_cache(keys=["test_value", "test_value_2"])
|
||||
|
||||
assert result[0] == "hello world"
|
||||
assert result[1] == None
|
||||
assert mock_redis_cache.call_count == 1
|
||||
|
||||
|
||||
def test_dual_cache_batch_get_cache():
|
||||
|
|
|
@ -96,6 +96,7 @@ async def test_pre_call_hook():
|
|||
key=request_count_api_key
|
||||
)
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
assert (
|
||||
parallel_request_handler.internal_usage_cache.get_cache(
|
||||
key=request_count_api_key
|
||||
|
@ -110,6 +111,7 @@ async def test_pre_call_hook_rpm_limits():
|
|||
Test if error raised on hitting rpm limits
|
||||
"""
|
||||
_api_key = "sk-12345"
|
||||
_api_key = hash_token(_api_key)
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key=_api_key, max_parallel_requests=1, tpm_limit=9, rpm_limit=1
|
||||
)
|
||||
|
@ -152,6 +154,7 @@ async def test_pre_call_hook_rpm_limits_retry_after():
|
|||
Test if rate limit error, returns 'retry_after'
|
||||
"""
|
||||
_api_key = "sk-12345"
|
||||
_api_key = hash_token(_api_key)
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key=_api_key, max_parallel_requests=1, tpm_limit=9, rpm_limit=1
|
||||
)
|
||||
|
@ -251,6 +254,7 @@ async def test_pre_call_hook_tpm_limits():
|
|||
Test if error raised on hitting tpm limits
|
||||
"""
|
||||
_api_key = "sk-12345"
|
||||
_api_key = hash_token(_api_key)
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key=_api_key, max_parallel_requests=1, tpm_limit=9, rpm_limit=10
|
||||
)
|
||||
|
@ -306,9 +310,9 @@ async def test_pre_call_hook_user_tpm_limits():
|
|||
local_cache.set_cache(key=user_id, value=user_obj)
|
||||
|
||||
_api_key = "sk-12345"
|
||||
_api_key = hash_token(_api_key)
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key=_api_key,
|
||||
user_id=user_id,
|
||||
api_key=_api_key, user_id=user_id, user_rpm_limit=10, user_tpm_limit=9
|
||||
)
|
||||
res = dict(user_api_key_dict)
|
||||
print("dict user", res)
|
||||
|
@ -372,7 +376,7 @@ async def test_success_call_hook():
|
|||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
|
||||
|
||||
await asyncio.sleep(1)
|
||||
assert (
|
||||
parallel_request_handler.internal_usage_cache.get_cache(
|
||||
key=request_count_api_key
|
||||
|
@ -416,7 +420,7 @@ async def test_failure_call_hook():
|
|||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
|
||||
|
||||
await asyncio.sleep(1)
|
||||
assert (
|
||||
parallel_request_handler.internal_usage_cache.get_cache(
|
||||
key=request_count_api_key
|
||||
|
@ -497,7 +501,7 @@ async def test_normal_router_call():
|
|||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
|
||||
|
||||
await asyncio.sleep(1)
|
||||
assert (
|
||||
parallel_request_handler.internal_usage_cache.get_cache(
|
||||
key=request_count_api_key
|
||||
|
@ -579,7 +583,7 @@ async def test_normal_router_tpm_limit():
|
|||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
|
||||
print("Test: Checking current_requests for precise_minute=", precise_minute)
|
||||
|
||||
await asyncio.sleep(1)
|
||||
assert (
|
||||
parallel_request_handler.internal_usage_cache.get_cache(
|
||||
key=request_count_api_key
|
||||
|
@ -658,7 +662,7 @@ async def test_streaming_router_call():
|
|||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
|
||||
|
||||
await asyncio.sleep(1)
|
||||
assert (
|
||||
parallel_request_handler.internal_usage_cache.get_cache(
|
||||
key=request_count_api_key
|
||||
|
@ -736,7 +740,7 @@ async def test_streaming_router_tpm_limit():
|
|||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
|
||||
|
||||
await asyncio.sleep(1)
|
||||
assert (
|
||||
parallel_request_handler.internal_usage_cache.get_cache(
|
||||
key=request_count_api_key
|
||||
|
@ -814,7 +818,7 @@ async def test_bad_router_call():
|
|||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
|
||||
|
||||
await asyncio.sleep(1)
|
||||
assert (
|
||||
parallel_request_handler.internal_usage_cache.get_cache( # type: ignore
|
||||
key=request_count_api_key
|
||||
|
@ -890,7 +894,7 @@ async def test_bad_router_tpm_limit():
|
|||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
|
||||
|
||||
await asyncio.sleep(1)
|
||||
assert (
|
||||
parallel_request_handler.internal_usage_cache.get_cache(
|
||||
key=request_count_api_key
|
||||
|
@ -979,7 +983,7 @@ async def test_bad_router_tpm_limit_per_model():
|
|||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
request_count_api_key = f"{_api_key}::{model}::{precise_minute}::request_count"
|
||||
|
||||
await asyncio.sleep(1)
|
||||
print(
|
||||
"internal usage cache: ",
|
||||
parallel_request_handler.internal_usage_cache.dual_cache.in_memory_cache.cache_dict,
|
||||
|
|
|
@ -139,6 +139,7 @@ async def test_check_blocked_team():
|
|||
def test_returned_user_api_key_auth(user_role, expected_role):
|
||||
from litellm.proxy._types import LiteLLM_UserTable, LitellmUserRoles
|
||||
from litellm.proxy.auth.user_api_key_auth import _return_user_api_key_auth_obj
|
||||
from datetime import datetime
|
||||
|
||||
new_obj = _return_user_api_key_auth_obj(
|
||||
user_obj=LiteLLM_UserTable(
|
||||
|
@ -148,6 +149,7 @@ def test_returned_user_api_key_auth(user_role, expected_role):
|
|||
parent_otel_span=None,
|
||||
valid_token_dict={},
|
||||
route="/chat/completion",
|
||||
start_time=datetime.now(),
|
||||
)
|
||||
|
||||
assert new_obj.user_role == expected_role
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue