mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
fix(utils.py): support caching for embedding + log cache hits
n n
This commit is contained in:
parent
72b9d4c5e8
commit
853508e8c0
5 changed files with 88 additions and 25 deletions
|
@ -5,7 +5,7 @@ from datetime import datetime
|
|||
import pytest
|
||||
sys.path.insert(0, os.path.abspath('../..'))
|
||||
from typing import Optional, Literal, List, Union
|
||||
from litellm import completion, embedding
|
||||
from litellm import completion, embedding, Cache
|
||||
import litellm
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
|
@ -14,6 +14,7 @@ from litellm.integrations.custom_logger import CustomLogger
|
|||
## 2: Post-API-Call
|
||||
## 3: On LiteLLM Call success
|
||||
## 4: On LiteLLM Call failure
|
||||
## 5. Caching
|
||||
|
||||
# Test models
|
||||
## 1. OpenAI
|
||||
|
@ -32,7 +33,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
|
|||
def __init__(self):
|
||||
self.errors = []
|
||||
self.states: Optional[List[Literal["sync_pre_api_call", "async_pre_api_call", "post_api_call", "sync_stream", "async_stream", "sync_success", "async_success", "sync_failure", "async_failure"]]] = []
|
||||
|
||||
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
try:
|
||||
self.states.append("sync_pre_api_call")
|
||||
|
@ -126,6 +127,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
|
|||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper))
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
assert isinstance(kwargs["cache_hit"], Optional[bool])
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
@ -197,7 +199,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
|
|||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
|
||||
assert isinstance(kwargs["cache_hit"], Optional[bool])
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
@ -577,4 +579,47 @@ async def test_async_embedding_bedrock():
|
|||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
|
||||
# asyncio.run(test_async_embedding_bedrock())
|
||||
# asyncio.run(test_async_embedding_bedrock())
|
||||
|
||||
# CACHING
|
||||
## Test Azure - completion, embedding
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_completion_azure_caching():
|
||||
customHandler_caching = CompletionCustomHandler()
|
||||
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
||||
litellm.callbacks = [customHandler_caching]
|
||||
unique_time = time.time()
|
||||
response1 = await litellm.acompletion(model="azure/chatgpt-v-2",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": f"Hi 👋 - i'm async azure {unique_time}"
|
||||
}],
|
||||
caching=True)
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
|
||||
response2 = await litellm.acompletion(model="azure/chatgpt-v-2",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": f"Hi 👋 - i'm async azure {unique_time}"
|
||||
}],
|
||||
caching=True)
|
||||
await asyncio.sleep(1) # success callbacks are done in parallel
|
||||
print(f"customHandler_caching.states post-cache hit: {customHandler_caching.states}")
|
||||
assert len(customHandler_caching.errors) == 0
|
||||
assert len(customHandler_caching.states) == 4 # pre, post, success, success
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_embedding_azure_caching():
|
||||
customHandler_caching = CompletionCustomHandler()
|
||||
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
||||
litellm.callbacks = [customHandler_caching]
|
||||
unique_time = time.time()
|
||||
response1 = await litellm.aembedding(model="azure/azure-embedding-model",
|
||||
input=[f"good morning from litellm1 {unique_time}"],
|
||||
caching=True)
|
||||
response2 = await litellm.aembedding(model="azure/azure-embedding-model",
|
||||
input=[f"good morning from litellm1 {unique_time}"],
|
||||
caching=True)
|
||||
await asyncio.sleep(1) # success callbacks are done in parallel
|
||||
assert len(customHandler_caching.errors) == 0
|
||||
assert len(customHandler_caching.states) == 4 # pre, post, success, success
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue