diff --git a/litellm/caching.py b/litellm/caching.py index 6b58cf5276..dde41ad29e 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -7,14 +7,20 @@ # # Thank you users! We ❤️ you! - Krrish & Ishaan -import litellm -import time, logging, asyncio -import json, traceback, ast, hashlib -from typing import Optional, Literal, List, Union, Any, BinaryIO +import ast +import asyncio +import hashlib +import json +import logging +import time +import traceback +from typing import Any, BinaryIO, List, Literal, Optional, Union + from openai._models import BaseModel as OpenAIObject + +import litellm from litellm._logging import verbose_logger from litellm.types.services import ServiceLoggerPayload, ServiceTypes -import traceback def print_verbose(print_statement): @@ -57,10 +63,12 @@ class BaseCache: class InMemoryCache(BaseCache): - def __init__(self): + def __init__(self, default_ttl: Optional[float] = 60.0): # if users don't provider one, use the default litellm cache - self.cache_dict = {} - self.ttl_dict = {} + self.cache_dict: dict = {} + self.ttl_dict: dict = {} + self.default_ttl = default_ttl + self.last_cleaned = 0 # since this is in memory we need to periodically clean it up to not overuse the machines RAM def set_cache(self, key, value, **kwargs): print_verbose("InMemoryCache: set_cache") @@ -70,6 +78,8 @@ class InMemoryCache(BaseCache): async def async_set_cache(self, key, value, **kwargs): self.set_cache(key=key, value=value, **kwargs) + if time.time() > self.last_cleaned: + asyncio.create_task(self.clean_up_in_memory_cache()) async def async_set_cache_pipeline(self, cache_list, ttl=None): for cache_key, cache_value in cache_list: @@ -78,6 +88,9 @@ class InMemoryCache(BaseCache): else: self.set_cache(key=cache_key, value=cache_value) + if time.time() > self.last_cleaned: + asyncio.create_task(self.clean_up_in_memory_cache()) + def get_cache(self, key, **kwargs): if key in self.cache_dict: if key in self.ttl_dict: @@ -121,8 +134,26 @@ class InMemoryCache(BaseCache): init_value = await self.async_get_cache(key=key) or 0 value = init_value + value await self.async_set_cache(key, value, **kwargs) + + if time.time() > self.last_cleaned: + asyncio.create_task(self.clean_up_in_memory_cache()) + return value + async def clean_up_in_memory_cache(self): + """ + Runs periodically to clean up the in-memory cache + + - loop through all keys in cache, check if they are expired + - if yes, delete them + """ + for key in list(self.cache_dict.keys()): + if key in self.ttl_dict: + if time.time() > self.ttl_dict[key]: + self.cache_dict.pop(key, None) + self.ttl_dict.pop(key, None) + self.last_cleaned = time.time() + def flush_cache(self): self.cache_dict.clear() self.ttl_dict.clear() @@ -147,10 +178,12 @@ class RedisCache(BaseCache): namespace: Optional[str] = None, **kwargs, ): - from ._redis import get_redis_client, get_redis_connection_pool - from litellm._service_logger import ServiceLogging import redis + from litellm._service_logger import ServiceLogging + + from ._redis import get_redis_client, get_redis_connection_pool + redis_kwargs = {} if host is not None: redis_kwargs["host"] = host @@ -886,11 +919,10 @@ class RedisSemanticCache(BaseCache): def get_cache(self, key, **kwargs): print_verbose(f"sync redis semantic-cache get_cache, kwargs: {kwargs}") - from redisvl.query import VectorQuery import numpy as np + from redisvl.query import VectorQuery # query - # get the messages messages = kwargs["messages"] prompt = "".join(message["content"] for message in messages) @@ -943,7 +975,8 @@ class RedisSemanticCache(BaseCache): async def async_set_cache(self, key, value, **kwargs): import numpy as np - from litellm.proxy.proxy_server import llm_router, llm_model_list + + from litellm.proxy.proxy_server import llm_model_list, llm_router try: await self.index.acreate(overwrite=False) # don't overwrite existing index @@ -998,12 +1031,12 @@ class RedisSemanticCache(BaseCache): async def async_get_cache(self, key, **kwargs): print_verbose(f"async redis semantic-cache get_cache, kwargs: {kwargs}") - from redisvl.query import VectorQuery import numpy as np - from litellm.proxy.proxy_server import llm_router, llm_model_list + from redisvl.query import VectorQuery + + from litellm.proxy.proxy_server import llm_model_list, llm_router # query - # get the messages messages = kwargs["messages"] prompt = "".join(message["content"] for message in messages) @@ -1161,7 +1194,8 @@ class S3Cache(BaseCache): self.set_cache(key=key, value=value, **kwargs) def get_cache(self, key, **kwargs): - import boto3, botocore + import boto3 + import botocore try: key = self.key_prefix + key