fix(caching.py): support s-maxage param for cache controls

This commit is contained in:
Krrish Dholakia 2024-01-04 11:41:23 +05:30
parent 4946b1ef6d
commit b0827a87b2
3 changed files with 13 additions and 7 deletions

View file

@ -161,7 +161,7 @@ litellm_settings:
The proxy support 3 cache-controls: The proxy support 3 cache-controls:
- `ttl`: Will cache the response for the user-defined amount of time (in seconds). - `ttl`: Will cache the response for the user-defined amount of time (in seconds).
- `s-max-age`: Will only accept cached responses that are within user-defined range (in seconds). - `s-maxage`: Will only accept cached responses that are within user-defined range (in seconds).
- `no-cache`: Will not return a cached response, but instead call the actual endpoint. - `no-cache`: Will not return a cached response, but instead call the actual endpoint.
[Let us know if you need more](https://github.com/BerriAI/litellm/issues/1218) [Let us know if you need more](https://github.com/BerriAI/litellm/issues/1218)
@ -237,7 +237,7 @@ chat_completion = client.chat.completions.create(
], ],
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
cache={ cache={
"s-max-age": 600 # only get responses cached within last 10 minutes "s-maxage": 600 # only get responses cached within last 10 minutes
} }
) )
``` ```

View file

@ -11,6 +11,7 @@ import litellm
import time, logging import time, logging
import json, traceback, ast, hashlib import json, traceback, ast, hashlib
from typing import Optional, Literal, List, Union, Any from typing import Optional, Literal, List, Union, Any
from openai._models import BaseModel as OpenAIObject
def print_verbose(print_statement): def print_verbose(print_statement):
@ -472,7 +473,10 @@ class Cache:
else: else:
cache_key = self.get_cache_key(*args, **kwargs) cache_key = self.get_cache_key(*args, **kwargs)
if cache_key is not None: if cache_key is not None:
max_age = kwargs.get("cache", {}).get("s-max-age", float("inf")) cache_control_args = kwargs.get("cache", {})
max_age = cache_control_args.get(
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
)
cached_result = self.cache.get_cache(cache_key) cached_result = self.cache.get_cache(cache_key)
# Check if a timestamp was stored with the cached response # Check if a timestamp was stored with the cached response
if ( if (
@ -529,7 +533,7 @@ class Cache:
else: else:
cache_key = self.get_cache_key(*args, **kwargs) cache_key = self.get_cache_key(*args, **kwargs)
if cache_key is not None: if cache_key is not None:
if isinstance(result, litellm.ModelResponse): if isinstance(result, OpenAIObject):
result = result.model_dump_json() result = result.model_dump_json()
## Get Cache-Controls ## ## Get Cache-Controls ##

View file

@ -91,7 +91,7 @@ def test_caching_with_cache_controls():
model="gpt-3.5-turbo", messages=messages, cache={"ttl": 0} model="gpt-3.5-turbo", messages=messages, cache={"ttl": 0}
) )
response2 = completion( response2 = completion(
model="gpt-3.5-turbo", messages=messages, cache={"s-max-age": 10} model="gpt-3.5-turbo", messages=messages, cache={"s-maxage": 10}
) )
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response2: {response2}") print(f"response2: {response2}")
@ -105,7 +105,7 @@ def test_caching_with_cache_controls():
model="gpt-3.5-turbo", messages=messages, cache={"ttl": 5} model="gpt-3.5-turbo", messages=messages, cache={"ttl": 5}
) )
response2 = completion( response2 = completion(
model="gpt-3.5-turbo", messages=messages, cache={"s-max-age": 5} model="gpt-3.5-turbo", messages=messages, cache={"s-maxage": 5}
) )
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response2: {response2}") print(f"response2: {response2}")
@ -167,6 +167,8 @@ small text
def test_embedding_caching(): def test_embedding_caching():
import time import time
# litellm.set_verbose = True
litellm.cache = Cache() litellm.cache = Cache()
text_to_embed = [embedding_large_text] text_to_embed = [embedding_large_text]
start_time = time.time() start_time = time.time()
@ -182,7 +184,7 @@ def test_embedding_caching():
model="text-embedding-ada-002", input=text_to_embed, caching=True model="text-embedding-ada-002", input=text_to_embed, caching=True
) )
end_time = time.time() end_time = time.time()
print(f"embedding2: {embedding2}") # print(f"embedding2: {embedding2}")
print(f"Embedding 2 response time: {end_time - start_time} seconds") print(f"Embedding 2 response time: {end_time - start_time} seconds")
litellm.cache = None litellm.cache = None