Merge pull request #2775 from BerriAI/litellm_redis_user_api_key_cache_v3

fix(tpm_rpm_limiter.py): enable redis caching for tpm/rpm checks on keys/user/teams
This commit is contained in:
Krish Dholakia 2024-03-30 22:07:05 -07:00 committed by GitHub
commit 1356f6cd32
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 461 additions and 14 deletions

View file

@ -102,7 +102,7 @@ from litellm.proxy.secret_managers.google_kms import load_google_kms
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
import pydantic
from litellm.proxy._types import *
from litellm.caching import DualCache
from litellm.caching import DualCache, RedisCache
from litellm.proxy.health_check import perform_health_check
from litellm._logging import verbose_router_logger, verbose_proxy_logger
from litellm.proxy.auth.handle_jwt import JWTHandler
@ -281,6 +281,9 @@ otel_logging = False
prisma_client: Optional[PrismaClient] = None
custom_db_client: Optional[DBClient] = None
user_api_key_cache = DualCache()
redis_usage_cache: Optional[RedisCache] = (
None # redis cache used for tracking spend, tpm/rpm limits
)
user_custom_auth = None
user_custom_key_generate = None
use_background_health_checks = None
@ -299,7 +302,9 @@ disable_spend_logs = False
jwt_handler = JWTHandler()
prompt_injection_detection_obj: Optional[_OPTIONAL_PromptInjectionDetection] = None
### INITIALIZE GLOBAL LOGGING OBJECT ###
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
proxy_logging_obj = ProxyLogging(
user_api_key_cache=user_api_key_cache, redis_usage_cache=redis_usage_cache
)
### REDIS QUEUE ###
async_result = None
celery_app_conn = None
@ -909,6 +914,10 @@ async def user_api_key_auth(
models=valid_token.team_models,
)
user_api_key_cache.set_cache(
key=valid_token.team_id, value=_team_obj
) # save team table in cache - used for tpm/rpm limiting - tpm_rpm_limiter.py
_end_user_object = None
if "user" in request_data:
_id = "end_user_id:{}".format(request_data["user"])
@ -1905,7 +1914,7 @@ class ProxyConfig:
"""
Load config values into proxy global state
"""
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache
# Load existing config
config = await self.get_config(config_file_path=config_file_path)
@ -1967,6 +1976,7 @@ class ProxyConfig:
"password": cache_password,
}
)
# Assuming cache_type, cache_host, cache_port, and cache_password are strings
print( # noqa
f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}"
@ -1991,7 +2001,14 @@ class ProxyConfig:
cache_params[key] = litellm.get_secret(value)
## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables
litellm.cache = Cache(**cache_params)
if litellm.cache is not None and isinstance(
litellm.cache.cache, RedisCache
):
## INIT PROXY REDIS USAGE CLIENT ##
redis_usage_cache = litellm.cache.cache
print( # noqa
f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}"
)