litellm-mirror/litellm/caching/in_memory_cache.py
Krish Dholakia 8ee32291e0
Squashed commit of the following: (#9709)
commit b12a9892b7
Author: Krrish Dholakia <krrishdholakia@gmail.com>
Date:   Wed Apr 2 08:09:56 2025 -0700

    fix(utils.py): don't modify openai_token_counter

commit 294de31803
Author: Krrish Dholakia <krrishdholakia@gmail.com>
Date:   Mon Mar 24 21:22:40 2025 -0700

    fix: fix linting error

commit cb6e9fbe40
Author: Krrish Dholakia <krrishdholakia@gmail.com>
Date:   Mon Mar 24 19:52:45 2025 -0700

    refactor: complete migration

commit bfc159172d
Author: Krrish Dholakia <krrishdholakia@gmail.com>
Date:   Mon Mar 24 19:09:59 2025 -0700

    refactor: refactor more constants

commit 43ffb6a558
Author: Krrish Dholakia <krrishdholakia@gmail.com>
Date:   Mon Mar 24 18:45:24 2025 -0700

    fix: test

commit 04dbe4310c
Author: Krrish Dholakia <krrishdholakia@gmail.com>
Date:   Mon Mar 24 18:28:58 2025 -0700

    refactor: refactor: move more constants into constants.py

commit 3c26284aff
Author: Krrish Dholakia <krrishdholakia@gmail.com>
Date:   Mon Mar 24 18:14:46 2025 -0700

    refactor: migrate hardcoded constants out of __init__.py

commit c11e0de69d
Author: Krrish Dholakia <krrishdholakia@gmail.com>
Date:   Mon Mar 24 18:11:21 2025 -0700

    build: migrate all constants into constants.py

commit 7882bdc787
Author: Krrish Dholakia <krrishdholakia@gmail.com>
Date:   Mon Mar 24 18:07:37 2025 -0700

    build: initial test banning hardcoded numbers in repo
2025-04-02 21:24:54 -07:00

204 lines
7 KiB
Python

"""
In-Memory Cache implementation
Has 4 methods:
- set_cache
- get_cache
- async_set_cache
- async_get_cache
"""
import json
import sys
import time
from typing import Any, List, Optional
from pydantic import BaseModel
from litellm.constants import MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
from .base_cache import BaseCache
class InMemoryCache(BaseCache):
def __init__(
self,
max_size_in_memory: Optional[int] = 200,
default_ttl: Optional[
int
] = 600, # default ttl is 10 minutes. At maximum litellm rate limiting logic requires objects to be in memory for 1 minute
max_size_per_item: Optional[int] = 1024, # 1MB = 1024KB
):
"""
max_size_in_memory [int]: Maximum number of items in cache. done to prevent memory leaks. Use 200 items as a default
"""
self.max_size_in_memory = (
max_size_in_memory or 200
) # set an upper bound of 200 items in-memory
self.default_ttl = default_ttl or 600
self.max_size_per_item = (
max_size_per_item or MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
) # 1MB = 1024KB
# in-memory cache
self.cache_dict: dict = {}
self.ttl_dict: dict = {}
def check_value_size(self, value: Any):
"""
Check if value size exceeds max_size_per_item (1MB)
Returns True if value size is acceptable, False otherwise
"""
try:
# Fast path for common primitive types that are typically small
if (
isinstance(value, (bool, int, float, str))
and len(str(value))
< self.max_size_per_item * MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
): # Conservative estimate
return True
# Direct size check for bytes objects
if isinstance(value, bytes):
return sys.getsizeof(value) / 1024 <= self.max_size_per_item
# Handle special types without full conversion when possible
if hasattr(value, "__sizeof__"): # Use __sizeof__ if available
size = value.__sizeof__() / 1024
return size <= self.max_size_per_item
# Fallback for complex types
if isinstance(value, BaseModel) and hasattr(
value, "model_dump"
): # Pydantic v2
value = value.model_dump()
elif hasattr(value, "isoformat"): # datetime objects
return True # datetime strings are always small
# Only convert to JSON if absolutely necessary
if not isinstance(value, (str, bytes)):
value = json.dumps(value, default=str)
return sys.getsizeof(value) / 1024 <= self.max_size_per_item
except Exception:
return False
def evict_cache(self):
"""
Eviction policy:
- check if any items in ttl_dict are expired -> remove them from ttl_dict and cache_dict
This guarantees the following:
- 1. When item ttl not set: At minimumm each item will remain in memory for 5 minutes
- 2. When ttl is set: the item will remain in memory for at least that amount of time
- 3. the size of in-memory cache is bounded
"""
for key in list(self.ttl_dict.keys()):
if time.time() > self.ttl_dict[key]:
self.cache_dict.pop(key, None)
self.ttl_dict.pop(key, None)
# de-reference the removed item
# https://www.geeksforgeeks.org/diagnosing-and-fixing-memory-leaks-in-python/
# One of the most common causes of memory leaks in Python is the retention of objects that are no longer being used.
# This can occur when an object is referenced by another object, but the reference is never removed.
def set_cache(self, key, value, **kwargs):
if len(self.cache_dict) >= self.max_size_in_memory:
# only evict when cache is full
self.evict_cache()
if not self.check_value_size(value):
return
self.cache_dict[key] = value
if "ttl" in kwargs and kwargs["ttl"] is not None:
self.ttl_dict[key] = time.time() + kwargs["ttl"]
else:
self.ttl_dict[key] = time.time() + self.default_ttl
async def async_set_cache(self, key, value, **kwargs):
self.set_cache(key=key, value=value, **kwargs)
async def async_set_cache_pipeline(self, cache_list, ttl=None, **kwargs):
for cache_key, cache_value in cache_list:
if ttl is not None:
self.set_cache(key=cache_key, value=cache_value, ttl=ttl)
else:
self.set_cache(key=cache_key, value=cache_value)
async def async_set_cache_sadd(self, key, value: List, ttl: Optional[float]):
"""
Add value to set
"""
# get the value
init_value = self.get_cache(key=key) or set()
for val in value:
init_value.add(val)
self.set_cache(key, init_value, ttl=ttl)
return value
def get_cache(self, key, **kwargs):
if key in self.cache_dict:
if key in self.ttl_dict:
if time.time() > self.ttl_dict[key]:
self.cache_dict.pop(key, None)
return None
original_cached_response = self.cache_dict[key]
try:
cached_response = json.loads(original_cached_response)
except Exception:
cached_response = original_cached_response
return cached_response
return None
def batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
def increment_cache(self, key, value: int, **kwargs) -> int:
# get the value
init_value = self.get_cache(key=key) or 0
value = init_value + value
self.set_cache(key, value, **kwargs)
return value
async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs)
async def async_batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
async def async_increment(self, key, value: float, **kwargs) -> float:
# get the value
init_value = await self.async_get_cache(key=key) or 0
value = init_value + value
await self.async_set_cache(key, value, **kwargs)
return value
def flush_cache(self):
self.cache_dict.clear()
self.ttl_dict.clear()
async def disconnect(self):
pass
def delete_cache(self, key):
self.cache_dict.pop(key, None)
self.ttl_dict.pop(key, None)
async def async_get_ttl(self, key: str) -> Optional[int]:
"""
Get the remaining TTL of a key in in-memory cache
"""
return self.ttl_dict.get(key, None)