diff --git a/litellm/caching.py b/litellm/caching.py index 6c2ec0356..7126f2e83 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -220,6 +220,7 @@ class Cache: if "cache" not in litellm._async_success_callback: litellm._async_success_callback.append("cache") self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"] + self.type = type def get_cache_key(self, *args, **kwargs): """ @@ -374,3 +375,59 @@ class Cache: async def _async_add_cache(self, result, *args, **kwargs): self.add_cache(result, *args, **kwargs) + + +def enable_cache( + type: Optional[Literal["local", "redis"]] = "local", + host: Optional[str] = None, + port: Optional[str] = None, + password: Optional[str] = None, + supported_call_types: Optional[ + List[Literal["completion", "acompletion", "embedding", "aembedding"]] + ] = ["completion", "acompletion", "embedding", "aembedding"], + **kwargs, +): + """ + Enable caching. + + :param attach_cache: If True, attach the cache to litellm.cache + :return: None + """ + print_verbose("LiteLLM: Enabling Cache") + if "cache" not in litellm.input_callback: + litellm.input_callback.append("cache") + if "cache" not in litellm.success_callback: + litellm.success_callback.append("cache") + if "cache" not in litellm._async_success_callback: + litellm._async_success_callback.append("cache") + + if litellm.cache == None: + litellm.cache = Cache( + type=type, + host=host, + port=port, + password=password, + supported_call_types=supported_call_types, + **kwargs, + ) + print_verbose(f"LiteLLM: Cache enabled, litellm.cache={litellm.cache}") + print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}") + + +def disable_cache(): + from contextlib import suppress + + """ + Disable caching. + + :param detach_cache: If True, detach the cache from litellm.cache + :return: None + """ + print_verbose("LiteLLM: Disabling Cache") + with suppress(ValueError): + litellm.input_callback.remove("cache") + litellm.success_callback.remove("cache") + litellm._async_success_callback.remove("cache") + + litellm.cache = None + print_verbose(f"LiteLLM: Cache disabled, litellm.cache={litellm.cache}") diff --git a/litellm/main.py b/litellm/main.py index 3b13e717a..50f39e549 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -71,6 +71,7 @@ from .llms.prompt_templates.factory import ( import tiktoken from concurrent.futures import ThreadPoolExecutor from typing import Callable, List, Optional, Dict, Union, Mapping +from .caching import enable_cache, disable_cache encoding = tiktoken.get_encoding("cl100k_base") from litellm.utils import (