fix redis caching

This commit is contained in:
ishaan-jaff 2023-08-28 22:10:15 -07:00
parent 3726270d95
commit 8f7f9ca932
3 changed files with 37 additions and 4 deletions

View file

@ -1,5 +1,7 @@
import litellm
import time
import json
def get_prompt(*args, **kwargs):
# make this safe checks, it should not throw any exceptions
if len(args) > 1:
@ -23,7 +25,13 @@ class RedisCache():
def get_cache(self, key):
# TODO convert this to a ModelResponse object
return self.redis_client.get(key)
cached_response = self.redis_client.get(key)
if cached_response!=None:
# cached_response is in `b{} convert it to ModelResponse
cached_response = cached_response.decode("utf-8") # Convert bytes to string
cached_response = json.loads(cached_response) # Convert string to dictionary
return cached_response
class InMemoryCache():
def __init__(self):
@ -46,7 +54,7 @@ class InMemoryCache():
class Cache():
def __init__(self, type="local", host="", port="", password=""):
if type == "redis":
self.cache = RedisCache(type, host, port, password)
self.cache = RedisCache(host, port, password)
if type == "local":
self.cache = InMemoryCache()
if "cache" not in litellm.input_callback:

View file

@ -229,3 +229,28 @@ def test_caching_v2_stream():
# test_caching_v2_stream()
def test_redis_cache_completion():
messages = [{"role": "user", "content": "who is ishaan CTO of litellm from litellm 2023"}]
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
print("test2 for caching")
response1 = completion(model="gpt-3.5-turbo", messages=messages)
response2 = completion(model="gpt-3.5-turbo", messages=messages)
response3 = completion(model="command-nightly", messages=messages)
print(f"response1: {response1}")
print(f"response2: {response2}")
print(f"response3: {response3}")
litellm.cache = None
if response3 == response2:
# if models are different, it should not return cached response
print(f"response2: {response2}")
print(f"response3: {response3}")
pytest.fail(f"Error occurred:")
if response1 != response2: # 1 and 2 should be the same
print(f"response1: {response1}")
print(f"response2: {response2}")
pytest.fail(f"Error occurred:")
# test_redis_cache_completion()

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm"
version = "0.1.498"
version = "0.1.499"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT License"