From 9ee16bc9623e18b5be7362050f9df28717815764 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Thu, 14 Dec 2023 22:27:14 +0530 Subject: [PATCH] (feat) caching - add supported call types --- litellm/__init__.py | 2 +- litellm/caching.py | 21 ++++++++++++--------- litellm/router.py | 4 ++-- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index d5c29e7a7..7c6864eac 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -10,7 +10,7 @@ success_callback: List[Union[str, Callable]] = [] failure_callback: List[Union[str, Callable]] = [] callbacks: List[Callable] = [] _async_input_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here. -_async_success_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here. +_async_success_callback: List[Union[str, Callable]] = [] # internal variable - async custom callbacks are routed here. _async_failure_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here. pre_call_rules: List[Callable] = [] post_call_rules: List[Callable] = [] diff --git a/litellm/caching.py b/litellm/caching.py index 2ff4f0c82..f4787c7f1 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -10,7 +10,7 @@ import litellm import time, logging import json, traceback, ast -from typing import Optional +from typing import Optional, Literal def print_verbose(print_statement): try: @@ -162,34 +162,36 @@ class DualCache(BaseCache): if self.redis_cache is not None: self.redis_cache.flush_cache() -#### LiteLLM.Completion Cache #### +#### LiteLLM.Completion / Embedding Cache #### class Cache: def __init__( self, - type="local", - host=None, - port=None, - password=None, + type: Optional[Literal["local", "redis"]] = "local", + host: Optional[str] = None, + port: Optional[str] = None, + password: Optional[str] = None, + supported_call_types: Optional[list[Literal["completion", "acompletion", "embedding", "aembedding"]]] = ["completion", "acompletion", "embedding", "aembedding"], **kwargs ): """ Initializes the cache based on the given type. Args: - type (str, optional): The type of cache to initialize. Defaults to "local". + type (str, optional): The type of cache to initialize. Can be "local" or "redis". Defaults to "local". host (str, optional): The host address for the Redis cache. Required if type is "redis". port (int, optional): The port number for the Redis cache. Required if type is "redis". password (str, optional): The password for the Redis cache. Required if type is "redis". + supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types. **kwargs: Additional keyword arguments for redis.Redis() cache Raises: ValueError: If an invalid cache type is provided. Returns: - None + None. Cache is set as a litellm param """ if type == "redis": - self.cache = RedisCache(host, port, password, **kwargs) + self.cache: BaseCache = RedisCache(host, port, password, **kwargs) if type == "local": self.cache = InMemoryCache() if "cache" not in litellm.input_callback: @@ -198,6 +200,7 @@ class Cache: litellm.success_callback.append("cache") if "cache" not in litellm._async_success_callback: litellm._async_success_callback.append("cache") + self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"] def get_cache_key(self, *args, **kwargs): """ diff --git a/litellm/router.py b/litellm/router.py index 64b88e7e4..6a4d04815 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -114,7 +114,7 @@ class Router: self.default_litellm_params.setdefault("max_retries", 0) ### CACHING ### - cache_type = "local" # default to an in-memory cache + cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache redis_cache = None cache_config = {} if redis_url is not None or (redis_host is not None and redis_port is not None and redis_password is not None): @@ -138,7 +138,7 @@ class Router: if cache_responses: if litellm.cache is None: # the cache can be initialized on the proxy server. We should not overwrite it - litellm.cache = litellm.Cache(type=cache_type, **cache_config) + litellm.cache = litellm.Cache(type=cache_type, **cache_config) # type: ignore self.cache_responses = cache_responses self.cache = DualCache(redis_cache=redis_cache, in_memory_cache=InMemoryCache()) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc. ### ROUTING SETUP ###