mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
fix - clean up in memory cache
This commit is contained in:
parent
98590af58a
commit
c4ae06576b
1 changed files with 51 additions and 17 deletions
|
@ -7,14 +7,20 @@
|
||||||
#
|
#
|
||||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||||
|
|
||||||
import litellm
|
import ast
|
||||||
import time, logging, asyncio
|
import asyncio
|
||||||
import json, traceback, ast, hashlib
|
import hashlib
|
||||||
from typing import Optional, Literal, List, Union, Any, BinaryIO
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from typing import Any, BinaryIO, List, Literal, Optional, Union
|
||||||
|
|
||||||
from openai._models import BaseModel as OpenAIObject
|
from openai._models import BaseModel as OpenAIObject
|
||||||
|
|
||||||
|
import litellm
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
|
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
|
||||||
import traceback
|
|
||||||
|
|
||||||
|
|
||||||
def print_verbose(print_statement):
|
def print_verbose(print_statement):
|
||||||
|
@ -57,10 +63,12 @@ class BaseCache:
|
||||||
|
|
||||||
|
|
||||||
class InMemoryCache(BaseCache):
|
class InMemoryCache(BaseCache):
|
||||||
def __init__(self):
|
def __init__(self, default_ttl: Optional[float] = 60.0):
|
||||||
# if users don't provider one, use the default litellm cache
|
# if users don't provider one, use the default litellm cache
|
||||||
self.cache_dict = {}
|
self.cache_dict: dict = {}
|
||||||
self.ttl_dict = {}
|
self.ttl_dict: dict = {}
|
||||||
|
self.default_ttl = default_ttl
|
||||||
|
self.last_cleaned = 0 # since this is in memory we need to periodically clean it up to not overuse the machines RAM
|
||||||
|
|
||||||
def set_cache(self, key, value, **kwargs):
|
def set_cache(self, key, value, **kwargs):
|
||||||
print_verbose("InMemoryCache: set_cache")
|
print_verbose("InMemoryCache: set_cache")
|
||||||
|
@ -70,6 +78,8 @@ class InMemoryCache(BaseCache):
|
||||||
|
|
||||||
async def async_set_cache(self, key, value, **kwargs):
|
async def async_set_cache(self, key, value, **kwargs):
|
||||||
self.set_cache(key=key, value=value, **kwargs)
|
self.set_cache(key=key, value=value, **kwargs)
|
||||||
|
if time.time() > self.last_cleaned:
|
||||||
|
asyncio.create_task(self.clean_up_in_memory_cache())
|
||||||
|
|
||||||
async def async_set_cache_pipeline(self, cache_list, ttl=None):
|
async def async_set_cache_pipeline(self, cache_list, ttl=None):
|
||||||
for cache_key, cache_value in cache_list:
|
for cache_key, cache_value in cache_list:
|
||||||
|
@ -78,6 +88,9 @@ class InMemoryCache(BaseCache):
|
||||||
else:
|
else:
|
||||||
self.set_cache(key=cache_key, value=cache_value)
|
self.set_cache(key=cache_key, value=cache_value)
|
||||||
|
|
||||||
|
if time.time() > self.last_cleaned:
|
||||||
|
asyncio.create_task(self.clean_up_in_memory_cache())
|
||||||
|
|
||||||
def get_cache(self, key, **kwargs):
|
def get_cache(self, key, **kwargs):
|
||||||
if key in self.cache_dict:
|
if key in self.cache_dict:
|
||||||
if key in self.ttl_dict:
|
if key in self.ttl_dict:
|
||||||
|
@ -121,8 +134,26 @@ class InMemoryCache(BaseCache):
|
||||||
init_value = await self.async_get_cache(key=key) or 0
|
init_value = await self.async_get_cache(key=key) or 0
|
||||||
value = init_value + value
|
value = init_value + value
|
||||||
await self.async_set_cache(key, value, **kwargs)
|
await self.async_set_cache(key, value, **kwargs)
|
||||||
|
|
||||||
|
if time.time() > self.last_cleaned:
|
||||||
|
asyncio.create_task(self.clean_up_in_memory_cache())
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
async def clean_up_in_memory_cache(self):
|
||||||
|
"""
|
||||||
|
Runs periodically to clean up the in-memory cache
|
||||||
|
|
||||||
|
- loop through all keys in cache, check if they are expired
|
||||||
|
- if yes, delete them
|
||||||
|
"""
|
||||||
|
for key in list(self.cache_dict.keys()):
|
||||||
|
if key in self.ttl_dict:
|
||||||
|
if time.time() > self.ttl_dict[key]:
|
||||||
|
self.cache_dict.pop(key, None)
|
||||||
|
self.ttl_dict.pop(key, None)
|
||||||
|
self.last_cleaned = time.time()
|
||||||
|
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
self.cache_dict.clear()
|
self.cache_dict.clear()
|
||||||
self.ttl_dict.clear()
|
self.ttl_dict.clear()
|
||||||
|
@ -147,10 +178,12 @@ class RedisCache(BaseCache):
|
||||||
namespace: Optional[str] = None,
|
namespace: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
from ._redis import get_redis_client, get_redis_connection_pool
|
|
||||||
from litellm._service_logger import ServiceLogging
|
|
||||||
import redis
|
import redis
|
||||||
|
|
||||||
|
from litellm._service_logger import ServiceLogging
|
||||||
|
|
||||||
|
from ._redis import get_redis_client, get_redis_connection_pool
|
||||||
|
|
||||||
redis_kwargs = {}
|
redis_kwargs = {}
|
||||||
if host is not None:
|
if host is not None:
|
||||||
redis_kwargs["host"] = host
|
redis_kwargs["host"] = host
|
||||||
|
@ -886,11 +919,10 @@ class RedisSemanticCache(BaseCache):
|
||||||
|
|
||||||
def get_cache(self, key, **kwargs):
|
def get_cache(self, key, **kwargs):
|
||||||
print_verbose(f"sync redis semantic-cache get_cache, kwargs: {kwargs}")
|
print_verbose(f"sync redis semantic-cache get_cache, kwargs: {kwargs}")
|
||||||
from redisvl.query import VectorQuery
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from redisvl.query import VectorQuery
|
||||||
|
|
||||||
# query
|
# query
|
||||||
|
|
||||||
# get the messages
|
# get the messages
|
||||||
messages = kwargs["messages"]
|
messages = kwargs["messages"]
|
||||||
prompt = "".join(message["content"] for message in messages)
|
prompt = "".join(message["content"] for message in messages)
|
||||||
|
@ -943,7 +975,8 @@ class RedisSemanticCache(BaseCache):
|
||||||
|
|
||||||
async def async_set_cache(self, key, value, **kwargs):
|
async def async_set_cache(self, key, value, **kwargs):
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from litellm.proxy.proxy_server import llm_router, llm_model_list
|
|
||||||
|
from litellm.proxy.proxy_server import llm_model_list, llm_router
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.index.acreate(overwrite=False) # don't overwrite existing index
|
await self.index.acreate(overwrite=False) # don't overwrite existing index
|
||||||
|
@ -998,12 +1031,12 @@ class RedisSemanticCache(BaseCache):
|
||||||
|
|
||||||
async def async_get_cache(self, key, **kwargs):
|
async def async_get_cache(self, key, **kwargs):
|
||||||
print_verbose(f"async redis semantic-cache get_cache, kwargs: {kwargs}")
|
print_verbose(f"async redis semantic-cache get_cache, kwargs: {kwargs}")
|
||||||
from redisvl.query import VectorQuery
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from litellm.proxy.proxy_server import llm_router, llm_model_list
|
from redisvl.query import VectorQuery
|
||||||
|
|
||||||
|
from litellm.proxy.proxy_server import llm_model_list, llm_router
|
||||||
|
|
||||||
# query
|
# query
|
||||||
|
|
||||||
# get the messages
|
# get the messages
|
||||||
messages = kwargs["messages"]
|
messages = kwargs["messages"]
|
||||||
prompt = "".join(message["content"] for message in messages)
|
prompt = "".join(message["content"] for message in messages)
|
||||||
|
@ -1161,7 +1194,8 @@ class S3Cache(BaseCache):
|
||||||
self.set_cache(key=key, value=value, **kwargs)
|
self.set_cache(key=key, value=value, **kwargs)
|
||||||
|
|
||||||
def get_cache(self, key, **kwargs):
|
def get_cache(self, key, **kwargs):
|
||||||
import boto3, botocore
|
import boto3
|
||||||
|
import botocore
|
||||||
|
|
||||||
try:
|
try:
|
||||||
key = self.key_prefix + key
|
key = self.key_prefix + key
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue