fix(utils.py): support caching for embedding + log cache hits

n

n
This commit is contained in:
Krrish Dholakia 2023-12-13 18:37:30 -08:00
parent 0f29cda8d9
commit 8d688b6217
5 changed files with 88 additions and 25 deletions

View file

@ -574,8 +574,9 @@ class Logging:
self.litellm_call_id = litellm_call_id
self.function_id = function_id
self.streaming_chunks = [] # for generating complete stream response
self.model_call_details = {}
def update_environment_variables(self, model, user, optional_params, litellm_params):
def update_environment_variables(self, model, user, optional_params, litellm_params, **additional_params):
self.optional_params = optional_params
self.model = model
self.user = user
@ -590,7 +591,8 @@ class Logging:
"start_time": self.start_time,
"stream": self.stream,
"user": user,
**self.optional_params
**self.optional_params,
**additional_params
}
def _pre_call(self, input, api_key, model=None, additional_args={}):
@ -821,7 +823,7 @@ class Logging:
)
pass
def _success_handler_helper_fn(self, result=None, start_time=None, end_time=None):
def _success_handler_helper_fn(self, result=None, start_time=None, end_time=None, cache_hit=None):
try:
if start_time is None:
start_time = self.start_time
@ -829,6 +831,7 @@ class Logging:
end_time = datetime.datetime.now()
self.model_call_details["log_event_type"] = "successful_api_call"
self.model_call_details["end_time"] = end_time
self.model_call_details["cache_hit"] = cache_hit
if litellm.max_budget and self.stream:
time_diff = (end_time - start_time).total_seconds()
@ -836,10 +839,10 @@ class Logging:
litellm._current_cost += litellm.completion_cost(model=self.model, prompt="", completion=result["content"], total_time=float_diff)
return start_time, end_time, result
except:
pass
except Exception as e:
print_verbose(f"[Non-Blocking] LiteLLM.Success_Call Error: {str(e)}")
def success_handler(self, result=None, start_time=None, end_time=None, **kwargs):
def success_handler(self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs):
print_verbose(
f"Logging Details LiteLLM-Success Call"
)
@ -867,7 +870,7 @@ class Logging:
if complete_streaming_response:
self.model_call_details["complete_streaming_response"] = complete_streaming_response
start_time, end_time, result = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result)
start_time, end_time, result = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit)
for callback in litellm.success_callback:
try:
if callback == "lite_debugger":
@ -1063,7 +1066,7 @@ class Logging:
)
pass
async def async_success_handler(self, result=None, start_time=None, end_time=None, **kwargs):
async def async_success_handler(self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs):
"""
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
"""
@ -1082,7 +1085,7 @@ class Logging:
self.streaming_chunks.append(result)
if complete_streaming_response:
self.model_call_details["complete_streaming_response"] = complete_streaming_response
start_time, end_time, result = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result)
start_time, end_time, result = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit)
for callback in litellm._async_success_callback:
try:
if callback == "cache" and litellm.cache is not None:
@ -1440,6 +1443,7 @@ def client(original_function):
model = args[0] if len(args) > 0 else kwargs["model"]
call_type = original_function.__name__
if call_type == CallTypes.completion.value or call_type == CallTypes.acompletion.value:
messages = None
if len(args) > 1:
messages = args[1]
elif kwargs.get("messages", None):
@ -1509,11 +1513,12 @@ def client(original_function):
if litellm._current_cost > litellm.max_budget:
raise BudgetExceededError(current_cost=litellm._current_cost, max_budget=litellm.max_budget)
# [OPTIONAL] CHECK CACHE
# remove this after deprecating litellm.caching
if (litellm.caching or litellm.caching_with_models) and litellm.cache is None:
litellm.cache = Cache()
# [OPTIONAL] CHECK CACHE
print_verbose(f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}")
# if caching is false, don't run this
if (kwargs.get("caching", None) is None and litellm.cache is not None) or kwargs.get("caching", False) == True: # allow users to control returning cached responses from the completion function
@ -1563,11 +1568,6 @@ def client(original_function):
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
print_verbose(f"Wrapper: Completed Call, calling success_handler")
threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start()
# threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start()
my_thread = threading.Thread(
target=handle_success, args=(args, kwargs, result, start_time, end_time)
) # don't interrupt execution of main thread
my_thread.start()
# RETURN RESULT
result._response_ms = (end_time - start_time).total_seconds() * 1000 # return response latency in ms like openai
return result
@ -1648,13 +1648,22 @@ def client(original_function):
call_type = original_function.__name__
if call_type == CallTypes.acompletion.value and isinstance(cached_result, dict):
if kwargs.get("stream", False) == True:
return convert_to_streaming_response_async(
cached_result = convert_to_streaming_response_async(
response_object=cached_result,
)
else:
return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse())
else:
return cached_result
cached_result = convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse())
elif call_type == CallTypes.aembedding.value and isinstance(cached_result, dict):
cached_result = convert_to_model_response_object(response_object=cached_result, model_response_object=EmbeddingResponse(), response_type="embedding")
# LOG SUCCESS
cache_hit = True
end_time = datetime.datetime.now()
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(model=model, custom_llm_provider=kwargs.get('custom_llm_provider', None), api_base=kwargs.get('api_base', None), api_key=kwargs.get('api_key', None))
print_verbose(f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}")
logging_obj.update_environment_variables(model=model, user=kwargs.get('user', None), optional_params={}, litellm_params={"logger_fn": kwargs.get('logger_fn', None), "acompletion": True}, input=kwargs.get('messages', ""), api_key=kwargs.get('api_key', None), original_response=str(cached_result), additional_args=None, stream=kwargs.get('stream', False))
asyncio.create_task(logging_obj.async_success_handler(cached_result, start_time, end_time, cache_hit))
threading.Thread(target=logging_obj.success_handler, args=(cached_result, start_time, end_time, cache_hit)).start()
return cached_result
# MODEL CALL
result = await original_function(*args, **kwargs)
end_time = datetime.datetime.now()
@ -1672,7 +1681,10 @@ def client(original_function):
# [OPTIONAL] ADD TO CACHE
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
litellm.cache.add_cache(result, *args, **kwargs)
if isinstance(result, litellm.ModelResponse) or isinstance(result, litellm.EmbeddingResponse):
litellm.cache.add_cache(result.json(), *args, **kwargs)
else:
litellm.cache.add_cache(result, *args, **kwargs)
# LOG SUCCESS - handle streaming success logging in the _next_ object
print_verbose(f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}")
asyncio.create_task(logging_obj.async_success_handler(result, start_time, end_time))