diff --git a/docs/my-website/docs/caching/redis_cache.md b/docs/my-website/docs/caching/all_caches.md similarity index 80% rename from docs/my-website/docs/caching/redis_cache.md rename to docs/my-website/docs/caching/all_caches.md index b00a118c1..eb309f9b8 100644 --- a/docs/my-website/docs/caching/redis_cache.md +++ b/docs/my-website/docs/caching/all_caches.md @@ -1,7 +1,7 @@ import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; -# Caching - In-Memory, Redis, s3, Redis Semantic Cache +# Caching - In-Memory, Redis, s3, Redis Semantic Cache, Disk [**See Code**](https://github.com/BerriAI/litellm/blob/main/litellm/caching.py) @@ -11,7 +11,7 @@ Need to use Caching on LiteLLM Proxy Server? Doc here: [Caching Proxy Server](ht ::: -## Initialize Cache - In Memory, Redis, s3 Bucket, Redis Semantic Cache +## Initialize Cache - In Memory, Redis, s3 Bucket, Redis Semantic, Disk Cache @@ -159,7 +159,7 @@ litellm.cache = Cache() # Make completion calls response1 = completion( model="gpt-3.5-turbo", - messages=[{"role": "user", "content": "Tell me a joke."}] + messages=[{"role": "user", "content": "Tell me a joke."}], caching=True ) response2 = completion( @@ -174,6 +174,43 @@ response2 = completion( + + +### Quick Start + +Install diskcache: + +```shell +pip install diskcache +``` + +Then you can use the disk cache as follows. + +```python +import litellm +from litellm import completion +from litellm.caching import Cache +litellm.cache = Cache(type="disk") + +# Make completion calls +response1 = completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Tell me a joke."}], + caching=True +) +response2 = completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Tell me a joke."}], + caching=True +) + +# response1 == response2, response 1 is cached + +``` + +If you run the code two times, response1 will use the cache from the first run that was stored in a cache file. + + @@ -191,13 +228,13 @@ Advanced Params ```python litellm.enable_cache( - type: Optional[Literal["local", "redis"]] = "local", + type: Optional[Literal["local", "redis", "s3", "disk"]] = "local", host: Optional[str] = None, port: Optional[str] = None, password: Optional[str] = None, supported_call_types: Optional[ - List[Literal["completion", "acompletion", "embedding", "aembedding"]] - ] = ["completion", "acompletion", "embedding", "aembedding"], + List[Literal["completion", "acompletion", "embedding", "aembedding", "atranscription", "transcription"]] + ] = ["completion", "acompletion", "embedding", "aembedding", "atranscription", "transcription"], **kwargs, ) ``` @@ -215,13 +252,13 @@ Update the Cache params ```python litellm.update_cache( - type: Optional[Literal["local", "redis"]] = "local", + type: Optional[Literal["local", "redis", "s3", "disk"]] = "local", host: Optional[str] = None, port: Optional[str] = None, password: Optional[str] = None, supported_call_types: Optional[ - List[Literal["completion", "acompletion", "embedding", "aembedding"]] - ] = ["completion", "acompletion", "embedding", "aembedding"], + List[Literal["completion", "acompletion", "embedding", "aembedding", "atranscription", "transcription"]] + ] = ["completion", "acompletion", "embedding", "aembedding", "atranscription", "transcription"], **kwargs, ) ``` @@ -276,22 +313,29 @@ cache.get_cache = get_cache ```python def __init__( self, - type: Optional[Literal["local", "redis", "s3"]] = "local", + type: Optional[Literal["local", "redis", "redis-semantic", "s3", "disk"]] = "local", supported_call_types: Optional[ - List[Literal["completion", "acompletion", "embedding", "aembedding"]] - ] = ["completion", "acompletion", "embedding", "aembedding"], # A list of litellm call types to cache for. Defaults to caching for all litellm call types. - + List[Literal["completion", "acompletion", "embedding", "aembedding", "atranscription", "transcription"]] + ] = ["completion", "acompletion", "embedding", "aembedding", "atranscription", "transcription"], + ttl: Optional[float] = None, + default_in_memory_ttl: Optional[float] = None, + # redis cache params host: Optional[str] = None, port: Optional[str] = None, password: Optional[str] = None, - + namespace: Optional[str] = None, + default_in_redis_ttl: Optional[float] = None, + similarity_threshold: Optional[float] = None, + redis_semantic_cache_use_async=False, + redis_semantic_cache_embedding_model="text-embedding-ada-002", + redis_flush_size=None, # s3 Bucket, boto3 configuration s3_bucket_name: Optional[str] = None, s3_region_name: Optional[str] = None, s3_api_version: Optional[str] = None, - s3_path: Optional[str] = None, # if you wish to save to a spefic path + s3_path: Optional[str] = None, # if you wish to save to a specific path s3_use_ssl: Optional[bool] = True, s3_verify: Optional[Union[bool, str]] = None, s3_endpoint_url: Optional[str] = None, @@ -299,7 +343,11 @@ def __init__( s3_aws_secret_access_key: Optional[str] = None, s3_aws_session_token: Optional[str] = None, s3_config: Optional[Any] = None, - **kwargs, + + # disk cache params + disk_cache_dir=None, + + **kwargs ): ``` diff --git a/docs/my-website/docs/caching/local_caching.md b/docs/my-website/docs/caching/local_caching.md index d0e26e4bf..81c4edcb8 100644 --- a/docs/my-website/docs/caching/local_caching.md +++ b/docs/my-website/docs/caching/local_caching.md @@ -40,7 +40,7 @@ cache = Cache() cache.add_cache(cache_key="test-key", result="1234") -cache.get_cache(cache_key="test-key) +cache.get_cache(cache_key="test-key") ``` ## Caching with Streaming diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 2deca9258..62202cc7e 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -189,7 +189,7 @@ const sidebars = { `observability/telemetry`, ], }, - "caching/redis_cache", + "caching/all_caches", { type: "category", label: "Tutorials", diff --git a/litellm/caching.py b/litellm/caching.py index ccb62b882..8c9157e53 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -1441,7 +1441,7 @@ class DualCache(BaseCache): class Cache: def __init__( self, - type: Optional[Literal["local", "redis", "redis-semantic", "s3"]] = "local", + type: Optional[Literal["local", "redis", "redis-semantic", "s3", "disk"]] = "local", host: Optional[str] = None, port: Optional[str] = None, password: Optional[str] = None, @@ -1484,13 +1484,14 @@ class Cache: redis_semantic_cache_use_async=False, redis_semantic_cache_embedding_model="text-embedding-ada-002", redis_flush_size=None, + disk_cache_dir=None, **kwargs, ): """ Initializes the cache based on the given type. Args: - type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", or "s3". Defaults to "local". + type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "s3" or "disk". Defaults to "local". host (str, optional): The host address for the Redis cache. Required if type is "redis". port (int, optional): The port number for the Redis cache. Required if type is "redis". password (str, optional): The password for the Redis cache. Required if type is "redis". @@ -1536,6 +1537,8 @@ class Cache: s3_path=s3_path, **kwargs, ) + elif type == "disk": + self.cache = DiskCache(disk_cache_dir=disk_cache_dir) if "cache" not in litellm.input_callback: litellm.input_callback.append("cache") if "cache" not in litellm.success_callback: @@ -1907,8 +1910,86 @@ class Cache: await self.cache.disconnect() +class DiskCache(BaseCache): + def __init__(self, disk_cache_dir: Optional[str] = None): + import diskcache as dc + + # if users don't provider one, use the default litellm cache + if disk_cache_dir is None: + self.disk_cache = dc.Cache(".litellm_cache") + else: + self.disk_cache = dc.Cache(disk_cache_dir) + + def set_cache(self, key, value, **kwargs): + print_verbose("DiskCache: set_cache") + if "ttl" in kwargs: + self.disk_cache.set(key, value, expire=kwargs["ttl"]) + else: + self.disk_cache.set(key, value) + + async def async_set_cache(self, key, value, **kwargs): + self.set_cache(key=key, value=value, **kwargs) + + async def async_set_cache_pipeline(self, cache_list, ttl=None): + for cache_key, cache_value in cache_list: + if ttl is not None: + self.set_cache(key=cache_key, value=cache_value, ttl=ttl) + else: + self.set_cache(key=cache_key, value=cache_value) + + def get_cache(self, key, **kwargs): + original_cached_response = self.disk_cache.get(key) + if original_cached_response: + try: + cached_response = json.loads(original_cached_response) + except: + cached_response = original_cached_response + return cached_response + return None + + def batch_get_cache(self, keys: list, **kwargs): + return_val = [] + for k in keys: + val = self.get_cache(key=k, **kwargs) + return_val.append(val) + return return_val + + def increment_cache(self, key, value: int, **kwargs) -> int: + # get the value + init_value = self.get_cache(key=key) or 0 + value = init_value + value + self.set_cache(key, value, **kwargs) + return value + + async def async_get_cache(self, key, **kwargs): + return self.get_cache(key=key, **kwargs) + + async def async_batch_get_cache(self, keys: list, **kwargs): + return_val = [] + for k in keys: + val = self.get_cache(key=k, **kwargs) + return_val.append(val) + return return_val + + async def async_increment(self, key, value: int, **kwargs) -> int: + # get the value + init_value = await self.async_get_cache(key=key) or 0 + value = init_value + value + await self.async_set_cache(key, value, **kwargs) + return value + + def flush_cache(self): + self.disk_cache.clear() + + async def disconnect(self): + pass + + def delete_cache(self, key): + self.disk_cache.pop(key) + + def enable_cache( - type: Optional[Literal["local", "redis", "s3"]] = "local", + type: Optional[Literal["local", "redis", "s3", "disk"]] = "local", host: Optional[str] = None, port: Optional[str] = None, password: Optional[str] = None, @@ -1937,7 +2018,7 @@ def enable_cache( Enable cache with the specified configuration. Args: - type (Optional[Literal["local", "redis"]]): The type of cache to enable. Defaults to "local". + type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache to enable. Defaults to "local". host (Optional[str]): The host address of the cache server. Defaults to None. port (Optional[str]): The port number of the cache server. Defaults to None. password (Optional[str]): The password for the cache server. Defaults to None. @@ -1973,7 +2054,7 @@ def enable_cache( def update_cache( - type: Optional[Literal["local", "redis"]] = "local", + type: Optional[Literal["local", "redis", "s3", "disk"]] = "local", host: Optional[str] = None, port: Optional[str] = None, password: Optional[str] = None, @@ -2002,7 +2083,7 @@ def update_cache( Update the cache for LiteLLM. Args: - type (Optional[Literal["local", "redis"]]): The type of cache. Defaults to "local". + type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache. Defaults to "local". host (Optional[str]): The host of the cache. Defaults to None. port (Optional[str]): The port of the cache. Defaults to None. password (Optional[str]): The password for the cache. Defaults to None. diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index 903ce69c7..2f0f1dbfe 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -599,7 +599,10 @@ def test_redis_cache_completion(): ) print("test2 for Redis Caching - non streaming") response1 = completion( - model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20 + model="gpt-3.5-turbo", + messages=messages, + caching=True, + max_tokens=20, ) response2 = completion( model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20 @@ -653,7 +656,6 @@ def test_redis_cache_completion(): assert response1.created == response2.created assert response1.choices[0].message.content == response2.choices[0].message.content - # test_redis_cache_completion() @@ -875,6 +877,80 @@ async def test_redis_cache_acompletion_stream_bedrock(): print(e) raise e +def test_disk_cache_completion(): + litellm.set_verbose = False + + random_number = random.randint( + 1, 100000 + ) # add a random number to ensure it's always adding / reading from cache + messages = [ + {"role": "user", "content": f"write a one sentence poem about: {random_number}"} + ] + litellm.cache = Cache( + type="disk", + ) + + response1 = completion( + model="gpt-3.5-turbo", + messages=messages, + caching=True, + max_tokens=20, + mock_response="This number is so great!", + ) + # response2 is mocked to a different response from response1, + # but the completion from the cache should be used instead of the mock + # response since the input is the same as response1 + response2 = completion( + model="gpt-3.5-turbo", + messages=messages, + caching=True, + max_tokens=20, + mock_response="This number is awful!", + ) + # Since the parameters are not the same as response1, response3 should actually + # be the mock response + response3 = completion( + model="gpt-3.5-turbo", + messages=messages, + caching=True, + temperature=0.5, + mock_response="This number is awful!", + ) + + print("\nresponse 1", response1) + print("\nresponse 2", response2) + print("\nresponse 3", response3) + # print("\nresponse 4", response4) + litellm.cache = None + litellm.success_callback = [] + litellm._async_success_callback = [] + + # 1 & 2 should be exactly the same + # 1 & 3 should be different, since input params are diff + if ( + response1["choices"][0]["message"]["content"] + != response2["choices"][0]["message"]["content"] + ): # 1 and 2 should be the same + # 1&2 have the exact same input params. This MUST Be a CACHE HIT + print(f"response1: {response1}") + print(f"response2: {response2}") + pytest.fail(f"Error occurred:") + if ( + response1["choices"][0]["message"]["content"] + == response3["choices"][0]["message"]["content"] + ): + # if input params like max_tokens, temperature are diff it should NOT be a cache hit + print(f"response1: {response1}") + print(f"response3: {response3}") + pytest.fail( + f"Response 1 == response 3. Same model, diff params shoudl not cache Error" + f" occurred:" + ) + + assert response1.id == response2.id + assert response1.created == response2.created + assert response1.choices[0].message.content == response2.choices[0].message.content + @pytest.mark.skip(reason="AWS Suspended Account") @pytest.mark.asyncio diff --git a/pyproject.toml b/pyproject.toml index fe88b0699..49689da78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,6 @@ extra_proxy = [ "resend" ] - [tool.poetry.scripts] litellm = 'litellm:run_server'