fix(caching.py): fixing pr issues

This commit is contained in:
Krrish Dholakia 2023-10-31 18:32:31 -07:00
parent c0e6596395
commit 6ead8d8c18
2 changed files with 99 additions and 28 deletions

View file

@ -9,7 +9,7 @@
import litellm
import time
import json
import json, traceback
def get_prompt(*args, **kwargs):
@ -37,6 +37,7 @@ class RedisCache(BaseCache):
def __init__(self, host, port, password):
import redis
# if users don't provider one, use the default litellm cache
print(f"HOST: {host}; PORT: {port}; PASSWORD: {password}")
self.redis_client = redis.Redis(host=host, port=port, password=password)
def set_cache(self, key, value, **kwargs):
@ -45,7 +46,7 @@ class RedisCache(BaseCache):
self.redis_client.set(name=key, value=str(value), ex=ttl)
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
print("LiteLLM Caching: Got exception from REDIS: ", e)
print("LiteLLM Caching: set() - Got exception from REDIS : ", e)
def get_cache(self, key, **kwargs):
try:
@ -59,7 +60,8 @@ class RedisCache(BaseCache):
return cached_response
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
print("LiteLLM Caching: Got exception from REDIS: ", e)
traceback.print_exc()
print("LiteLLM Caching: get() - Got exception from REDIS: ", e)
class HostedCache(BaseCache):
@ -104,7 +106,13 @@ class InMemoryCache(BaseCache):
if time.time() > self.ttl_dict[key]:
self.cache_dict.pop(key, None)
return None
return self.cache_dict[key]
original_cached_response = self.cache_dict[key]
try:
cached_response = json.loads(original_cached_response)
except:
cached_response = original_cached_response
cached_response['cache'] = True # set cache-hit flag to True
return cached_response
return None
@ -196,7 +204,8 @@ class Cache:
# print(cached_result)
return self.generate_streaming_content(cached_result["choices"][0]['message']['content'])
return cached_result
except:
except Exception as e:
print(f"An exception occurred: {traceback.format_exc()}")
return None
def add_cache(self, result, *args, **kwargs):

View file

@ -13,7 +13,7 @@ import pytest
import litellm
from litellm import embedding, completion
from litellm.caching import Cache
litellm.set_verbose=True
# litellm.set_verbose=True
messages = [{"role": "user", "content": "who is ishaan Github? "}]
# comment
@ -271,30 +271,13 @@ def test_embedding_caching_azure():
def test_redis_cache_completion():
litellm.set_verbose = True
messages = [{"role": "user", "content": "who is ishaan CTO of litellm from litellm 2023"}]
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
print("test2 for caching")
# patch this redis test
local_cache = {}
def set_cache(key, value):
local_cache[key] = value
def get_cache(key):
if key in local_cache:
return local_cache[key]
litellm.cache.cache.set_cache = set_cache
litellm.cache.cache.get_cache = get_cache
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
response3 = completion(model="command-nightly", messages=messages, caching=True)
print(f"response1: {response1}")
print(f"response2: {response2}")
print(f"response3: {response3}")
litellm.cache = None
if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']:
# if models are different, it should not return cached response
@ -368,17 +351,96 @@ def test_hosted_cache():
def test_redis_cache_with_ttl():
cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
cache.add_cache(cache_key="test_key", result="test_value", ttl=1)
sample_model_response_object_str = """{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"message": {
"role": "assistant",
"content": "I'm doing well, thank you for asking. I am Claude, an AI assistant created by Anthropic."
}
}
],
"created": 1691429984.3852863,
"model": "claude-instant-1",
"usage": {
"prompt_tokens": 18,
"completion_tokens": 23,
"total_tokens": 41
}
}"""
sample_model_response_object = {
"choices": [
{
"finish_reason": "stop",
"index": 0,
"message": {
"role": "assistant",
"content": "I'm doing well, thank you for asking. I am Claude, an AI assistant created by Anthropic."
}
}
],
"created": 1691429984.3852863,
"model": "claude-instant-1",
"usage": {
"prompt_tokens": 18,
"completion_tokens": 23,
"total_tokens": 41
}
}
cache.add_cache(cache_key="test_key", result=sample_model_response_object_str, ttl=1)
cached_value = cache.get_cache(cache_key="test_key")
assert cached_value == "test_value"
print(f"cached-value: {cached_value}")
assert cached_value['choices'][0]['message']['content'] == sample_model_response_object['choices'][0]['message']['content']
time.sleep(2)
assert cache.get_cache(cache_key="test_key") is None
# test_redis_cache_with_ttl()
def test_in_memory_cache_with_ttl():
cache = Cache(type="local")
cache.add_cache(cache_key="test_key", result="test_value", ttl=1)
sample_model_response_object_str = """{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"message": {
"role": "assistant",
"content": "I'm doing well, thank you for asking. I am Claude, an AI assistant created by Anthropic."
}
}
],
"created": 1691429984.3852863,
"model": "claude-instant-1",
"usage": {
"prompt_tokens": 18,
"completion_tokens": 23,
"total_tokens": 41
}
}"""
sample_model_response_object = {
"choices": [
{
"finish_reason": "stop",
"index": 0,
"message": {
"role": "assistant",
"content": "I'm doing well, thank you for asking. I am Claude, an AI assistant created by Anthropic."
}
}
],
"created": 1691429984.3852863,
"model": "claude-instant-1",
"usage": {
"prompt_tokens": 18,
"completion_tokens": 23,
"total_tokens": 41
}
}
cache.add_cache(cache_key="test_key", result=sample_model_response_object_str, ttl=1)
cached_value = cache.get_cache(cache_key="test_key")
assert cached_value == "test_value"
assert cached_value['choices'][0]['message']['content'] == sample_model_response_object['choices'][0]['message']['content']
time.sleep(2)
assert cache.get_cache(cache_key="test_key") is None
# test_in_memory_cache_with_ttl()