mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
* define model param helper * use ModelParamHelper * get_standard_logging_model_parameters * fix code quality * get_standard_logging_model_parameters * StandardLoggingPayload * test_get_kwargs_for_cache_key * test_langsmith_key_based_logging * fix code qa * fix linting
797 lines
31 KiB
Python
797 lines
31 KiB
Python
# +-----------------------------------------------+
|
|
# | |
|
|
# | Give Feedback / Get Help |
|
|
# | https://github.com/BerriAI/litellm/issues/new |
|
|
# | |
|
|
# +-----------------------------------------------+
|
|
#
|
|
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
|
|
|
import ast
|
|
import hashlib
|
|
import json
|
|
import time
|
|
import traceback
|
|
from enum import Enum
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
from pydantic import BaseModel
|
|
|
|
import litellm
|
|
from litellm._logging import verbose_logger
|
|
from litellm.litellm_core_utils.model_param_helper import ModelParamHelper
|
|
from litellm.types.caching import *
|
|
from litellm.types.utils import all_litellm_params
|
|
|
|
from .base_cache import BaseCache
|
|
from .disk_cache import DiskCache
|
|
from .dual_cache import DualCache # noqa
|
|
from .in_memory_cache import InMemoryCache
|
|
from .qdrant_semantic_cache import QdrantSemanticCache
|
|
from .redis_cache import RedisCache
|
|
from .redis_cluster_cache import RedisClusterCache
|
|
from .redis_semantic_cache import RedisSemanticCache
|
|
from .s3_cache import S3Cache
|
|
|
|
|
|
def print_verbose(print_statement):
|
|
try:
|
|
verbose_logger.debug(print_statement)
|
|
if litellm.set_verbose:
|
|
print(print_statement) # noqa
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
class CacheMode(str, Enum):
|
|
default_on = "default_on"
|
|
default_off = "default_off"
|
|
|
|
|
|
#### LiteLLM.Completion / Embedding Cache ####
|
|
class Cache:
|
|
def __init__(
|
|
self,
|
|
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
|
|
mode: Optional[
|
|
CacheMode
|
|
] = CacheMode.default_on, # when default_on cache is always on, when default_off cache is opt in
|
|
host: Optional[str] = None,
|
|
port: Optional[str] = None,
|
|
password: Optional[str] = None,
|
|
namespace: Optional[str] = None,
|
|
ttl: Optional[float] = None,
|
|
default_in_memory_ttl: Optional[float] = None,
|
|
default_in_redis_ttl: Optional[float] = None,
|
|
similarity_threshold: Optional[float] = None,
|
|
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
|
|
"completion",
|
|
"acompletion",
|
|
"embedding",
|
|
"aembedding",
|
|
"atranscription",
|
|
"transcription",
|
|
"atext_completion",
|
|
"text_completion",
|
|
"arerank",
|
|
"rerank",
|
|
],
|
|
# s3 Bucket, boto3 configuration
|
|
s3_bucket_name: Optional[str] = None,
|
|
s3_region_name: Optional[str] = None,
|
|
s3_api_version: Optional[str] = None,
|
|
s3_use_ssl: Optional[bool] = True,
|
|
s3_verify: Optional[Union[bool, str]] = None,
|
|
s3_endpoint_url: Optional[str] = None,
|
|
s3_aws_access_key_id: Optional[str] = None,
|
|
s3_aws_secret_access_key: Optional[str] = None,
|
|
s3_aws_session_token: Optional[str] = None,
|
|
s3_config: Optional[Any] = None,
|
|
s3_path: Optional[str] = None,
|
|
redis_semantic_cache_use_async=False,
|
|
redis_semantic_cache_embedding_model="text-embedding-ada-002",
|
|
redis_flush_size: Optional[int] = None,
|
|
redis_startup_nodes: Optional[List] = None,
|
|
disk_cache_dir=None,
|
|
qdrant_api_base: Optional[str] = None,
|
|
qdrant_api_key: Optional[str] = None,
|
|
qdrant_collection_name: Optional[str] = None,
|
|
qdrant_quantization_config: Optional[str] = None,
|
|
qdrant_semantic_cache_embedding_model="text-embedding-ada-002",
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Initializes the cache based on the given type.
|
|
|
|
Args:
|
|
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "qdrant-semantic", "s3" or "disk". Defaults to "local".
|
|
|
|
# Redis Cache Args
|
|
host (str, optional): The host address for the Redis cache. Required if type is "redis".
|
|
port (int, optional): The port number for the Redis cache. Required if type is "redis".
|
|
password (str, optional): The password for the Redis cache. Required if type is "redis".
|
|
namespace (str, optional): The namespace for the Redis cache. Required if type is "redis".
|
|
ttl (float, optional): The ttl for the Redis cache
|
|
redis_flush_size (int, optional): The number of keys to flush at a time. Defaults to 1000. Only used if batch redis set caching is used.
|
|
redis_startup_nodes (list, optional): The list of startup nodes for the Redis cache. Defaults to None.
|
|
|
|
# Qdrant Cache Args
|
|
qdrant_api_base (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic".
|
|
qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster.
|
|
qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic".
|
|
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic".
|
|
|
|
# Disk Cache Args
|
|
disk_cache_dir (str, optional): The directory for the disk cache. Defaults to None.
|
|
|
|
# S3 Cache Args
|
|
s3_bucket_name (str, optional): The bucket name for the s3 cache. Defaults to None.
|
|
s3_region_name (str, optional): The region name for the s3 cache. Defaults to None.
|
|
s3_api_version (str, optional): The api version for the s3 cache. Defaults to None.
|
|
s3_use_ssl (bool, optional): The use ssl for the s3 cache. Defaults to True.
|
|
s3_verify (bool, optional): The verify for the s3 cache. Defaults to None.
|
|
s3_endpoint_url (str, optional): The endpoint url for the s3 cache. Defaults to None.
|
|
s3_aws_access_key_id (str, optional): The aws access key id for the s3 cache. Defaults to None.
|
|
s3_aws_secret_access_key (str, optional): The aws secret access key for the s3 cache. Defaults to None.
|
|
s3_aws_session_token (str, optional): The aws session token for the s3 cache. Defaults to None.
|
|
s3_config (dict, optional): The config for the s3 cache. Defaults to None.
|
|
|
|
# Common Cache Args
|
|
supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
|
|
**kwargs: Additional keyword arguments for redis.Redis() cache
|
|
|
|
Raises:
|
|
ValueError: If an invalid cache type is provided.
|
|
|
|
Returns:
|
|
None. Cache is set as a litellm param
|
|
"""
|
|
if type == LiteLLMCacheType.REDIS:
|
|
if redis_startup_nodes:
|
|
self.cache: BaseCache = RedisClusterCache(
|
|
host=host,
|
|
port=port,
|
|
password=password,
|
|
redis_flush_size=redis_flush_size,
|
|
startup_nodes=redis_startup_nodes,
|
|
**kwargs,
|
|
)
|
|
else:
|
|
self.cache = RedisCache(
|
|
host=host,
|
|
port=port,
|
|
password=password,
|
|
redis_flush_size=redis_flush_size,
|
|
**kwargs,
|
|
)
|
|
elif type == LiteLLMCacheType.REDIS_SEMANTIC:
|
|
self.cache = RedisSemanticCache(
|
|
host=host,
|
|
port=port,
|
|
password=password,
|
|
similarity_threshold=similarity_threshold,
|
|
use_async=redis_semantic_cache_use_async,
|
|
embedding_model=redis_semantic_cache_embedding_model,
|
|
**kwargs,
|
|
)
|
|
elif type == LiteLLMCacheType.QDRANT_SEMANTIC:
|
|
self.cache = QdrantSemanticCache(
|
|
qdrant_api_base=qdrant_api_base,
|
|
qdrant_api_key=qdrant_api_key,
|
|
collection_name=qdrant_collection_name,
|
|
similarity_threshold=similarity_threshold,
|
|
quantization_config=qdrant_quantization_config,
|
|
embedding_model=qdrant_semantic_cache_embedding_model,
|
|
)
|
|
elif type == LiteLLMCacheType.LOCAL:
|
|
self.cache = InMemoryCache()
|
|
elif type == LiteLLMCacheType.S3:
|
|
self.cache = S3Cache(
|
|
s3_bucket_name=s3_bucket_name,
|
|
s3_region_name=s3_region_name,
|
|
s3_api_version=s3_api_version,
|
|
s3_use_ssl=s3_use_ssl,
|
|
s3_verify=s3_verify,
|
|
s3_endpoint_url=s3_endpoint_url,
|
|
s3_aws_access_key_id=s3_aws_access_key_id,
|
|
s3_aws_secret_access_key=s3_aws_secret_access_key,
|
|
s3_aws_session_token=s3_aws_session_token,
|
|
s3_config=s3_config,
|
|
s3_path=s3_path,
|
|
**kwargs,
|
|
)
|
|
elif type == LiteLLMCacheType.DISK:
|
|
self.cache = DiskCache(disk_cache_dir=disk_cache_dir)
|
|
if "cache" not in litellm.input_callback:
|
|
litellm.input_callback.append("cache")
|
|
if "cache" not in litellm.success_callback:
|
|
litellm.logging_callback_manager.add_litellm_success_callback("cache")
|
|
if "cache" not in litellm._async_success_callback:
|
|
litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
|
|
self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"]
|
|
self.type = type
|
|
self.namespace = namespace
|
|
self.redis_flush_size = redis_flush_size
|
|
self.ttl = ttl
|
|
self.mode: CacheMode = mode or CacheMode.default_on
|
|
|
|
if self.type == LiteLLMCacheType.LOCAL and default_in_memory_ttl is not None:
|
|
self.ttl = default_in_memory_ttl
|
|
|
|
if (
|
|
self.type == LiteLLMCacheType.REDIS
|
|
or self.type == LiteLLMCacheType.REDIS_SEMANTIC
|
|
) and default_in_redis_ttl is not None:
|
|
self.ttl = default_in_redis_ttl
|
|
|
|
if self.namespace is not None and isinstance(self.cache, RedisCache):
|
|
self.cache.namespace = self.namespace
|
|
|
|
def get_cache_key(self, **kwargs) -> str:
|
|
"""
|
|
Get the cache key for the given arguments.
|
|
|
|
Args:
|
|
**kwargs: kwargs to litellm.completion() or embedding()
|
|
|
|
Returns:
|
|
str: The cache key generated from the arguments, or None if no cache key could be generated.
|
|
"""
|
|
cache_key = ""
|
|
# verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs)
|
|
|
|
preset_cache_key = self._get_preset_cache_key_from_kwargs(**kwargs)
|
|
if preset_cache_key is not None:
|
|
verbose_logger.debug("\nReturning preset cache key: %s", preset_cache_key)
|
|
return preset_cache_key
|
|
|
|
combined_kwargs = ModelParamHelper._get_all_llm_api_params()
|
|
litellm_param_kwargs = all_litellm_params
|
|
for param in kwargs:
|
|
if param in combined_kwargs:
|
|
param_value: Optional[str] = self._get_param_value(param, kwargs)
|
|
if param_value is not None:
|
|
cache_key += f"{str(param)}: {str(param_value)}"
|
|
elif (
|
|
param not in litellm_param_kwargs
|
|
): # check if user passed in optional param - e.g. top_k
|
|
if (
|
|
litellm.enable_caching_on_provider_specific_optional_params is True
|
|
): # feature flagged for now
|
|
if kwargs[param] is None:
|
|
continue # ignore None params
|
|
param_value = kwargs[param]
|
|
cache_key += f"{str(param)}: {str(param_value)}"
|
|
|
|
verbose_logger.debug("\nCreated cache key: %s", cache_key)
|
|
hashed_cache_key = Cache._get_hashed_cache_key(cache_key)
|
|
hashed_cache_key = self._add_namespace_to_cache_key(hashed_cache_key, **kwargs)
|
|
self._set_preset_cache_key_in_kwargs(
|
|
preset_cache_key=hashed_cache_key, **kwargs
|
|
)
|
|
return hashed_cache_key
|
|
|
|
def _get_param_value(
|
|
self,
|
|
param: str,
|
|
kwargs: dict,
|
|
) -> Optional[str]:
|
|
"""
|
|
Get the value for the given param from kwargs
|
|
"""
|
|
if param == "model":
|
|
return self._get_model_param_value(kwargs)
|
|
elif param == "file":
|
|
return self._get_file_param_value(kwargs)
|
|
return kwargs[param]
|
|
|
|
def _get_model_param_value(self, kwargs: dict) -> str:
|
|
"""
|
|
Handles getting the value for the 'model' param from kwargs
|
|
|
|
1. If caching groups are set, then return the caching group as the model https://docs.litellm.ai/docs/routing#caching-across-model-groups
|
|
2. Else if a model_group is set, then return the model_group as the model. This is used for all requests sent through the litellm.Router()
|
|
3. Else use the `model` passed in kwargs
|
|
"""
|
|
metadata: Dict = kwargs.get("metadata", {}) or {}
|
|
litellm_params: Dict = kwargs.get("litellm_params", {}) or {}
|
|
metadata_in_litellm_params: Dict = litellm_params.get("metadata", {}) or {}
|
|
model_group: Optional[str] = metadata.get(
|
|
"model_group"
|
|
) or metadata_in_litellm_params.get("model_group")
|
|
caching_group = self._get_caching_group(metadata, model_group)
|
|
return caching_group or model_group or kwargs["model"]
|
|
|
|
def _get_caching_group(
|
|
self, metadata: dict, model_group: Optional[str]
|
|
) -> Optional[str]:
|
|
caching_groups: Optional[List] = metadata.get("caching_groups", [])
|
|
if caching_groups:
|
|
for group in caching_groups:
|
|
if model_group in group:
|
|
return str(group)
|
|
return None
|
|
|
|
def _get_file_param_value(self, kwargs: dict) -> str:
|
|
"""
|
|
Handles getting the value for the 'file' param from kwargs. Used for `transcription` requests
|
|
"""
|
|
file = kwargs.get("file")
|
|
metadata = kwargs.get("metadata", {})
|
|
litellm_params = kwargs.get("litellm_params", {})
|
|
return (
|
|
metadata.get("file_checksum")
|
|
or getattr(file, "name", None)
|
|
or metadata.get("file_name")
|
|
or litellm_params.get("file_name")
|
|
)
|
|
|
|
def _get_preset_cache_key_from_kwargs(self, **kwargs) -> Optional[str]:
|
|
"""
|
|
Get the preset cache key from kwargs["litellm_params"]
|
|
|
|
We use _get_preset_cache_keys for two reasons
|
|
|
|
1. optional params like max_tokens, get transformed for bedrock -> max_new_tokens
|
|
2. avoid doing duplicate / repeated work
|
|
"""
|
|
if kwargs:
|
|
if "litellm_params" in kwargs:
|
|
return kwargs["litellm_params"].get("preset_cache_key", None)
|
|
return None
|
|
|
|
def _set_preset_cache_key_in_kwargs(self, preset_cache_key: str, **kwargs) -> None:
|
|
"""
|
|
Set the calculated cache key in kwargs
|
|
|
|
This is used to avoid doing duplicate / repeated work
|
|
|
|
Placed in kwargs["litellm_params"]
|
|
"""
|
|
if kwargs:
|
|
if "litellm_params" in kwargs:
|
|
kwargs["litellm_params"]["preset_cache_key"] = preset_cache_key
|
|
|
|
@staticmethod
|
|
def _get_hashed_cache_key(cache_key: str) -> str:
|
|
"""
|
|
Get the hashed cache key for the given cache key.
|
|
|
|
Use hashlib to create a sha256 hash of the cache key
|
|
|
|
Args:
|
|
cache_key (str): The cache key to hash.
|
|
|
|
Returns:
|
|
str: The hashed cache key.
|
|
"""
|
|
hash_object = hashlib.sha256(cache_key.encode())
|
|
# Hexadecimal representation of the hash
|
|
hash_hex = hash_object.hexdigest()
|
|
verbose_logger.debug("Hashed cache key (SHA-256): %s", hash_hex)
|
|
return hash_hex
|
|
|
|
def _add_namespace_to_cache_key(self, hash_hex: str, **kwargs) -> str:
|
|
"""
|
|
If a redis namespace is provided, add it to the cache key
|
|
|
|
Args:
|
|
hash_hex (str): The hashed cache key.
|
|
**kwargs: Additional keyword arguments.
|
|
|
|
Returns:
|
|
str: The final hashed cache key with the redis namespace.
|
|
"""
|
|
dynamic_cache_control: DynamicCacheControl = kwargs.get("cache", {})
|
|
namespace = (
|
|
dynamic_cache_control.get("namespace")
|
|
or kwargs.get("metadata", {}).get("redis_namespace")
|
|
or self.namespace
|
|
)
|
|
if namespace:
|
|
hash_hex = f"{namespace}:{hash_hex}"
|
|
verbose_logger.debug("Final hashed key: %s", hash_hex)
|
|
return hash_hex
|
|
|
|
def generate_streaming_content(self, content):
|
|
chunk_size = 5 # Adjust the chunk size as needed
|
|
for i in range(0, len(content), chunk_size):
|
|
yield {
|
|
"choices": [
|
|
{
|
|
"delta": {
|
|
"role": "assistant",
|
|
"content": content[i : i + chunk_size],
|
|
}
|
|
}
|
|
]
|
|
}
|
|
time.sleep(0.02)
|
|
|
|
def _get_cache_logic(
|
|
self,
|
|
cached_result: Optional[Any],
|
|
max_age: Optional[float],
|
|
):
|
|
"""
|
|
Common get cache logic across sync + async implementations
|
|
"""
|
|
# Check if a timestamp was stored with the cached response
|
|
if (
|
|
cached_result is not None
|
|
and isinstance(cached_result, dict)
|
|
and "timestamp" in cached_result
|
|
):
|
|
timestamp = cached_result["timestamp"]
|
|
current_time = time.time()
|
|
|
|
# Calculate age of the cached response
|
|
response_age = current_time - timestamp
|
|
|
|
# Check if the cached response is older than the max-age
|
|
if max_age is not None and response_age > max_age:
|
|
return None # Cached response is too old
|
|
|
|
# If the response is fresh, or there's no max-age requirement, return the cached response
|
|
# cached_response is in `b{} convert it to ModelResponse
|
|
cached_response = cached_result.get("response")
|
|
try:
|
|
if isinstance(cached_response, dict):
|
|
pass
|
|
else:
|
|
cached_response = json.loads(
|
|
cached_response # type: ignore
|
|
) # Convert string to dictionary
|
|
except Exception:
|
|
cached_response = ast.literal_eval(cached_response) # type: ignore
|
|
return cached_response
|
|
return cached_result
|
|
|
|
def get_cache(self, **kwargs):
|
|
"""
|
|
Retrieves the cached result for the given arguments.
|
|
|
|
Args:
|
|
*args: args to litellm.completion() or embedding()
|
|
**kwargs: kwargs to litellm.completion() or embedding()
|
|
|
|
Returns:
|
|
The cached result if it exists, otherwise None.
|
|
"""
|
|
try: # never block execution
|
|
if self.should_use_cache(**kwargs) is not True:
|
|
return
|
|
messages = kwargs.get("messages", [])
|
|
if "cache_key" in kwargs:
|
|
cache_key = kwargs["cache_key"]
|
|
else:
|
|
cache_key = self.get_cache_key(**kwargs)
|
|
if cache_key is not None:
|
|
cache_control_args: DynamicCacheControl = kwargs.get("cache", {})
|
|
max_age = (
|
|
cache_control_args.get("s-maxage")
|
|
or cache_control_args.get("s-max-age")
|
|
or float("inf")
|
|
)
|
|
cached_result = self.cache.get_cache(cache_key, messages=messages)
|
|
cached_result = self.cache.get_cache(cache_key, messages=messages)
|
|
return self._get_cache_logic(
|
|
cached_result=cached_result, max_age=max_age
|
|
)
|
|
except Exception:
|
|
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
|
return None
|
|
|
|
async def async_get_cache(self, **kwargs):
|
|
"""
|
|
Async get cache implementation.
|
|
|
|
Used for embedding calls in async wrapper
|
|
"""
|
|
|
|
try: # never block execution
|
|
if self.should_use_cache(**kwargs) is not True:
|
|
return
|
|
|
|
kwargs.get("messages", [])
|
|
if "cache_key" in kwargs:
|
|
cache_key = kwargs["cache_key"]
|
|
else:
|
|
cache_key = self.get_cache_key(**kwargs)
|
|
if cache_key is not None:
|
|
cache_control_args = kwargs.get("cache", {})
|
|
max_age = cache_control_args.get(
|
|
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
|
)
|
|
cached_result = await self.cache.async_get_cache(cache_key, **kwargs)
|
|
return self._get_cache_logic(
|
|
cached_result=cached_result, max_age=max_age
|
|
)
|
|
except Exception:
|
|
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
|
return None
|
|
|
|
def _add_cache_logic(self, result, **kwargs):
|
|
"""
|
|
Common implementation across sync + async add_cache functions
|
|
"""
|
|
try:
|
|
if "cache_key" in kwargs:
|
|
cache_key = kwargs["cache_key"]
|
|
else:
|
|
cache_key = self.get_cache_key(**kwargs)
|
|
if cache_key is not None:
|
|
if isinstance(result, BaseModel):
|
|
result = result.model_dump_json()
|
|
|
|
## DEFAULT TTL ##
|
|
if self.ttl is not None:
|
|
kwargs["ttl"] = self.ttl
|
|
## Get Cache-Controls ##
|
|
_cache_kwargs = kwargs.get("cache", None)
|
|
if isinstance(_cache_kwargs, dict):
|
|
for k, v in _cache_kwargs.items():
|
|
if k == "ttl":
|
|
kwargs["ttl"] = v
|
|
|
|
cached_data = {"timestamp": time.time(), "response": result}
|
|
return cache_key, cached_data, kwargs
|
|
else:
|
|
raise Exception("cache key is None")
|
|
except Exception as e:
|
|
raise e
|
|
|
|
def add_cache(self, result, **kwargs):
|
|
"""
|
|
Adds a result to the cache.
|
|
|
|
Args:
|
|
*args: args to litellm.completion() or embedding()
|
|
**kwargs: kwargs to litellm.completion() or embedding()
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
try:
|
|
if self.should_use_cache(**kwargs) is not True:
|
|
return
|
|
cache_key, cached_data, kwargs = self._add_cache_logic(
|
|
result=result, **kwargs
|
|
)
|
|
self.cache.set_cache(cache_key, cached_data, **kwargs)
|
|
except Exception as e:
|
|
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
|
|
|
async def async_add_cache(self, result, **kwargs):
|
|
"""
|
|
Async implementation of add_cache
|
|
"""
|
|
try:
|
|
if self.should_use_cache(**kwargs) is not True:
|
|
return
|
|
if self.type == "redis" and self.redis_flush_size is not None:
|
|
# high traffic - fill in results in memory and then flush
|
|
await self.batch_cache_write(result, **kwargs)
|
|
else:
|
|
cache_key, cached_data, kwargs = self._add_cache_logic(
|
|
result=result, **kwargs
|
|
)
|
|
|
|
await self.cache.async_set_cache(cache_key, cached_data, **kwargs)
|
|
except Exception as e:
|
|
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
|
|
|
async def async_add_cache_pipeline(self, result, **kwargs):
|
|
"""
|
|
Async implementation of add_cache for Embedding calls
|
|
|
|
Does a bulk write, to prevent using too many clients
|
|
"""
|
|
try:
|
|
if self.should_use_cache(**kwargs) is not True:
|
|
return
|
|
|
|
# set default ttl if not set
|
|
if self.ttl is not None:
|
|
kwargs["ttl"] = self.ttl
|
|
|
|
cache_list = []
|
|
for idx, i in enumerate(kwargs["input"]):
|
|
preset_cache_key = self.get_cache_key(**{**kwargs, "input": i})
|
|
kwargs["cache_key"] = preset_cache_key
|
|
embedding_response = result.data[idx]
|
|
cache_key, cached_data, kwargs = self._add_cache_logic(
|
|
result=embedding_response,
|
|
**kwargs,
|
|
)
|
|
cache_list.append((cache_key, cached_data))
|
|
|
|
await self.cache.async_set_cache_pipeline(cache_list=cache_list, **kwargs)
|
|
# if async_set_cache_pipeline:
|
|
# await async_set_cache_pipeline(cache_list=cache_list, **kwargs)
|
|
# else:
|
|
# tasks = []
|
|
# for val in cache_list:
|
|
# tasks.append(self.cache.async_set_cache(val[0], val[1], **kwargs))
|
|
# await asyncio.gather(*tasks)
|
|
except Exception as e:
|
|
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
|
|
|
def should_use_cache(self, **kwargs):
|
|
"""
|
|
Returns true if we should use the cache for LLM API calls
|
|
|
|
If cache is default_on then this is True
|
|
If cache is default_off then this is only true when user has opted in to use cache
|
|
"""
|
|
if self.mode == CacheMode.default_on:
|
|
return True
|
|
|
|
# when mode == default_off -> Cache is opt in only
|
|
_cache = kwargs.get("cache", None)
|
|
verbose_logger.debug("should_use_cache: kwargs: %s; _cache: %s", kwargs, _cache)
|
|
if _cache and isinstance(_cache, dict):
|
|
if _cache.get("use-cache", False) is True:
|
|
return True
|
|
return False
|
|
|
|
async def batch_cache_write(self, result, **kwargs):
|
|
cache_key, cached_data, kwargs = self._add_cache_logic(result=result, **kwargs)
|
|
await self.cache.batch_cache_write(cache_key, cached_data, **kwargs)
|
|
|
|
async def ping(self):
|
|
cache_ping = getattr(self.cache, "ping")
|
|
if cache_ping:
|
|
return await cache_ping()
|
|
return None
|
|
|
|
async def delete_cache_keys(self, keys):
|
|
cache_delete_cache_keys = getattr(self.cache, "delete_cache_keys")
|
|
if cache_delete_cache_keys:
|
|
return await cache_delete_cache_keys(keys)
|
|
return None
|
|
|
|
async def disconnect(self):
|
|
if hasattr(self.cache, "disconnect"):
|
|
await self.cache.disconnect()
|
|
|
|
def _supports_async(self) -> bool:
|
|
"""
|
|
Internal method to check if the cache type supports async get/set operations
|
|
|
|
Only S3 Cache Does NOT support async operations
|
|
|
|
"""
|
|
if self.type and self.type == LiteLLMCacheType.S3:
|
|
return False
|
|
return True
|
|
|
|
|
|
def enable_cache(
|
|
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
|
|
host: Optional[str] = None,
|
|
port: Optional[str] = None,
|
|
password: Optional[str] = None,
|
|
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
|
|
"completion",
|
|
"acompletion",
|
|
"embedding",
|
|
"aembedding",
|
|
"atranscription",
|
|
"transcription",
|
|
"atext_completion",
|
|
"text_completion",
|
|
"arerank",
|
|
"rerank",
|
|
],
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Enable cache with the specified configuration.
|
|
|
|
Args:
|
|
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache to enable. Defaults to "local".
|
|
host (Optional[str]): The host address of the cache server. Defaults to None.
|
|
port (Optional[str]): The port number of the cache server. Defaults to None.
|
|
password (Optional[str]): The password for the cache server. Defaults to None.
|
|
supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]):
|
|
The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"].
|
|
**kwargs: Additional keyword arguments.
|
|
|
|
Returns:
|
|
None
|
|
|
|
Raises:
|
|
None
|
|
"""
|
|
print_verbose("LiteLLM: Enabling Cache")
|
|
if "cache" not in litellm.input_callback:
|
|
litellm.input_callback.append("cache")
|
|
if "cache" not in litellm.success_callback:
|
|
litellm.logging_callback_manager.add_litellm_success_callback("cache")
|
|
if "cache" not in litellm._async_success_callback:
|
|
litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
|
|
|
|
if litellm.cache is None:
|
|
litellm.cache = Cache(
|
|
type=type,
|
|
host=host,
|
|
port=port,
|
|
password=password,
|
|
supported_call_types=supported_call_types,
|
|
**kwargs,
|
|
)
|
|
print_verbose(f"LiteLLM: Cache enabled, litellm.cache={litellm.cache}")
|
|
print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}")
|
|
|
|
|
|
def update_cache(
|
|
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
|
|
host: Optional[str] = None,
|
|
port: Optional[str] = None,
|
|
password: Optional[str] = None,
|
|
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
|
|
"completion",
|
|
"acompletion",
|
|
"embedding",
|
|
"aembedding",
|
|
"atranscription",
|
|
"transcription",
|
|
"atext_completion",
|
|
"text_completion",
|
|
"arerank",
|
|
"rerank",
|
|
],
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Update the cache for LiteLLM.
|
|
|
|
Args:
|
|
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache. Defaults to "local".
|
|
host (Optional[str]): The host of the cache. Defaults to None.
|
|
port (Optional[str]): The port of the cache. Defaults to None.
|
|
password (Optional[str]): The password for the cache. Defaults to None.
|
|
supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]):
|
|
The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"].
|
|
**kwargs: Additional keyword arguments for the cache.
|
|
|
|
Returns:
|
|
None
|
|
|
|
"""
|
|
print_verbose("LiteLLM: Updating Cache")
|
|
litellm.cache = Cache(
|
|
type=type,
|
|
host=host,
|
|
port=port,
|
|
password=password,
|
|
supported_call_types=supported_call_types,
|
|
**kwargs,
|
|
)
|
|
print_verbose(f"LiteLLM: Cache Updated, litellm.cache={litellm.cache}")
|
|
print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}")
|
|
|
|
|
|
def disable_cache():
|
|
"""
|
|
Disable the cache used by LiteLLM.
|
|
|
|
This function disables the cache used by the LiteLLM module. It removes the cache-related callbacks from the input_callback, success_callback, and _async_success_callback lists. It also sets the litellm.cache attribute to None.
|
|
|
|
Parameters:
|
|
None
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
from contextlib import suppress
|
|
|
|
print_verbose("LiteLLM: Disabling Cache")
|
|
with suppress(ValueError):
|
|
litellm.input_callback.remove("cache")
|
|
litellm.success_callback.remove("cache")
|
|
litellm._async_success_callback.remove("cache")
|
|
|
|
litellm.cache = None
|
|
print_verbose(f"LiteLLM: Cache disabled, litellm.cache={litellm.cache}")
|