forked from phoenix/litellm-mirror
fix(utils.py): support caching for embedding + log cache hits
n n
This commit is contained in:
parent
0f29cda8d9
commit
8d688b6217
5 changed files with 88 additions and 25 deletions
|
@ -574,8 +574,9 @@ class Logging:
|
|||
self.litellm_call_id = litellm_call_id
|
||||
self.function_id = function_id
|
||||
self.streaming_chunks = [] # for generating complete stream response
|
||||
self.model_call_details = {}
|
||||
|
||||
def update_environment_variables(self, model, user, optional_params, litellm_params):
|
||||
def update_environment_variables(self, model, user, optional_params, litellm_params, **additional_params):
|
||||
self.optional_params = optional_params
|
||||
self.model = model
|
||||
self.user = user
|
||||
|
@ -590,7 +591,8 @@ class Logging:
|
|||
"start_time": self.start_time,
|
||||
"stream": self.stream,
|
||||
"user": user,
|
||||
**self.optional_params
|
||||
**self.optional_params,
|
||||
**additional_params
|
||||
}
|
||||
|
||||
def _pre_call(self, input, api_key, model=None, additional_args={}):
|
||||
|
@ -821,7 +823,7 @@ class Logging:
|
|||
)
|
||||
pass
|
||||
|
||||
def _success_handler_helper_fn(self, result=None, start_time=None, end_time=None):
|
||||
def _success_handler_helper_fn(self, result=None, start_time=None, end_time=None, cache_hit=None):
|
||||
try:
|
||||
if start_time is None:
|
||||
start_time = self.start_time
|
||||
|
@ -829,6 +831,7 @@ class Logging:
|
|||
end_time = datetime.datetime.now()
|
||||
self.model_call_details["log_event_type"] = "successful_api_call"
|
||||
self.model_call_details["end_time"] = end_time
|
||||
self.model_call_details["cache_hit"] = cache_hit
|
||||
|
||||
if litellm.max_budget and self.stream:
|
||||
time_diff = (end_time - start_time).total_seconds()
|
||||
|
@ -836,10 +839,10 @@ class Logging:
|
|||
litellm._current_cost += litellm.completion_cost(model=self.model, prompt="", completion=result["content"], total_time=float_diff)
|
||||
|
||||
return start_time, end_time, result
|
||||
except:
|
||||
pass
|
||||
except Exception as e:
|
||||
print_verbose(f"[Non-Blocking] LiteLLM.Success_Call Error: {str(e)}")
|
||||
|
||||
def success_handler(self, result=None, start_time=None, end_time=None, **kwargs):
|
||||
def success_handler(self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs):
|
||||
print_verbose(
|
||||
f"Logging Details LiteLLM-Success Call"
|
||||
)
|
||||
|
@ -867,7 +870,7 @@ class Logging:
|
|||
if complete_streaming_response:
|
||||
self.model_call_details["complete_streaming_response"] = complete_streaming_response
|
||||
|
||||
start_time, end_time, result = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result)
|
||||
start_time, end_time, result = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit)
|
||||
for callback in litellm.success_callback:
|
||||
try:
|
||||
if callback == "lite_debugger":
|
||||
|
@ -1063,7 +1066,7 @@ class Logging:
|
|||
)
|
||||
pass
|
||||
|
||||
async def async_success_handler(self, result=None, start_time=None, end_time=None, **kwargs):
|
||||
async def async_success_handler(self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs):
|
||||
"""
|
||||
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
|
||||
"""
|
||||
|
@ -1082,7 +1085,7 @@ class Logging:
|
|||
self.streaming_chunks.append(result)
|
||||
if complete_streaming_response:
|
||||
self.model_call_details["complete_streaming_response"] = complete_streaming_response
|
||||
start_time, end_time, result = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result)
|
||||
start_time, end_time, result = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit)
|
||||
for callback in litellm._async_success_callback:
|
||||
try:
|
||||
if callback == "cache" and litellm.cache is not None:
|
||||
|
@ -1440,6 +1443,7 @@ def client(original_function):
|
|||
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||
call_type = original_function.__name__
|
||||
if call_type == CallTypes.completion.value or call_type == CallTypes.acompletion.value:
|
||||
messages = None
|
||||
if len(args) > 1:
|
||||
messages = args[1]
|
||||
elif kwargs.get("messages", None):
|
||||
|
@ -1509,11 +1513,12 @@ def client(original_function):
|
|||
if litellm._current_cost > litellm.max_budget:
|
||||
raise BudgetExceededError(current_cost=litellm._current_cost, max_budget=litellm.max_budget)
|
||||
|
||||
# [OPTIONAL] CHECK CACHE
|
||||
# remove this after deprecating litellm.caching
|
||||
if (litellm.caching or litellm.caching_with_models) and litellm.cache is None:
|
||||
litellm.cache = Cache()
|
||||
|
||||
|
||||
# [OPTIONAL] CHECK CACHE
|
||||
print_verbose(f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}")
|
||||
# if caching is false, don't run this
|
||||
if (kwargs.get("caching", None) is None and litellm.cache is not None) or kwargs.get("caching", False) == True: # allow users to control returning cached responses from the completion function
|
||||
|
@ -1563,11 +1568,6 @@ def client(original_function):
|
|||
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
|
||||
print_verbose(f"Wrapper: Completed Call, calling success_handler")
|
||||
threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start()
|
||||
# threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start()
|
||||
my_thread = threading.Thread(
|
||||
target=handle_success, args=(args, kwargs, result, start_time, end_time)
|
||||
) # don't interrupt execution of main thread
|
||||
my_thread.start()
|
||||
# RETURN RESULT
|
||||
result._response_ms = (end_time - start_time).total_seconds() * 1000 # return response latency in ms like openai
|
||||
return result
|
||||
|
@ -1648,13 +1648,22 @@ def client(original_function):
|
|||
call_type = original_function.__name__
|
||||
if call_type == CallTypes.acompletion.value and isinstance(cached_result, dict):
|
||||
if kwargs.get("stream", False) == True:
|
||||
return convert_to_streaming_response_async(
|
||||
cached_result = convert_to_streaming_response_async(
|
||||
response_object=cached_result,
|
||||
)
|
||||
else:
|
||||
return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse())
|
||||
else:
|
||||
return cached_result
|
||||
cached_result = convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse())
|
||||
elif call_type == CallTypes.aembedding.value and isinstance(cached_result, dict):
|
||||
cached_result = convert_to_model_response_object(response_object=cached_result, model_response_object=EmbeddingResponse(), response_type="embedding")
|
||||
# LOG SUCCESS
|
||||
cache_hit = True
|
||||
end_time = datetime.datetime.now()
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(model=model, custom_llm_provider=kwargs.get('custom_llm_provider', None), api_base=kwargs.get('api_base', None), api_key=kwargs.get('api_key', None))
|
||||
print_verbose(f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}")
|
||||
logging_obj.update_environment_variables(model=model, user=kwargs.get('user', None), optional_params={}, litellm_params={"logger_fn": kwargs.get('logger_fn', None), "acompletion": True}, input=kwargs.get('messages', ""), api_key=kwargs.get('api_key', None), original_response=str(cached_result), additional_args=None, stream=kwargs.get('stream', False))
|
||||
asyncio.create_task(logging_obj.async_success_handler(cached_result, start_time, end_time, cache_hit))
|
||||
threading.Thread(target=logging_obj.success_handler, args=(cached_result, start_time, end_time, cache_hit)).start()
|
||||
return cached_result
|
||||
# MODEL CALL
|
||||
result = await original_function(*args, **kwargs)
|
||||
end_time = datetime.datetime.now()
|
||||
|
@ -1672,7 +1681,10 @@ def client(original_function):
|
|||
|
||||
# [OPTIONAL] ADD TO CACHE
|
||||
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
|
||||
litellm.cache.add_cache(result, *args, **kwargs)
|
||||
if isinstance(result, litellm.ModelResponse) or isinstance(result, litellm.EmbeddingResponse):
|
||||
litellm.cache.add_cache(result.json(), *args, **kwargs)
|
||||
else:
|
||||
litellm.cache.add_cache(result, *args, **kwargs)
|
||||
# LOG SUCCESS - handle streaming success logging in the _next_ object
|
||||
print_verbose(f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}")
|
||||
asyncio.create_task(logging_obj.async_success_handler(result, start_time, end_time))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue