mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix redis caching
This commit is contained in:
parent
3726270d95
commit
8f7f9ca932
3 changed files with 37 additions and 4 deletions
|
@ -1,5 +1,7 @@
|
||||||
import litellm
|
import litellm
|
||||||
import time
|
import time
|
||||||
|
import json
|
||||||
|
|
||||||
def get_prompt(*args, **kwargs):
|
def get_prompt(*args, **kwargs):
|
||||||
# make this safe checks, it should not throw any exceptions
|
# make this safe checks, it should not throw any exceptions
|
||||||
if len(args) > 1:
|
if len(args) > 1:
|
||||||
|
@ -22,8 +24,14 @@ class RedisCache():
|
||||||
self.redis_client.set(key, str(value))
|
self.redis_client.set(key, str(value))
|
||||||
|
|
||||||
def get_cache(self, key):
|
def get_cache(self, key):
|
||||||
# TODO convert this to a ModelResponse object
|
# TODO convert this to a ModelResponse object
|
||||||
return self.redis_client.get(key)
|
cached_response = self.redis_client.get(key)
|
||||||
|
if cached_response!=None:
|
||||||
|
# cached_response is in `b{} convert it to ModelResponse
|
||||||
|
cached_response = cached_response.decode("utf-8") # Convert bytes to string
|
||||||
|
cached_response = json.loads(cached_response) # Convert string to dictionary
|
||||||
|
return cached_response
|
||||||
|
|
||||||
|
|
||||||
class InMemoryCache():
|
class InMemoryCache():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -46,7 +54,7 @@ class InMemoryCache():
|
||||||
class Cache():
|
class Cache():
|
||||||
def __init__(self, type="local", host="", port="", password=""):
|
def __init__(self, type="local", host="", port="", password=""):
|
||||||
if type == "redis":
|
if type == "redis":
|
||||||
self.cache = RedisCache(type, host, port, password)
|
self.cache = RedisCache(host, port, password)
|
||||||
if type == "local":
|
if type == "local":
|
||||||
self.cache = InMemoryCache()
|
self.cache = InMemoryCache()
|
||||||
if "cache" not in litellm.input_callback:
|
if "cache" not in litellm.input_callback:
|
||||||
|
|
|
@ -229,3 +229,28 @@ def test_caching_v2_stream():
|
||||||
# test_caching_v2_stream()
|
# test_caching_v2_stream()
|
||||||
|
|
||||||
|
|
||||||
|
def test_redis_cache_completion():
|
||||||
|
messages = [{"role": "user", "content": "who is ishaan CTO of litellm from litellm 2023"}]
|
||||||
|
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
||||||
|
print("test2 for caching")
|
||||||
|
response1 = completion(model="gpt-3.5-turbo", messages=messages)
|
||||||
|
response2 = completion(model="gpt-3.5-turbo", messages=messages)
|
||||||
|
response3 = completion(model="command-nightly", messages=messages)
|
||||||
|
print(f"response1: {response1}")
|
||||||
|
print(f"response2: {response2}")
|
||||||
|
print(f"response3: {response3}")
|
||||||
|
litellm.cache = None
|
||||||
|
if response3 == response2:
|
||||||
|
# if models are different, it should not return cached response
|
||||||
|
print(f"response2: {response2}")
|
||||||
|
print(f"response3: {response3}")
|
||||||
|
pytest.fail(f"Error occurred:")
|
||||||
|
if response1 != response2: # 1 and 2 should be the same
|
||||||
|
print(f"response1: {response1}")
|
||||||
|
print(f"response2: {response2}")
|
||||||
|
pytest.fail(f"Error occurred:")
|
||||||
|
|
||||||
|
# test_redis_cache_completion()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "0.1.498"
|
version = "0.1.499"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT License"
|
license = "MIT License"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue