(feat) use async_cache for acompletion/aembedding

This commit is contained in:
ishaan-jaff 2023-12-14 16:04:45 +05:30
parent a8e12661c2
commit 008df34ddc
3 changed files with 9 additions and 15 deletions

View file

@ -12,18 +12,6 @@ import time, logging
import json, traceback, ast
from typing import Optional
def get_prompt(*args, **kwargs):
# make this safe checks, it should not throw any exceptions
if len(args) > 1:
messages = args[1]
prompt = " ".join(message["content"] for message in messages)
return prompt
if "messages" in kwargs:
messages = kwargs["messages"]
prompt = " ".join(message["content"] for message in messages)
return prompt
return None
def print_verbose(print_statement):
try:
if litellm.set_verbose:
@ -309,4 +297,9 @@ class Cache:
result = result.model_dump_json()
self.cache.set_cache(cache_key, result, **kwargs)
except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
traceback.print_exc()
pass
async def _async_add_cache(self, result, *args, **kwargs):
self.add_cache(result, *args, **kwargs)

View file

@ -29,6 +29,7 @@ def generate_random_word(length=4):
messages = [{"role": "user", "content": "who is ishaan 5222"}]
def test_caching_v2(): # test in memory cache
try:
litellm.set_verbose=True
litellm.cache = Cache()
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
@ -40,7 +41,7 @@ def test_caching_v2(): # test in memory cache
if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']:
print(f"response1: {response1}")
print(f"response2: {response2}")
pytest.fail(f"Error occurred: {e}")
pytest.fail(f"Error occurred:")
except Exception as e:
print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}")

View file

@ -1682,9 +1682,9 @@ def client(original_function):
# [OPTIONAL] ADD TO CACHE
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
if isinstance(result, litellm.ModelResponse) or isinstance(result, litellm.EmbeddingResponse):
litellm.cache.add_cache(result.json(), *args, **kwargs)
asyncio.create_task(litellm.cache._async_add_cache(result.json(), *args, **kwargs))
else:
litellm.cache.add_cache(result, *args, **kwargs)
asyncio.create_task(litellm.cache._async_add_cache(result, *args, **kwargs))
# LOG SUCCESS - handle streaming success logging in the _next_ object
print_verbose(f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}")
asyncio.create_task(logging_obj.async_success_handler(result, start_time, end_time))