fix(utils.py): support caching for embedding + log cache hits

n

n
This commit is contained in:
Krrish Dholakia 2023-12-13 18:37:30 -08:00
parent 72b9d4c5e8
commit 853508e8c0
5 changed files with 88 additions and 25 deletions

View file

@ -5,7 +5,7 @@ from datetime import datetime
import pytest
sys.path.insert(0, os.path.abspath('../..'))
from typing import Optional, Literal, List, Union
from litellm import completion, embedding
from litellm import completion, embedding, Cache
import litellm
from litellm.integrations.custom_logger import CustomLogger
@ -14,6 +14,7 @@ from litellm.integrations.custom_logger import CustomLogger
## 2: Post-API-Call
## 3: On LiteLLM Call success
## 4: On LiteLLM Call failure
## 5. Caching
# Test models
## 1. OpenAI
@ -32,7 +33,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
def __init__(self):
self.errors = []
self.states: Optional[List[Literal["sync_pre_api_call", "async_pre_api_call", "post_api_call", "sync_stream", "async_stream", "sync_success", "async_success", "sync_failure", "async_failure"]]] = []
def log_pre_api_call(self, model, messages, kwargs):
try:
self.states.append("sync_pre_api_call")
@ -126,6 +127,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper))
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
assert isinstance(kwargs["cache_hit"], Optional[bool])
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
@ -197,7 +199,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
assert isinstance(kwargs["cache_hit"], Optional[bool])
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
@ -577,4 +579,47 @@ async def test_async_embedding_bedrock():
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_embedding_bedrock())
# asyncio.run(test_async_embedding_bedrock())
# CACHING
## Test Azure - completion, embedding
@pytest.mark.asyncio
async def test_async_completion_azure_caching():
customHandler_caching = CompletionCustomHandler()
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
litellm.callbacks = [customHandler_caching]
unique_time = time.time()
response1 = await litellm.acompletion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": f"Hi 👋 - i'm async azure {unique_time}"
}],
caching=True)
await asyncio.sleep(1)
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
response2 = await litellm.acompletion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": f"Hi 👋 - i'm async azure {unique_time}"
}],
caching=True)
await asyncio.sleep(1) # success callbacks are done in parallel
print(f"customHandler_caching.states post-cache hit: {customHandler_caching.states}")
assert len(customHandler_caching.errors) == 0
assert len(customHandler_caching.states) == 4 # pre, post, success, success
@pytest.mark.asyncio
async def test_async_embedding_azure_caching():
customHandler_caching = CompletionCustomHandler()
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
litellm.callbacks = [customHandler_caching]
unique_time = time.time()
response1 = await litellm.aembedding(model="azure/azure-embedding-model",
input=[f"good morning from litellm1 {unique_time}"],
caching=True)
response2 = await litellm.aembedding(model="azure/azure-embedding-model",
input=[f"good morning from litellm1 {unique_time}"],
caching=True)
await asyncio.sleep(1) # success callbacks are done in parallel
assert len(customHandler_caching.errors) == 0
assert len(customHandler_caching.states) == 4 # pre, post, success, success