forked from phoenix/litellm-mirror
v0 of caching
This commit is contained in:
parent
8a76c80039
commit
8b43917792
4 changed files with 137 additions and 49 deletions
|
@ -393,50 +393,6 @@ def client(original_function):
|
|||
# [Non-Blocking Error]
|
||||
pass
|
||||
|
||||
def get_prompt(*args, **kwargs):
|
||||
# make this safe checks, it should not throw any exceptions
|
||||
if len(args) > 1:
|
||||
messages = args[1]
|
||||
prompt = " ".join(message["content"] for message in messages)
|
||||
return prompt
|
||||
if "messages" in kwargs:
|
||||
messages = kwargs["messages"]
|
||||
prompt = " ".join(message["content"] for message in messages)
|
||||
return prompt
|
||||
return None
|
||||
|
||||
def check_cache(*args, **kwargs):
|
||||
try: # never block execution
|
||||
prompt = get_prompt(*args, **kwargs)
|
||||
if prompt != None: # check if messages / prompt exists
|
||||
if litellm.caching_with_models:
|
||||
# if caching with model names is enabled, key is prompt + model name
|
||||
if "model" in kwargs:
|
||||
cache_key = prompt + kwargs["model"]
|
||||
if cache_key in local_cache:
|
||||
return local_cache[cache_key]
|
||||
else: # caching only with prompts
|
||||
if prompt in local_cache:
|
||||
result = local_cache[prompt]
|
||||
return result
|
||||
else:
|
||||
return None
|
||||
return None # default to return None
|
||||
except:
|
||||
return None
|
||||
|
||||
def add_cache(result, *args, **kwargs):
|
||||
try: # never block execution
|
||||
prompt = get_prompt(*args, **kwargs)
|
||||
if litellm.caching_with_models: # caching with model + prompt
|
||||
if "model" in kwargs:
|
||||
cache_key = prompt + kwargs["model"]
|
||||
local_cache[cache_key] = result
|
||||
else: # caching based only on prompts
|
||||
local_cache[prompt] = result
|
||||
except:
|
||||
pass
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
start_time = None
|
||||
result = None
|
||||
|
@ -446,19 +402,20 @@ def client(original_function):
|
|||
kwargs["litellm_call_id"] = litellm_call_id
|
||||
start_time = datetime.datetime.now()
|
||||
# [OPTIONAL] CHECK CACHE
|
||||
if (litellm.caching or litellm.caching_with_models) and (
|
||||
cached_result := check_cache(*args, **kwargs)
|
||||
if (litellm.caching or litellm.caching_with_models or litellm.cache != None) and (
|
||||
cached_result := litellm.cache.check_cache(*args, **kwargs)
|
||||
) is not None:
|
||||
result = cached_result
|
||||
return result
|
||||
# MODEL CALL
|
||||
result = original_function(*args, **kwargs)
|
||||
if "stream" in kwargs and kwargs["stream"] == True:
|
||||
# TODO: Add to cache for streaming
|
||||
return result
|
||||
end_time = datetime.datetime.now()
|
||||
# [OPTIONAL] ADD TO CACHE
|
||||
if litellm.caching or litellm.caching_with_models:
|
||||
add_cache(result, *args, **kwargs)
|
||||
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
|
||||
litellm.cache.add_cache(result, *args, **kwargs)
|
||||
# LOG SUCCESS
|
||||
my_thread = threading.Thread(
|
||||
target=handle_success, args=(args, kwargs, result, start_time, end_time)
|
||||
|
@ -1730,4 +1687,4 @@ def completion_with_fallbacks(**kwargs):
|
|||
) # cool down this selected model
|
||||
# print(f"rate_limited_models {rate_limited_models}")
|
||||
pass
|
||||
return response
|
||||
return response
|
Loading…
Add table
Add a link
Reference in a new issue