fix - clean up in memory cache

This commit is contained in:
Ishaan Jaff 2024-06-22 18:46:30 -07:00
parent 98590af58a
commit c4ae06576b

View file

@ -7,14 +7,20 @@
# #
# Thank you users! We ❤️ you! - Krrish & Ishaan # Thank you users! We ❤️ you! - Krrish & Ishaan
import litellm import ast
import time, logging, asyncio import asyncio
import json, traceback, ast, hashlib import hashlib
from typing import Optional, Literal, List, Union, Any, BinaryIO import json
import logging
import time
import traceback
from typing import Any, BinaryIO, List, Literal, Optional, Union
from openai._models import BaseModel as OpenAIObject from openai._models import BaseModel as OpenAIObject
import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.types.services import ServiceLoggerPayload, ServiceTypes from litellm.types.services import ServiceLoggerPayload, ServiceTypes
import traceback
def print_verbose(print_statement): def print_verbose(print_statement):
@ -57,10 +63,12 @@ class BaseCache:
class InMemoryCache(BaseCache): class InMemoryCache(BaseCache):
def __init__(self): def __init__(self, default_ttl: Optional[float] = 60.0):
# if users don't provider one, use the default litellm cache # if users don't provider one, use the default litellm cache
self.cache_dict = {} self.cache_dict: dict = {}
self.ttl_dict = {} self.ttl_dict: dict = {}
self.default_ttl = default_ttl
self.last_cleaned = 0 # since this is in memory we need to periodically clean it up to not overuse the machines RAM
def set_cache(self, key, value, **kwargs): def set_cache(self, key, value, **kwargs):
print_verbose("InMemoryCache: set_cache") print_verbose("InMemoryCache: set_cache")
@ -70,6 +78,8 @@ class InMemoryCache(BaseCache):
async def async_set_cache(self, key, value, **kwargs): async def async_set_cache(self, key, value, **kwargs):
self.set_cache(key=key, value=value, **kwargs) self.set_cache(key=key, value=value, **kwargs)
if time.time() > self.last_cleaned:
asyncio.create_task(self.clean_up_in_memory_cache())
async def async_set_cache_pipeline(self, cache_list, ttl=None): async def async_set_cache_pipeline(self, cache_list, ttl=None):
for cache_key, cache_value in cache_list: for cache_key, cache_value in cache_list:
@ -78,6 +88,9 @@ class InMemoryCache(BaseCache):
else: else:
self.set_cache(key=cache_key, value=cache_value) self.set_cache(key=cache_key, value=cache_value)
if time.time() > self.last_cleaned:
asyncio.create_task(self.clean_up_in_memory_cache())
def get_cache(self, key, **kwargs): def get_cache(self, key, **kwargs):
if key in self.cache_dict: if key in self.cache_dict:
if key in self.ttl_dict: if key in self.ttl_dict:
@ -121,8 +134,26 @@ class InMemoryCache(BaseCache):
init_value = await self.async_get_cache(key=key) or 0 init_value = await self.async_get_cache(key=key) or 0
value = init_value + value value = init_value + value
await self.async_set_cache(key, value, **kwargs) await self.async_set_cache(key, value, **kwargs)
if time.time() > self.last_cleaned:
asyncio.create_task(self.clean_up_in_memory_cache())
return value return value
async def clean_up_in_memory_cache(self):
"""
Runs periodically to clean up the in-memory cache
- loop through all keys in cache, check if they are expired
- if yes, delete them
"""
for key in list(self.cache_dict.keys()):
if key in self.ttl_dict:
if time.time() > self.ttl_dict[key]:
self.cache_dict.pop(key, None)
self.ttl_dict.pop(key, None)
self.last_cleaned = time.time()
def flush_cache(self): def flush_cache(self):
self.cache_dict.clear() self.cache_dict.clear()
self.ttl_dict.clear() self.ttl_dict.clear()
@ -147,10 +178,12 @@ class RedisCache(BaseCache):
namespace: Optional[str] = None, namespace: Optional[str] = None,
**kwargs, **kwargs,
): ):
from ._redis import get_redis_client, get_redis_connection_pool
from litellm._service_logger import ServiceLogging
import redis import redis
from litellm._service_logger import ServiceLogging
from ._redis import get_redis_client, get_redis_connection_pool
redis_kwargs = {} redis_kwargs = {}
if host is not None: if host is not None:
redis_kwargs["host"] = host redis_kwargs["host"] = host
@ -886,11 +919,10 @@ class RedisSemanticCache(BaseCache):
def get_cache(self, key, **kwargs): def get_cache(self, key, **kwargs):
print_verbose(f"sync redis semantic-cache get_cache, kwargs: {kwargs}") print_verbose(f"sync redis semantic-cache get_cache, kwargs: {kwargs}")
from redisvl.query import VectorQuery
import numpy as np import numpy as np
from redisvl.query import VectorQuery
# query # query
# get the messages # get the messages
messages = kwargs["messages"] messages = kwargs["messages"]
prompt = "".join(message["content"] for message in messages) prompt = "".join(message["content"] for message in messages)
@ -943,7 +975,8 @@ class RedisSemanticCache(BaseCache):
async def async_set_cache(self, key, value, **kwargs): async def async_set_cache(self, key, value, **kwargs):
import numpy as np import numpy as np
from litellm.proxy.proxy_server import llm_router, llm_model_list
from litellm.proxy.proxy_server import llm_model_list, llm_router
try: try:
await self.index.acreate(overwrite=False) # don't overwrite existing index await self.index.acreate(overwrite=False) # don't overwrite existing index
@ -998,12 +1031,12 @@ class RedisSemanticCache(BaseCache):
async def async_get_cache(self, key, **kwargs): async def async_get_cache(self, key, **kwargs):
print_verbose(f"async redis semantic-cache get_cache, kwargs: {kwargs}") print_verbose(f"async redis semantic-cache get_cache, kwargs: {kwargs}")
from redisvl.query import VectorQuery
import numpy as np import numpy as np
from litellm.proxy.proxy_server import llm_router, llm_model_list from redisvl.query import VectorQuery
from litellm.proxy.proxy_server import llm_model_list, llm_router
# query # query
# get the messages # get the messages
messages = kwargs["messages"] messages = kwargs["messages"]
prompt = "".join(message["content"] for message in messages) prompt = "".join(message["content"] for message in messages)
@ -1161,7 +1194,8 @@ class S3Cache(BaseCache):
self.set_cache(key=key, value=value, **kwargs) self.set_cache(key=key, value=value, **kwargs)
def get_cache(self, key, **kwargs): def get_cache(self, key, **kwargs):
import boto3, botocore import boto3
import botocore
try: try:
key = self.key_prefix + key key = self.key_prefix + key