(feat) control caching for embedding, completion

This commit is contained in:
ishaan-jaff 2023-12-14 22:31:04 +05:30
parent 9ee16bc962
commit 17c3a1393b

View file

@ -1514,17 +1514,13 @@ def client(original_function):
if litellm._current_cost > litellm.max_budget:
raise BudgetExceededError(current_cost=litellm._current_cost, max_budget=litellm.max_budget)
# remove this after deprecating litellm.caching
if (litellm.caching or litellm.caching_with_models) and litellm.cache is None:
litellm.cache = Cache()
# [OPTIONAL] CHECK CACHE
print_verbose(f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}")
# if caching is false, don't run this
if (kwargs.get("caching", None) is None and litellm.cache is not None) or kwargs.get("caching", False) == True: # allow users to control returning cached responses from the completion function
# checking cache
if (litellm.cache != None or litellm.caching or litellm.caching_with_models):
print_verbose(f"INSIDE CHECKING CACHE")
if litellm.cache is not None and str(original_function.__name__) in litellm.cache.supported_call_types:
print_verbose(f"Checking Cache")
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
kwargs["preset_cache_key"] = preset_cache_key # for streaming calls, we need to pass the preset_cache_key
@ -1563,7 +1559,7 @@ def client(original_function):
post_call_processing(original_response=result, model=model)
# [OPTIONAL] ADD TO CACHE
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
if litellm.cache is not None and str(original_function.__name__) in litellm.cache.supported_call_types:
litellm.cache.add_cache(result, *args, **kwargs)
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
@ -1641,7 +1637,7 @@ def client(original_function):
if (kwargs.get("caching", None) is None and litellm.cache is not None) or kwargs.get("caching", False) == True: # allow users to control returning cached responses from the completion function
# checking cache
print_verbose(f"INSIDE CHECKING CACHE")
if litellm.cache is not None:
if litellm.cache is not None and str(original_function.__name__) in litellm.cache.supported_call_types:
print_verbose(f"Checking Cache")
cached_result = litellm.cache.get_cache(*args, **kwargs)
if cached_result != None:
@ -1681,7 +1677,7 @@ def client(original_function):
post_call_processing(original_response=result, model=model)
# [OPTIONAL] ADD TO CACHE
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
if litellm.cache is not None and str(original_function.__name__) in litellm.cache.supported_call_types:
if isinstance(result, litellm.ModelResponse) or isinstance(result, litellm.EmbeddingResponse):
asyncio.create_task(litellm.cache._async_add_cache(result.json(), *args, **kwargs))
else: