This commit is contained in:
Krrish Dholakia 2024-01-11 16:30:05 +05:30
parent bd5a14daf6
commit df9df7b040
3 changed files with 58 additions and 111 deletions

View file

@ -81,7 +81,7 @@ class RedisCache(BaseCache):
def set_cache(self, key, value, **kwargs):
ttl = kwargs.get("ttl", None)
print_verbose(f"Set Redis Cache: key: {key}\nValue {value}")
print_verbose(f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}")
try:
self.redis_client.set(name=key, value=str(value), ex=ttl)
except Exception as e:
@ -171,7 +171,7 @@ class S3Cache(BaseCache):
CacheControl=cache_control,
ContentType="application/json",
ContentLanguage="en",
ContentDisposition=f"inline; filename=\"{key}.json\""
ContentDisposition=f'inline; filename="{key}.json"',
)
else:
cache_control = "immutable, max-age=31536000, s-maxage=31536000"
@ -183,7 +183,7 @@ class S3Cache(BaseCache):
CacheControl=cache_control,
ContentType="application/json",
ContentLanguage="en",
ContentDisposition=f"inline; filename=\"{key}.json\""
ContentDisposition=f'inline; filename="{key}.json"',
)
except Exception as e:
# NON blocking - notify users S3 is throwing an exception
@ -495,7 +495,6 @@ class Cache:
cached_result is not None
and isinstance(cached_result, dict)
and "timestamp" in cached_result
and max_age is not None
):
timestamp = cached_result["timestamp"]
current_time = time.time()
@ -504,7 +503,7 @@ class Cache:
response_age = current_time - timestamp
# Check if the cached response is older than the max-age
if response_age > max_age:
if max_age is not None and response_age > max_age:
print_verbose(
f"Cached response for key {cache_key} is too old. Max-age: {max_age}s, Age: {response_age}s"
)
@ -565,6 +564,9 @@ class Cache:
async def _async_add_cache(self, result, *args, **kwargs):
self.add_cache(result, *args, **kwargs)
async def _async_get_cache(self, *args, **kwargs):
return self.get_cache(*args, **kwargs)
def enable_cache(
type: Optional[Literal["local", "redis", "s3"]] = "local",

View file

@ -267,10 +267,10 @@ async def acompletion(
elif asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response
response = init_response # type: ignore
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
response = await loop.run_in_executor(None, func_with_context) # type: ignore
# if kwargs.get("stream", False): # return an async generator
# return _async_streaming(
# response=response,

View file

@ -11,10 +11,10 @@ sys.path.insert(
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm import embedding, completion
from litellm import embedding, completion, aembedding
from litellm.caching import Cache
import random
import hashlib
import hashlib, asyncio
# litellm.set_verbose=True
@ -261,6 +261,51 @@ def test_embedding_caching_azure():
# test_embedding_caching_azure()
@pytest.mark.asyncio
async def test_embedding_caching_azure_individual_items():
"""
Tests caching for individual items in an embedding list
Assert if the same embeddingresponse object is returned for the duplicate item in 2 embedding list calls
```
embedding_1 = ["hey how's it going", "I'm doing well"]
embedding_val_1 = embedding(...)
embedding_2 = ["hey how's it going", "I'm fine"]
embedding_val_2 = embedding(...)
assert embedding_val_1[0]["id"] == embedding_val_2[0]["id"]
```
"""
litellm.cache = Cache()
common_msg = f"hey how's it going {uuid.uuid4()}"
embedding_1 = [common_msg, "I'm doing well"]
embedding_2 = [common_msg, "I'm fine"]
embedding_val_1 = await aembedding(
model="azure/azure-embedding-model", input=embedding_1, caching=True
)
embedding_val_2 = await aembedding(
model="azure/azure-embedding-model", input=embedding_2, caching=True
)
if (
embedding_val_2["data"][0]["embedding"]
!= embedding_val_1["data"][0]["embedding"]
):
print(f"embedding1: {embedding_val_1}")
print(f"embedding2: {embedding_val_2}")
pytest.fail("Error occurred: Embedding caching failed")
if (
embedding_val_2["data"][1]["embedding"]
== embedding_val_1["data"][1]["embedding"]
):
print(f"embedding1: {embedding_val_1}")
print(f"embedding2: {embedding_val_2}")
pytest.fail("Error occurred: Embedding caching failed")
def test_redis_cache_completion():
litellm.set_verbose = False
@ -401,14 +446,14 @@ def test_redis_cache_completion_stream():
"""
test_redis_cache_completion_stream()
# test_redis_cache_completion_stream()
def test_redis_cache_acompletion_stream():
import asyncio
try:
litellm.set_verbose = True
litellm.set_verbose = False
random_word = generate_random_word()
messages = [
{
@ -436,7 +481,6 @@ def test_redis_cache_acompletion_stream():
stream=True,
)
async for chunk in response1:
print(chunk)
response_1_content += chunk.choices[0].delta.content or ""
print(response_1_content)
@ -454,7 +498,6 @@ def test_redis_cache_acompletion_stream():
stream=True,
)
async for chunk in response2:
print(chunk)
response_2_content += chunk.choices[0].delta.content or ""
print(response_2_content)
@ -916,101 +959,3 @@ def test_cache_context_managers():
# test_cache_context_managers()
# test_custom_redis_cache_params()
# def test_redis_cache_with_ttl():
# cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
# sample_model_response_object_str = """{
# "choices": [
# {
# "finish_reason": "stop",
# "index": 0,
# "message": {
# "role": "assistant",
# "content": "I'm doing well, thank you for asking. I am Claude, an AI assistant created by Anthropic."
# }
# }
# ],
# "created": 1691429984.3852863,
# "model": "claude-instant-1",
# "usage": {
# "prompt_tokens": 18,
# "completion_tokens": 23,
# "total_tokens": 41
# }
# }"""
# sample_model_response_object = {
# "choices": [
# {
# "finish_reason": "stop",
# "index": 0,
# "message": {
# "role": "assistant",
# "content": "I'm doing well, thank you for asking. I am Claude, an AI assistant created by Anthropic."
# }
# }
# ],
# "created": 1691429984.3852863,
# "model": "claude-instant-1",
# "usage": {
# "prompt_tokens": 18,
# "completion_tokens": 23,
# "total_tokens": 41
# }
# }
# cache.add_cache(cache_key="test_key", result=sample_model_response_object_str, ttl=1)
# cached_value = cache.get_cache(cache_key="test_key")
# print(f"cached-value: {cached_value}")
# assert cached_value['choices'][0]['message']['content'] == sample_model_response_object['choices'][0]['message']['content']
# time.sleep(2)
# assert cache.get_cache(cache_key="test_key") is None
# # test_redis_cache_with_ttl()
# def test_in_memory_cache_with_ttl():
# cache = Cache(type="local")
# sample_model_response_object_str = """{
# "choices": [
# {
# "finish_reason": "stop",
# "index": 0,
# "message": {
# "role": "assistant",
# "content": "I'm doing well, thank you for asking. I am Claude, an AI assistant created by Anthropic."
# }
# }
# ],
# "created": 1691429984.3852863,
# "model": "claude-instant-1",
# "usage": {
# "prompt_tokens": 18,
# "completion_tokens": 23,
# "total_tokens": 41
# }
# }"""
# sample_model_response_object = {
# "choices": [
# {
# "finish_reason": "stop",
# "index": 0,
# "message": {
# "role": "assistant",
# "content": "I'm doing well, thank you for asking. I am Claude, an AI assistant created by Anthropic."
# }
# }
# ],
# "created": 1691429984.3852863,
# "model": "claude-instant-1",
# "usage": {
# "prompt_tokens": 18,
# "completion_tokens": 23,
# "total_tokens": 41
# }
# }
# cache.add_cache(cache_key="test_key", result=sample_model_response_object_str, ttl=1)
# cached_value = cache.get_cache(cache_key="test_key")
# assert cached_value['choices'][0]['message']['content'] == sample_model_response_object['choices'][0]['message']['content']
# time.sleep(2)
# assert cache.get_cache(cache_key="test_key") is None
# # test_in_memory_cache_with_ttl()