mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix - clean up in memory cache
This commit is contained in:
parent
4fc8efd640
commit
fa554ae218
1 changed files with 51 additions and 17 deletions
|
@ -7,14 +7,20 @@
|
|||
#
|
||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import litellm
|
||||
import time, logging, asyncio
|
||||
import json, traceback, ast, hashlib
|
||||
from typing import Optional, Literal, List, Union, Any, BinaryIO
|
||||
import ast
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
from typing import Any, BinaryIO, List, Literal, Optional, Union
|
||||
|
||||
from openai._models import BaseModel as OpenAIObject
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
|
||||
import traceback
|
||||
|
||||
|
||||
def print_verbose(print_statement):
|
||||
|
@ -57,10 +63,12 @@ class BaseCache:
|
|||
|
||||
|
||||
class InMemoryCache(BaseCache):
|
||||
def __init__(self):
|
||||
def __init__(self, default_ttl: Optional[float] = 60.0):
|
||||
# if users don't provider one, use the default litellm cache
|
||||
self.cache_dict = {}
|
||||
self.ttl_dict = {}
|
||||
self.cache_dict: dict = {}
|
||||
self.ttl_dict: dict = {}
|
||||
self.default_ttl = default_ttl
|
||||
self.last_cleaned = 0 # since this is in memory we need to periodically clean it up to not overuse the machines RAM
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
print_verbose("InMemoryCache: set_cache")
|
||||
|
@ -70,6 +78,8 @@ class InMemoryCache(BaseCache):
|
|||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
self.set_cache(key=key, value=value, **kwargs)
|
||||
if time.time() > self.last_cleaned:
|
||||
asyncio.create_task(self.clean_up_in_memory_cache())
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, ttl=None):
|
||||
for cache_key, cache_value in cache_list:
|
||||
|
@ -78,6 +88,9 @@ class InMemoryCache(BaseCache):
|
|||
else:
|
||||
self.set_cache(key=cache_key, value=cache_value)
|
||||
|
||||
if time.time() > self.last_cleaned:
|
||||
asyncio.create_task(self.clean_up_in_memory_cache())
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
if key in self.cache_dict:
|
||||
if key in self.ttl_dict:
|
||||
|
@ -121,8 +134,26 @@ class InMemoryCache(BaseCache):
|
|||
init_value = await self.async_get_cache(key=key) or 0
|
||||
value = init_value + value
|
||||
await self.async_set_cache(key, value, **kwargs)
|
||||
|
||||
if time.time() > self.last_cleaned:
|
||||
asyncio.create_task(self.clean_up_in_memory_cache())
|
||||
|
||||
return value
|
||||
|
||||
async def clean_up_in_memory_cache(self):
|
||||
"""
|
||||
Runs periodically to clean up the in-memory cache
|
||||
|
||||
- loop through all keys in cache, check if they are expired
|
||||
- if yes, delete them
|
||||
"""
|
||||
for key in list(self.cache_dict.keys()):
|
||||
if key in self.ttl_dict:
|
||||
if time.time() > self.ttl_dict[key]:
|
||||
self.cache_dict.pop(key, None)
|
||||
self.ttl_dict.pop(key, None)
|
||||
self.last_cleaned = time.time()
|
||||
|
||||
def flush_cache(self):
|
||||
self.cache_dict.clear()
|
||||
self.ttl_dict.clear()
|
||||
|
@ -147,10 +178,12 @@ class RedisCache(BaseCache):
|
|||
namespace: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
from ._redis import get_redis_client, get_redis_connection_pool
|
||||
from litellm._service_logger import ServiceLogging
|
||||
import redis
|
||||
|
||||
from litellm._service_logger import ServiceLogging
|
||||
|
||||
from ._redis import get_redis_client, get_redis_connection_pool
|
||||
|
||||
redis_kwargs = {}
|
||||
if host is not None:
|
||||
redis_kwargs["host"] = host
|
||||
|
@ -886,11 +919,10 @@ class RedisSemanticCache(BaseCache):
|
|||
|
||||
def get_cache(self, key, **kwargs):
|
||||
print_verbose(f"sync redis semantic-cache get_cache, kwargs: {kwargs}")
|
||||
from redisvl.query import VectorQuery
|
||||
import numpy as np
|
||||
from redisvl.query import VectorQuery
|
||||
|
||||
# query
|
||||
|
||||
# get the messages
|
||||
messages = kwargs["messages"]
|
||||
prompt = "".join(message["content"] for message in messages)
|
||||
|
@ -943,7 +975,8 @@ class RedisSemanticCache(BaseCache):
|
|||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
import numpy as np
|
||||
from litellm.proxy.proxy_server import llm_router, llm_model_list
|
||||
|
||||
from litellm.proxy.proxy_server import llm_model_list, llm_router
|
||||
|
||||
try:
|
||||
await self.index.acreate(overwrite=False) # don't overwrite existing index
|
||||
|
@ -998,12 +1031,12 @@ class RedisSemanticCache(BaseCache):
|
|||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
print_verbose(f"async redis semantic-cache get_cache, kwargs: {kwargs}")
|
||||
from redisvl.query import VectorQuery
|
||||
import numpy as np
|
||||
from litellm.proxy.proxy_server import llm_router, llm_model_list
|
||||
from redisvl.query import VectorQuery
|
||||
|
||||
from litellm.proxy.proxy_server import llm_model_list, llm_router
|
||||
|
||||
# query
|
||||
|
||||
# get the messages
|
||||
messages = kwargs["messages"]
|
||||
prompt = "".join(message["content"] for message in messages)
|
||||
|
@ -1161,7 +1194,8 @@ class S3Cache(BaseCache):
|
|||
self.set_cache(key=key, value=value, **kwargs)
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
import boto3, botocore
|
||||
import boto3
|
||||
import botocore
|
||||
|
||||
try:
|
||||
key = self.key_prefix + key
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue