(feat) caching - add supported call types

This commit is contained in:
ishaan-jaff 2023-12-14 22:27:14 +05:30
parent 67518387f1
commit 9ee16bc962
3 changed files with 15 additions and 12 deletions

View file

@ -10,7 +10,7 @@ success_callback: List[Union[str, Callable]] = []
failure_callback: List[Union[str, Callable]] = [] failure_callback: List[Union[str, Callable]] = []
callbacks: List[Callable] = [] callbacks: List[Callable] = []
_async_input_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here. _async_input_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
_async_success_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here. _async_success_callback: List[Union[str, Callable]] = [] # internal variable - async custom callbacks are routed here.
_async_failure_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here. _async_failure_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
pre_call_rules: List[Callable] = [] pre_call_rules: List[Callable] = []
post_call_rules: List[Callable] = [] post_call_rules: List[Callable] = []

View file

@ -10,7 +10,7 @@
import litellm import litellm
import time, logging import time, logging
import json, traceback, ast import json, traceback, ast
from typing import Optional from typing import Optional, Literal
def print_verbose(print_statement): def print_verbose(print_statement):
try: try:
@ -162,34 +162,36 @@ class DualCache(BaseCache):
if self.redis_cache is not None: if self.redis_cache is not None:
self.redis_cache.flush_cache() self.redis_cache.flush_cache()
#### LiteLLM.Completion Cache #### #### LiteLLM.Completion / Embedding Cache ####
class Cache: class Cache:
def __init__( def __init__(
self, self,
type="local", type: Optional[Literal["local", "redis"]] = "local",
host=None, host: Optional[str] = None,
port=None, port: Optional[str] = None,
password=None, password: Optional[str] = None,
supported_call_types: Optional[list[Literal["completion", "acompletion", "embedding", "aembedding"]]] = ["completion", "acompletion", "embedding", "aembedding"],
**kwargs **kwargs
): ):
""" """
Initializes the cache based on the given type. Initializes the cache based on the given type.
Args: Args:
type (str, optional): The type of cache to initialize. Defaults to "local". type (str, optional): The type of cache to initialize. Can be "local" or "redis". Defaults to "local".
host (str, optional): The host address for the Redis cache. Required if type is "redis". host (str, optional): The host address for the Redis cache. Required if type is "redis".
port (int, optional): The port number for the Redis cache. Required if type is "redis". port (int, optional): The port number for the Redis cache. Required if type is "redis".
password (str, optional): The password for the Redis cache. Required if type is "redis". password (str, optional): The password for the Redis cache. Required if type is "redis".
supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
**kwargs: Additional keyword arguments for redis.Redis() cache **kwargs: Additional keyword arguments for redis.Redis() cache
Raises: Raises:
ValueError: If an invalid cache type is provided. ValueError: If an invalid cache type is provided.
Returns: Returns:
None None. Cache is set as a litellm param
""" """
if type == "redis": if type == "redis":
self.cache = RedisCache(host, port, password, **kwargs) self.cache: BaseCache = RedisCache(host, port, password, **kwargs)
if type == "local": if type == "local":
self.cache = InMemoryCache() self.cache = InMemoryCache()
if "cache" not in litellm.input_callback: if "cache" not in litellm.input_callback:
@ -198,6 +200,7 @@ class Cache:
litellm.success_callback.append("cache") litellm.success_callback.append("cache")
if "cache" not in litellm._async_success_callback: if "cache" not in litellm._async_success_callback:
litellm._async_success_callback.append("cache") litellm._async_success_callback.append("cache")
self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"]
def get_cache_key(self, *args, **kwargs): def get_cache_key(self, *args, **kwargs):
""" """

View file

@ -114,7 +114,7 @@ class Router:
self.default_litellm_params.setdefault("max_retries", 0) self.default_litellm_params.setdefault("max_retries", 0)
### CACHING ### ### CACHING ###
cache_type = "local" # default to an in-memory cache cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache
redis_cache = None redis_cache = None
cache_config = {} cache_config = {}
if redis_url is not None or (redis_host is not None and redis_port is not None and redis_password is not None): if redis_url is not None or (redis_host is not None and redis_port is not None and redis_password is not None):
@ -138,7 +138,7 @@ class Router:
if cache_responses: if cache_responses:
if litellm.cache is None: if litellm.cache is None:
# the cache can be initialized on the proxy server. We should not overwrite it # the cache can be initialized on the proxy server. We should not overwrite it
litellm.cache = litellm.Cache(type=cache_type, **cache_config) litellm.cache = litellm.Cache(type=cache_type, **cache_config) # type: ignore
self.cache_responses = cache_responses self.cache_responses = cache_responses
self.cache = DualCache(redis_cache=redis_cache, in_memory_cache=InMemoryCache()) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc. self.cache = DualCache(redis_cache=redis_cache, in_memory_cache=InMemoryCache()) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
### ROUTING SETUP ### ### ROUTING SETUP ###