fix: support async redis caching

This commit is contained in:
Krrish Dholakia 2024-01-12 21:46:41 +05:30
parent 817a3d29b7
commit 007870390d
6 changed files with 357 additions and 122 deletions

View file

@ -11,6 +11,7 @@
import os import os
import inspect import inspect
import redis, litellm import redis, litellm
import redis.asyncio as async_redis
from typing import List, Optional from typing import List, Optional
@ -67,7 +68,10 @@ def get_redis_url_from_environment():
) )
def get_redis_client(**env_overrides): def _get_redis_client_logic(**env_overrides):
"""
Common functionality across sync + async redis client implementations
"""
### check if "os.environ/<key-name>" passed in ### check if "os.environ/<key-name>" passed in
for k, v in env_overrides.items(): for k, v in env_overrides.items():
if isinstance(v, str) and v.startswith("os.environ/"): if isinstance(v, str) and v.startswith("os.environ/"):
@ -85,9 +89,21 @@ def get_redis_client(**env_overrides):
redis_kwargs.pop("port", None) redis_kwargs.pop("port", None)
redis_kwargs.pop("db", None) redis_kwargs.pop("db", None)
redis_kwargs.pop("password", None) redis_kwargs.pop("password", None)
return redis.Redis.from_url(**redis_kwargs)
elif "host" not in redis_kwargs or redis_kwargs["host"] is None: elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
raise ValueError("Either 'host' or 'url' must be specified for redis.") raise ValueError("Either 'host' or 'url' must be specified for redis.")
litellm.print_verbose(f"redis_kwargs: {redis_kwargs}") litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
return redis_kwargs
def get_redis_client(**env_overrides):
redis_kwargs = _get_redis_client_logic(**env_overrides)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
return redis.Redis.from_url(**redis_kwargs)
return redis.Redis(**redis_kwargs) return redis.Redis(**redis_kwargs)
def get_redis_async_client(**env_overrides):
redis_kwargs = _get_redis_client_logic(**env_overrides)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
return async_redis.Redis.from_url(**redis_kwargs)
return async_redis.Redis(socket_timeout=5, **redis_kwargs)

View file

@ -26,9 +26,18 @@ class BaseCache:
def set_cache(self, key, value, **kwargs): def set_cache(self, key, value, **kwargs):
raise NotImplementedError raise NotImplementedError
async def async_set_cache(self, key, value, **kwargs):
raise NotImplementedError
def get_cache(self, key, **kwargs): def get_cache(self, key, **kwargs):
raise NotImplementedError raise NotImplementedError
async def async_get_cache(self, key, **kwargs):
raise NotImplementedError
async def disconnect(self):
raise NotImplementedError
class InMemoryCache(BaseCache): class InMemoryCache(BaseCache):
def __init__(self): def __init__(self):
@ -41,6 +50,9 @@ class InMemoryCache(BaseCache):
if "ttl" in kwargs: if "ttl" in kwargs:
self.ttl_dict[key] = time.time() + kwargs["ttl"] self.ttl_dict[key] = time.time() + kwargs["ttl"]
async def async_set_cache(self, key, value, **kwargs):
self.set_cache(key=key, value=value, **kwargs)
def get_cache(self, key, **kwargs): def get_cache(self, key, **kwargs):
if key in self.cache_dict: if key in self.cache_dict:
if key in self.ttl_dict: if key in self.ttl_dict:
@ -55,16 +67,21 @@ class InMemoryCache(BaseCache):
return cached_response return cached_response
return None return None
async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs)
def flush_cache(self): def flush_cache(self):
self.cache_dict.clear() self.cache_dict.clear()
self.ttl_dict.clear() self.ttl_dict.clear()
async def disconnect(self):
pass
class RedisCache(BaseCache): class RedisCache(BaseCache):
def __init__(self, host=None, port=None, password=None, **kwargs): # if users don't provider one, use the default litellm cache
import redis
# if users don't provider one, use the default litellm cache def __init__(self, host=None, port=None, password=None, **kwargs):
from ._redis import get_redis_client from ._redis import get_redis_client
redis_kwargs = {} redis_kwargs = {}
@ -76,8 +93,13 @@ class RedisCache(BaseCache):
redis_kwargs["password"] = password redis_kwargs["password"] = password
redis_kwargs.update(kwargs) redis_kwargs.update(kwargs)
self.redis_client = get_redis_client(**redis_kwargs) self.redis_client = get_redis_client(**redis_kwargs)
self.redis_kwargs = redis_kwargs
def init_async_client(self):
from ._redis import get_redis_async_client
return get_redis_async_client(**self.redis_kwargs)
def set_cache(self, key, value, **kwargs): def set_cache(self, key, value, **kwargs):
ttl = kwargs.get("ttl", None) ttl = kwargs.get("ttl", None)
@ -88,6 +110,34 @@ class RedisCache(BaseCache):
# NON blocking - notify users Redis is throwing an exception # NON blocking - notify users Redis is throwing an exception
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e) logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
async def async_set_cache(self, key, value, **kwargs):
async with self.init_async_client() as redis_client:
ttl = kwargs.get("ttl", None)
print_verbose(
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
)
try:
await redis_client.set(name=key, value=str(value), ex=ttl)
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
def _get_cache_logic(self, cached_response: Any):
"""
Common 'get_cache_logic' across sync + async redis client implementations
"""
if cached_response is None:
return cached_response
# cached_response is in `b{} convert it to ModelResponse
cached_response = cached_response.decode("utf-8") # Convert bytes to string
try:
cached_response = json.loads(
cached_response
) # Convert string to dictionary
except:
cached_response = ast.literal_eval(cached_response)
return cached_response
def get_cache(self, key, **kwargs): def get_cache(self, key, **kwargs):
try: try:
print_verbose(f"Get Redis Cache: key: {key}") print_verbose(f"Get Redis Cache: key: {key}")
@ -95,26 +145,33 @@ class RedisCache(BaseCache):
print_verbose( print_verbose(
f"Got Redis Cache: key: {key}, cached_response {cached_response}" f"Got Redis Cache: key: {key}, cached_response {cached_response}"
) )
if cached_response != None: return self._get_cache_logic(cached_response=cached_response)
# cached_response is in `b{} convert it to ModelResponse
cached_response = cached_response.decode(
"utf-8"
) # Convert bytes to string
try:
cached_response = json.loads(
cached_response
) # Convert string to dictionary
except:
cached_response = ast.literal_eval(cached_response)
return cached_response
except Exception as e: except Exception as e:
# NON blocking - notify users Redis is throwing an exception # NON blocking - notify users Redis is throwing an exception
traceback.print_exc() traceback.print_exc()
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e) logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
async def async_get_cache(self, key, **kwargs):
async with self.init_async_client() as redis_client:
try:
print_verbose(f"Get Redis Cache: key: {key}")
cached_response = await redis_client.get(key)
print_verbose(
f"Got Async Redis Cache: key: {key}, cached_response {cached_response}"
)
response = self._get_cache_logic(cached_response=cached_response)
return response
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
traceback.print_exc()
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
def flush_cache(self): def flush_cache(self):
self.redis_client.flushall() self.redis_client.flushall()
async def disconnect(self):
pass
class S3Cache(BaseCache): class S3Cache(BaseCache):
def __init__( def __init__(
@ -189,6 +246,9 @@ class S3Cache(BaseCache):
# NON blocking - notify users S3 is throwing an exception # NON blocking - notify users S3 is throwing an exception
print_verbose(f"S3 Caching: set_cache() - Got exception from S3: {e}") print_verbose(f"S3 Caching: set_cache() - Got exception from S3: {e}")
async def async_set_cache(self, key, value, **kwargs):
self.set_cache(key=key, value=value, **kwargs)
def get_cache(self, key, **kwargs): def get_cache(self, key, **kwargs):
import boto3, botocore import boto3, botocore
@ -229,6 +289,9 @@ class S3Cache(BaseCache):
traceback.print_exc() traceback.print_exc()
print_verbose(f"S3 Caching: get_cache() - Got exception from S3: {e}") print_verbose(f"S3 Caching: get_cache() - Got exception from S3: {e}")
async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs)
def flush_cache(self): def flush_cache(self):
pass pass
@ -468,6 +531,45 @@ class Cache:
} }
time.sleep(0.02) time.sleep(0.02)
def _get_cache_logic(
self,
cached_result: Optional[Any],
max_age: Optional[float],
):
"""
Common get cache logic across sync + async implementations
"""
# Check if a timestamp was stored with the cached response
if (
cached_result is not None
and isinstance(cached_result, dict)
and "timestamp" in cached_result
):
timestamp = cached_result["timestamp"]
current_time = time.time()
# Calculate age of the cached response
response_age = current_time - timestamp
# Check if the cached response is older than the max-age
if max_age is not None and response_age > max_age:
return None # Cached response is too old
# If the response is fresh, or there's no max-age requirement, return the cached response
# cached_response is in `b{} convert it to ModelResponse
cached_response = cached_result.get("response")
try:
if isinstance(cached_response, dict):
pass
else:
cached_response = json.loads(
cached_response # type: ignore
) # Convert string to dictionary
except:
cached_response = ast.literal_eval(cached_response) # type: ignore
return cached_response
return cached_result
def get_cache(self, *args, **kwargs): def get_cache(self, *args, **kwargs):
""" """
Retrieves the cached result for the given arguments. Retrieves the cached result for the given arguments.
@ -490,53 +592,40 @@ class Cache:
"s-max-age", cache_control_args.get("s-maxage", float("inf")) "s-max-age", cache_control_args.get("s-maxage", float("inf"))
) )
cached_result = self.cache.get_cache(cache_key) cached_result = self.cache.get_cache(cache_key)
# Check if a timestamp was stored with the cached response return self._get_cache_logic(
if ( cached_result=cached_result, max_age=max_age
cached_result is not None )
and isinstance(cached_result, dict)
and "timestamp" in cached_result
):
timestamp = cached_result["timestamp"]
current_time = time.time()
# Calculate age of the cached response
response_age = current_time - timestamp
# Check if the cached response is older than the max-age
if max_age is not None and response_age > max_age:
print_verbose(
f"Cached response for key {cache_key} is too old. Max-age: {max_age}s, Age: {response_age}s"
)
return None # Cached response is too old
# If the response is fresh, or there's no max-age requirement, return the cached response
# cached_response is in `b{} convert it to ModelResponse
cached_response = cached_result.get("response")
try:
if isinstance(cached_response, dict):
pass
else:
cached_response = json.loads(
cached_response
) # Convert string to dictionary
except:
cached_response = ast.literal_eval(cached_response)
return cached_response
return cached_result
except Exception as e: except Exception as e:
print_verbose(f"An exception occurred: {traceback.format_exc()}") print_verbose(f"An exception occurred: {traceback.format_exc()}")
return None return None
def add_cache(self, result, *args, **kwargs): async def async_get_cache(self, *args, **kwargs):
""" """
Adds a result to the cache. Async get cache implementation.
Args: Used for embedding calls in async wrapper
*args: args to litellm.completion() or embedding() """
**kwargs: kwargs to litellm.completion() or embedding() try: # never block execution
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
else:
cache_key = self.get_cache_key(*args, **kwargs)
if cache_key is not None:
cache_control_args = kwargs.get("cache", {})
max_age = cache_control_args.get(
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
)
cached_result = await self.cache.async_get_cache(cache_key)
return self._get_cache_logic(
cached_result=cached_result, max_age=max_age
)
except Exception as e:
print_verbose(f"An exception occurred: {traceback.format_exc()}")
return None
Returns: def _add_cache_logic(self, result, *args, **kwargs):
None """
Common implementation across sync + async add_cache functions
""" """
try: try:
if "cache_key" in kwargs: if "cache_key" in kwargs:
@ -555,17 +644,49 @@ class Cache:
if k == "ttl": if k == "ttl":
kwargs["ttl"] = v kwargs["ttl"] = v
cached_data = {"timestamp": time.time(), "response": result} cached_data = {"timestamp": time.time(), "response": result}
self.cache.set_cache(cache_key, cached_data, **kwargs) return cache_key, cached_data
else:
raise Exception("cache key is None")
except Exception as e:
raise e
def add_cache(self, result, *args, **kwargs):
"""
Adds a result to the cache.
Args:
*args: args to litellm.completion() or embedding()
**kwargs: kwargs to litellm.completion() or embedding()
Returns:
None
"""
try:
cache_key, cached_data = self._add_cache_logic(
result=result, *args, **kwargs
)
self.cache.set_cache(cache_key, cached_data, **kwargs)
except Exception as e: except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}") print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
traceback.print_exc() traceback.print_exc()
pass pass
async def _async_add_cache(self, result, *args, **kwargs): async def async_add_cache(self, result, *args, **kwargs):
self.add_cache(result, *args, **kwargs) """
Async implementation of add_cache
"""
try:
cache_key, cached_data = self._add_cache_logic(
result=result, *args, **kwargs
)
await self.cache.async_set_cache(cache_key, cached_data, **kwargs)
except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
traceback.print_exc()
async def _async_get_cache(self, *args, **kwargs): async def disconnect(self):
return self.get_cache(*args, **kwargs) if hasattr(self.cache, "disconnect"):
await self.cache.disconnect()
def enable_cache( def enable_cache(

View file

@ -346,7 +346,7 @@ def run_server(
import gunicorn.app.base import gunicorn.app.base
except: except:
raise ImportError( raise ImportError(
"Uvicorn, gunicorn needs to be imported. Run - `pip 'litellm[proxy]'`" "uvicorn, gunicorn needs to be imported. Run - `pip install 'litellm[proxy]'`"
) )
if config is not None: if config is not None:
@ -427,36 +427,40 @@ def run_server(
f"\033[1;34mSee all Router/Swagger docs on http://0.0.0.0:{port} \033[0m\n" f"\033[1;34mSee all Router/Swagger docs on http://0.0.0.0:{port} \033[0m\n"
) # noqa ) # noqa
# Gunicorn Application Class uvicorn.run(
class StandaloneApplication(gunicorn.app.base.BaseApplication): "litellm.proxy.proxy_server:app", host=host, port=port, workers=num_workers
def __init__(self, app, options=None): )
self.options = options or {} # gunicorn options
self.application = app # FastAPI app
super().__init__()
def load_config(self): # # Gunicorn Application Class
# note: This Loads the gunicorn config - has nothing to do with LiteLLM Proxy config # class StandaloneApplication(gunicorn.app.base.BaseApplication):
config = { # def __init__(self, app, options=None):
key: value # self.options = options or {} # gunicorn options
for key, value in self.options.items() # self.application = app # FastAPI app
if key in self.cfg.settings and value is not None # super().__init__()
}
for key, value in config.items():
self.cfg.set(key.lower(), value)
def load(self): # def load_config(self):
# gunicorn app function # # note: This Loads the gunicorn config - has nothing to do with LiteLLM Proxy config
return self.application # config = {
# key: value
# for key, value in self.options.items()
# if key in self.cfg.settings and value is not None
# }
# for key, value in config.items():
# self.cfg.set(key.lower(), value)
gunicorn_options = { # def load(self):
"bind": f"{host}:{port}", # # gunicorn app function
"workers": num_workers, # default is 1 # return self.application
"worker_class": "uvicorn.workers.UvicornWorker",
"preload": True, # Add the preload flag
}
from litellm.proxy.proxy_server import app
StandaloneApplication(app=app, options=gunicorn_options).run() # Run gunicorn # gunicorn_options = {
# "bind": f"{host}:{port}",
# "workers": num_workers, # default is 1
# "worker_class": "uvicorn.workers.UvicornWorker",
# "preload": True, # Add the preload flag
# }
# from litellm.proxy.proxy_server import app
# StandaloneApplication(app=app, options=gunicorn_options).run() # Run gunicorn
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -7,6 +7,20 @@ import secrets, subprocess
import hashlib, uuid import hashlib, uuid
import warnings import warnings
import importlib import importlib
import warnings
def showwarning(message, category, filename, lineno, file=None, line=None):
traceback_info = f"{filename}:{lineno}: {category.__name__}: {message}\n"
if file is not None:
file.write(traceback_info)
warnings.showwarning = showwarning
warnings.filterwarnings("default", category=UserWarning)
# Your client code here
messages: list = [] messages: list = []
sys.path.insert( sys.path.insert(
@ -2510,10 +2524,12 @@ async def get_routes():
@router.on_event("shutdown") @router.on_event("shutdown")
async def shutdown_event(): async def shutdown_event():
global prisma_client, master_key, user_custom_auth global prisma_client, master_key, user_custom_auth
if prisma_client: if prisma_client is not None:
verbose_proxy_logger.debug("Disconnecting from Prisma") verbose_proxy_logger.debug("Disconnecting from Prisma")
await prisma_client.disconnect() await prisma_client.disconnect()
if litellm.cache is not None:
await litellm.cache.disconnect()
## RESET CUSTOM VARIABLES ## ## RESET CUSTOM VARIABLES ##
cleanup_router_config_variables() cleanup_router_config_variables()

View file

@ -266,8 +266,9 @@ async def test_embedding_caching_azure_individual_items():
""" """
Tests caching for individual items in an embedding list Tests caching for individual items in an embedding list
Assert if the same embeddingresponse object is returned for the duplicate item in 2 embedding list calls - Cache an item
- call aembedding(..) with the item + 1 unique item
- compare to a 2nd aembedding (...) with 2 unique items
``` ```
embedding_1 = ["hey how's it going", "I'm doing well"] embedding_1 = ["hey how's it going", "I'm doing well"]
embedding_val_1 = embedding(...) embedding_val_1 = embedding(...)
@ -280,31 +281,98 @@ async def test_embedding_caching_azure_individual_items():
""" """
litellm.cache = Cache() litellm.cache = Cache()
common_msg = f"hey how's it going {uuid.uuid4()}" common_msg = f"hey how's it going {uuid.uuid4()}"
embedding_1 = [common_msg, "I'm doing well"] common_msg_2 = f"hey how's it going {uuid.uuid4()}"
embedding_2 = [common_msg, "I'm fine"] embedding_2 = [
common_msg,
f"I'm fine {uuid.uuid4()}",
common_msg,
common_msg,
common_msg,
] * 20
embedding_2 = [
common_msg,
f"I'm fine {uuid.uuid4()}",
common_msg,
common_msg,
common_msg,
] * 20
embedding_3 = [
common_msg_2,
common_msg_2,
common_msg_2,
common_msg_2,
f"I'm fine {uuid.uuid4()}",
] * 20 # make sure azure doesn't return cached 'i'm fine' responses
embedding_val_1 = await aembedding( embedding_val_1 = await aembedding(
model="azure/azure-embedding-model", input=embedding_1, caching=True model="azure/azure-embedding-model", input=embedding_1, caching=True
) )
second_response_start_time = time.time()
embedding_val_2 = await aembedding( embedding_val_2 = await aembedding(
model="azure/azure-embedding-model", input=embedding_2, caching=True model="azure/azure-embedding-model", input=embedding_2, caching=True
) )
print(f"embedding_val_2: {embedding_val_2}") if embedding_val_2 is not None:
if ( second_response_end_time = time.time()
embedding_val_2["data"][0]["embedding"] second_response_time = second_response_end_time - second_response_start_time
!= embedding_val_1["data"][0]["embedding"]
): third_response_start_time = time.time()
print(f"embedding1: {embedding_val_1}") embedding_val_3 = await aembedding(
print(f"embedding2: {embedding_val_2}") model="azure/azure-embedding-model", input=embedding_3, cache={"no-cache": True}
pytest.fail("Error occurred: Embedding caching failed") )
if ( if embedding_val_3 is not None:
embedding_val_2["data"][1]["embedding"] third_response_end_time = time.time()
== embedding_val_1["data"][1]["embedding"] third_response_time = third_response_end_time - third_response_start_time
):
print(f"embedding1: {embedding_val_1}") print(f"second_response_time: {second_response_time}")
print(f"embedding2: {embedding_val_2}") print(f"third_response_time: {third_response_time}")
pytest.fail("Error occurred: Embedding caching failed")
assert (
second_response_time < third_response_time - 0.5
) # make sure it's actually faster
raise Exception(f"it works {second_response_time} < {third_response_time}")
@pytest.mark.asyncio
async def test_redis_cache_basic():
"""
Init redis client
- write to client
- read from client
"""
litellm.set_verbose = False
random_number = random.randint(
1, 100000
) # add a random number to ensure it's always adding / reading from cache
messages = [
{"role": "user", "content": f"write a one sentence poem about: {random_number}"}
]
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
response1 = completion(
model="gpt-3.5-turbo",
messages=messages,
)
cache_key = litellm.cache.get_cache_key(
model="gpt-3.5-turbo",
messages=messages,
)
print(f"cache_key: {cache_key}")
litellm.cache.add_cache(result=response1, cache_key=cache_key)
print(f"cache key pre async get: {cache_key}")
stored_val = await litellm.cache.async_get_cache(
model="gpt-3.5-turbo",
messages=messages,
)
print(f"stored_val: {stored_val}")
assert stored_val["id"] == response1.id
raise Exception("it worked!")
def test_redis_cache_completion(): def test_redis_cache_completion():

View file

@ -2214,8 +2214,13 @@ def client(original_function):
) )
# if caching is false, don't run this # if caching is false, don't run this
final_embedding_cached_response = None final_embedding_cached_response = None
if ( if (
(kwargs.get("caching", None) is None and litellm.cache is not None) (
kwargs.get("caching", None) is None
and kwargs.get("cache", None) is None
and litellm.cache is not None
)
or kwargs.get("caching", False) == True or kwargs.get("caching", False) == True
or ( or (
kwargs.get("cache", None) is not None kwargs.get("cache", None) is not None
@ -2234,12 +2239,13 @@ def client(original_function):
kwargs["input"], list kwargs["input"], list
): ):
tasks = [] tasks = []
embedding_kwargs = copy.deepcopy(kwargs)
for idx, i in enumerate(kwargs["input"]): for idx, i in enumerate(kwargs["input"]):
embedding_kwargs["input"] = i preset_cache_key = litellm.cache.get_cache_key(
*args, **{**kwargs, "input": i}
)
tasks.append( tasks.append(
litellm.cache._async_get_cache( litellm.cache.async_get_cache(
*args, **embedding_kwargs cache_key=preset_cache_key
) )
) )
cached_result = await asyncio.gather(*tasks) cached_result = await asyncio.gather(*tasks)
@ -2445,24 +2451,28 @@ def client(original_function):
if isinstance(result, EmbeddingResponse) and isinstance( if isinstance(result, EmbeddingResponse) and isinstance(
kwargs["input"], list kwargs["input"], list
): ):
embedding_kwargs = copy.deepcopy(kwargs)
for idx, i in enumerate(kwargs["input"]): for idx, i in enumerate(kwargs["input"]):
preset_cache_key = litellm.cache.get_cache_key(
*args, **{**kwargs, "input": i}
)
embedding_response = result.data[idx] embedding_response = result.data[idx]
embedding_kwargs["input"] = i
asyncio.create_task( asyncio.create_task(
litellm.cache._async_add_cache( litellm.cache.async_add_cache(
embedding_response, *args, **embedding_kwargs embedding_response,
*args,
cache_key=preset_cache_key,
) )
) )
# pass
else: else:
asyncio.create_task( asyncio.create_task(
litellm.cache._async_add_cache( litellm.cache.async_add_cache(
result.json(), *args, **kwargs result.json(), *args, **kwargs
) )
) )
else: else:
asyncio.create_task( asyncio.create_task(
litellm.cache._async_add_cache(result, *args, **kwargs) litellm.cache.async_add_cache(result, *args, **kwargs)
) )
# LOG SUCCESS - handle streaming success logging in the _next_ object # LOG SUCCESS - handle streaming success logging in the _next_ object
print_verbose( print_verbose(