v0 of caching

This commit is contained in:
ishaan-jaff 2023-08-28 13:11:54 -07:00
parent 753ab03d91
commit 6fa2e578d4
4 changed files with 137 additions and 49 deletions

View file

@ -393,50 +393,6 @@ def client(original_function):
# [Non-Blocking Error]
pass
def get_prompt(*args, **kwargs):
# make this safe checks, it should not throw any exceptions
if len(args) > 1:
messages = args[1]
prompt = " ".join(message["content"] for message in messages)
return prompt
if "messages" in kwargs:
messages = kwargs["messages"]
prompt = " ".join(message["content"] for message in messages)
return prompt
return None
def check_cache(*args, **kwargs):
try: # never block execution
prompt = get_prompt(*args, **kwargs)
if prompt != None: # check if messages / prompt exists
if litellm.caching_with_models:
# if caching with model names is enabled, key is prompt + model name
if "model" in kwargs:
cache_key = prompt + kwargs["model"]
if cache_key in local_cache:
return local_cache[cache_key]
else: # caching only with prompts
if prompt in local_cache:
result = local_cache[prompt]
return result
else:
return None
return None # default to return None
except:
return None
def add_cache(result, *args, **kwargs):
try: # never block execution
prompt = get_prompt(*args, **kwargs)
if litellm.caching_with_models: # caching with model + prompt
if "model" in kwargs:
cache_key = prompt + kwargs["model"]
local_cache[cache_key] = result
else: # caching based only on prompts
local_cache[prompt] = result
except:
pass
def wrapper(*args, **kwargs):
start_time = None
result = None
@ -446,19 +402,20 @@ def client(original_function):
kwargs["litellm_call_id"] = litellm_call_id
start_time = datetime.datetime.now()
# [OPTIONAL] CHECK CACHE
if (litellm.caching or litellm.caching_with_models) and (
cached_result := check_cache(*args, **kwargs)
if (litellm.caching or litellm.caching_with_models or litellm.cache != None) and (
cached_result := litellm.cache.check_cache(*args, **kwargs)
) is not None:
result = cached_result
return result
# MODEL CALL
result = original_function(*args, **kwargs)
if "stream" in kwargs and kwargs["stream"] == True:
# TODO: Add to cache for streaming
return result
end_time = datetime.datetime.now()
# [OPTIONAL] ADD TO CACHE
if litellm.caching or litellm.caching_with_models:
add_cache(result, *args, **kwargs)
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
litellm.cache.add_cache(result, *args, **kwargs)
# LOG SUCCESS
my_thread = threading.Thread(
target=handle_success, args=(args, kwargs, result, start_time, end_time)
@ -1730,4 +1687,4 @@ def completion_with_fallbacks(**kwargs):
) # cool down this selected model
# print(f"rate_limited_models {rate_limited_models}")
pass
return response
return response