forked from phoenix/litellm-mirror
fix(caching.py): use bulk writes and blockconnectionpooling for reads from Redis
This commit is contained in:
parent
007870390d
commit
01df37d8cf
3 changed files with 109 additions and 19 deletions
|
@ -106,4 +106,16 @@ def get_redis_async_client(**env_overrides):
|
|||
redis_kwargs = _get_redis_client_logic(**env_overrides)
|
||||
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||
return async_redis.Redis.from_url(**redis_kwargs)
|
||||
return async_redis.Redis(socket_timeout=5, **redis_kwargs)
|
||||
return async_redis.Redis(
|
||||
socket_timeout=5,
|
||||
**redis_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def get_redis_connection_pool(**env_overrides):
|
||||
redis_kwargs = _get_redis_client_logic(**env_overrides)
|
||||
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||
return async_redis.BlockingConnectionPool.from_url(
|
||||
timeout=5, url=redis_kwargs["url"]
|
||||
)
|
||||
return async_redis.BlockingConnectionPool(timeout=5, **redis_kwargs)
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import litellm
|
||||
import time, logging
|
||||
import time, logging, asyncio
|
||||
import json, traceback, ast, hashlib
|
||||
from typing import Optional, Literal, List, Union, Any
|
||||
from openai._models import BaseModel as OpenAIObject
|
||||
|
@ -82,7 +82,7 @@ class RedisCache(BaseCache):
|
|||
# if users don't provider one, use the default litellm cache
|
||||
|
||||
def __init__(self, host=None, port=None, password=None, **kwargs):
|
||||
from ._redis import get_redis_client
|
||||
from ._redis import get_redis_client, get_redis_connection_pool
|
||||
|
||||
redis_kwargs = {}
|
||||
if host is not None:
|
||||
|
@ -95,11 +95,20 @@ class RedisCache(BaseCache):
|
|||
redis_kwargs.update(kwargs)
|
||||
self.redis_client = get_redis_client(**redis_kwargs)
|
||||
self.redis_kwargs = redis_kwargs
|
||||
self.async_redis_conn_pool = get_redis_connection_pool()
|
||||
print_verbose(
|
||||
f"Number of available connections init: {self.async_redis_conn_pool.pool.qsize()}"
|
||||
)
|
||||
|
||||
def init_async_client(self):
|
||||
from ._redis import get_redis_async_client
|
||||
|
||||
return get_redis_async_client(**self.redis_kwargs)
|
||||
print_verbose(
|
||||
f"Number of available connections client_init: {self.async_redis_conn_pool.pool.qsize()}"
|
||||
)
|
||||
return get_redis_async_client(
|
||||
connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
|
||||
)
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
ttl = kwargs.get("ttl", None)
|
||||
|
@ -111,16 +120,52 @@ class RedisCache(BaseCache):
|
|||
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
async with self.init_async_client() as redis_client:
|
||||
_redis_client = self.init_async_client()
|
||||
async with _redis_client as redis_client:
|
||||
ttl = kwargs.get("ttl", None)
|
||||
print_verbose(
|
||||
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
|
||||
)
|
||||
try:
|
||||
await redis_client.set(name=key, value=str(value), ex=ttl)
|
||||
await redis_client.set(name=key, value=json.dumps(value), ex=ttl)
|
||||
except Exception as e:
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
|
||||
print_verbose(
|
||||
f"Number of available connections set_cache complete: {self.async_redis_conn_pool.pool.qsize()}"
|
||||
)
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, ttl=None):
|
||||
"""
|
||||
Use Redis Pipelines for bulk write operations
|
||||
"""
|
||||
_redis_client = self.init_async_client()
|
||||
try:
|
||||
async with _redis_client as redis_client:
|
||||
async with redis_client.pipeline(transaction=True) as pipe:
|
||||
# Iterate through each key-value pair in the cache_list and set them in the pipeline.
|
||||
for cache_key, cache_value in cache_list:
|
||||
print_verbose(
|
||||
f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}"
|
||||
)
|
||||
# Set the value with a TTL if it's provided.
|
||||
if ttl is not None:
|
||||
pipe.setex(cache_key, ttl, json.dumps(cache_value))
|
||||
else:
|
||||
pipe.set(cache_key, json.dumps(cache_value))
|
||||
# Execute the pipeline and return the results.
|
||||
results = await pipe.execute()
|
||||
print_verbose(
|
||||
f"Number of available connections set_cache complete: {self.async_redis_conn_pool.pool.qsize()}"
|
||||
)
|
||||
|
||||
print_verbose(f"pipeline results: {results}")
|
||||
# Optionally, you could process 'results' to make sure that all set operations were successful.
|
||||
return results
|
||||
except Exception as e:
|
||||
print_verbose(f"Error occurred in pipeline write - {str(e)}")
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
|
||||
|
||||
def _get_cache_logic(self, cached_response: Any):
|
||||
"""
|
||||
|
@ -152,7 +197,8 @@ class RedisCache(BaseCache):
|
|||
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
async with self.init_async_client() as redis_client:
|
||||
_redis_client = self.init_async_client()
|
||||
async with _redis_client as redis_client:
|
||||
try:
|
||||
print_verbose(f"Get Redis Cache: key: {key}")
|
||||
cached_response = await redis_client.get(key)
|
||||
|
@ -166,6 +212,10 @@ class RedisCache(BaseCache):
|
|||
traceback.print_exc()
|
||||
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
|
||||
|
||||
print_verbose(
|
||||
f"Number of available connections get_cache complete: {self.async_redis_conn_pool.pool.qsize()}"
|
||||
)
|
||||
|
||||
def flush_cache(self):
|
||||
self.redis_client.flushall()
|
||||
|
||||
|
@ -684,6 +734,39 @@ class Cache:
|
|||
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
traceback.print_exc()
|
||||
|
||||
async def async_add_cache_pipeline(self, result, *args, **kwargs):
|
||||
"""
|
||||
Async implementation of add_cache for Embedding calls
|
||||
|
||||
Does a bulk write, to prevent using too many clients
|
||||
"""
|
||||
try:
|
||||
cache_list = []
|
||||
for idx, i in enumerate(kwargs["input"]):
|
||||
preset_cache_key = litellm.cache.get_cache_key(
|
||||
*args, **{**kwargs, "input": i}
|
||||
)
|
||||
embedding_response = result.data[idx]
|
||||
cache_key, cached_data = self._add_cache_logic(
|
||||
result=embedding_response,
|
||||
cache_key=preset_cache_key,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
cache_list.append((cache_key, cached_data))
|
||||
if hasattr(self.cache, "async_set_cache_pipeline"):
|
||||
await self.cache.async_set_cache_pipeline(cache_list=cache_list)
|
||||
else:
|
||||
tasks = []
|
||||
for val in cache_list:
|
||||
tasks.append(
|
||||
self.cache.async_set_cache(cache_key, cached_data, **kwargs)
|
||||
)
|
||||
await asyncio.gather(*tasks)
|
||||
except Exception as e:
|
||||
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
traceback.print_exc()
|
||||
|
||||
async def disconnect(self):
|
||||
if hasattr(self.cache, "disconnect"):
|
||||
await self.cache.disconnect()
|
||||
|
|
|
@ -2346,6 +2346,9 @@ def client(original_function):
|
|||
kwargs["input"] = remaining_list
|
||||
|
||||
if len(non_null_list) > 0:
|
||||
print_verbose(
|
||||
f"EMBEDDING CACHE HIT! - {len(non_null_list)}"
|
||||
)
|
||||
final_embedding_cached_response = EmbeddingResponse(
|
||||
model=kwargs.get("model"),
|
||||
data=[None] * len(original_kwargs_input),
|
||||
|
@ -2451,19 +2454,11 @@ def client(original_function):
|
|||
if isinstance(result, EmbeddingResponse) and isinstance(
|
||||
kwargs["input"], list
|
||||
):
|
||||
for idx, i in enumerate(kwargs["input"]):
|
||||
preset_cache_key = litellm.cache.get_cache_key(
|
||||
*args, **{**kwargs, "input": i}
|
||||
asyncio.create_task(
|
||||
litellm.cache.async_add_cache_pipeline(
|
||||
result, *args, **kwargs
|
||||
)
|
||||
embedding_response = result.data[idx]
|
||||
asyncio.create_task(
|
||||
litellm.cache.async_add_cache(
|
||||
embedding_response,
|
||||
*args,
|
||||
cache_key=preset_cache_key,
|
||||
)
|
||||
)
|
||||
# pass
|
||||
)
|
||||
else:
|
||||
asyncio.create_task(
|
||||
litellm.cache.async_add_cache(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue