fix - clean up in memory cache

This commit is contained in:
Ishaan Jaff 2024-06-22 18:46:30 -07:00
parent 4fc8efd640
commit fa554ae218

View file

@ -7,14 +7,20 @@
#
# Thank you users! We ❤️ you! - Krrish & Ishaan
import litellm
import time, logging, asyncio
import json, traceback, ast, hashlib
from typing import Optional, Literal, List, Union, Any, BinaryIO
import ast
import asyncio
import hashlib
import json
import logging
import time
import traceback
from typing import Any, BinaryIO, List, Literal, Optional, Union
from openai._models import BaseModel as OpenAIObject
import litellm
from litellm._logging import verbose_logger
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
import traceback
def print_verbose(print_statement):
@ -57,10 +63,12 @@ class BaseCache:
class InMemoryCache(BaseCache):
def __init__(self):
def __init__(self, default_ttl: Optional[float] = 60.0):
# if users don't provider one, use the default litellm cache
self.cache_dict = {}
self.ttl_dict = {}
self.cache_dict: dict = {}
self.ttl_dict: dict = {}
self.default_ttl = default_ttl
self.last_cleaned = 0 # since this is in memory we need to periodically clean it up to not overuse the machines RAM
def set_cache(self, key, value, **kwargs):
print_verbose("InMemoryCache: set_cache")
@ -70,6 +78,8 @@ class InMemoryCache(BaseCache):
async def async_set_cache(self, key, value, **kwargs):
self.set_cache(key=key, value=value, **kwargs)
if time.time() > self.last_cleaned:
asyncio.create_task(self.clean_up_in_memory_cache())
async def async_set_cache_pipeline(self, cache_list, ttl=None):
for cache_key, cache_value in cache_list:
@ -78,6 +88,9 @@ class InMemoryCache(BaseCache):
else:
self.set_cache(key=cache_key, value=cache_value)
if time.time() > self.last_cleaned:
asyncio.create_task(self.clean_up_in_memory_cache())
def get_cache(self, key, **kwargs):
if key in self.cache_dict:
if key in self.ttl_dict:
@ -121,8 +134,26 @@ class InMemoryCache(BaseCache):
init_value = await self.async_get_cache(key=key) or 0
value = init_value + value
await self.async_set_cache(key, value, **kwargs)
if time.time() > self.last_cleaned:
asyncio.create_task(self.clean_up_in_memory_cache())
return value
async def clean_up_in_memory_cache(self):
"""
Runs periodically to clean up the in-memory cache
- loop through all keys in cache, check if they are expired
- if yes, delete them
"""
for key in list(self.cache_dict.keys()):
if key in self.ttl_dict:
if time.time() > self.ttl_dict[key]:
self.cache_dict.pop(key, None)
self.ttl_dict.pop(key, None)
self.last_cleaned = time.time()
def flush_cache(self):
self.cache_dict.clear()
self.ttl_dict.clear()
@ -147,10 +178,12 @@ class RedisCache(BaseCache):
namespace: Optional[str] = None,
**kwargs,
):
from ._redis import get_redis_client, get_redis_connection_pool
from litellm._service_logger import ServiceLogging
import redis
from litellm._service_logger import ServiceLogging
from ._redis import get_redis_client, get_redis_connection_pool
redis_kwargs = {}
if host is not None:
redis_kwargs["host"] = host
@ -886,11 +919,10 @@ class RedisSemanticCache(BaseCache):
def get_cache(self, key, **kwargs):
print_verbose(f"sync redis semantic-cache get_cache, kwargs: {kwargs}")
from redisvl.query import VectorQuery
import numpy as np
from redisvl.query import VectorQuery
# query
# get the messages
messages = kwargs["messages"]
prompt = "".join(message["content"] for message in messages)
@ -943,7 +975,8 @@ class RedisSemanticCache(BaseCache):
async def async_set_cache(self, key, value, **kwargs):
import numpy as np
from litellm.proxy.proxy_server import llm_router, llm_model_list
from litellm.proxy.proxy_server import llm_model_list, llm_router
try:
await self.index.acreate(overwrite=False) # don't overwrite existing index
@ -998,12 +1031,12 @@ class RedisSemanticCache(BaseCache):
async def async_get_cache(self, key, **kwargs):
print_verbose(f"async redis semantic-cache get_cache, kwargs: {kwargs}")
from redisvl.query import VectorQuery
import numpy as np
from litellm.proxy.proxy_server import llm_router, llm_model_list
from redisvl.query import VectorQuery
from litellm.proxy.proxy_server import llm_model_list, llm_router
# query
# get the messages
messages = kwargs["messages"]
prompt = "".join(message["content"] for message in messages)
@ -1161,7 +1194,8 @@ class S3Cache(BaseCache):
self.set_cache(key=key, value=value, **kwargs)
def get_cache(self, key, **kwargs):
import boto3, botocore
import boto3
import botocore
try:
key = self.key_prefix + key