diff --git a/litellm/tests/test_custom_logger.py b/litellm/tests/test_custom_logger.py index e1c532a88..2747d33e9 100644 --- a/litellm/tests/test_custom_logger.py +++ b/litellm/tests/test_custom_logger.py @@ -31,6 +31,7 @@ class MyCustomHandler(CustomLogger): self.sync_stream_collected_response = None # type: ignore self.user = None # type: ignore self.data_sent_to_api: dict = {} + self.response_cost = 0 def log_pre_api_call(self, model, messages, kwargs): print(f"Pre-API Call") @@ -47,6 +48,8 @@ class MyCustomHandler(CustomLogger): self.success = True if kwargs.get("stream") == True: self.sync_stream_collected_response = response_obj + print(f"response cost in log_success_event: {kwargs.get('response_cost')}") + self.response_cost = kwargs.get("response_cost", 0) def log_failure_event(self, kwargs, response_obj, start_time, end_time): print(f"On Failure") @@ -64,6 +67,10 @@ class MyCustomHandler(CustomLogger): self.stream_collected_response = response_obj self.async_completion_kwargs = kwargs self.user = kwargs.get("user", None) + print( + f"response cost in log_async_success_event: {kwargs.get('response_cost')}" + ) + self.response_cost = kwargs.get("response_cost", 0) async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): print(f"On Async Failure") @@ -400,6 +407,50 @@ async def test_async_custom_handler_embedding_optional_param_bedrock(): assert "user" not in customHandler_optional_params.data_sent_to_api +@pytest.mark.asyncio +async def test_cost_tracking_with_caching(): + """ + Important Test - This tests if that cost is 0 for cached responses + """ + from litellm import Cache + + litellm.set_verbose = False + litellm.cache = Cache( + type="redis", + host=os.environ["REDIS_HOST"], + port=os.environ["REDIS_PORT"], + password=os.environ["REDIS_PASSWORD"], + ) + customHandler_optional_params = MyCustomHandler() + litellm.callbacks = [customHandler_optional_params] + messages = [ + { + "role": "user", + "content": f"write a one sentence poem about: {time.time()}", + } + ] + response1 = await litellm.acompletion( + model="gpt-3.5-turbo", + messages=messages, + max_tokens=40, + temperature=0.2, + caching=True, + ) + await asyncio.sleep(1) # success callback is async + response_cost = customHandler_optional_params.response_cost + assert response_cost > 0 + response2 = await litellm.acompletion( + model="gpt-3.5-turbo", + messages=messages, + max_tokens=40, + temperature=0.2, + caching=True, + ) + await asyncio.sleep(1) # success callback is async + response_cost_2 = customHandler_optional_params.response_cost + assert response_cost_2 == 0 + + def test_redis_cache_completion_stream(): from litellm import Cache diff --git a/litellm/utils.py b/litellm/utils.py index ffd132e6f..f265a0190 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1060,9 +1060,14 @@ class Logging: and self.stream != True ): # handle streaming separately try: - self.model_call_details["response_cost"] = litellm.completion_cost( - completion_response=result, - ) + if self.model_call_details.get("cache_hit", False) == True: + self.model_call_details["response_cost"] = 0.0 + else: + self.model_call_details[ + "response_cost" + ] = litellm.completion_cost( + completion_response=result, + ) verbose_logger.debug( f"Model={self.model}; cost={self.model_call_details['response_cost']}" ) @@ -1096,7 +1101,7 @@ class Logging: def success_handler( self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs ): - verbose_logger.debug(f"Logging Details LiteLLM-Success Call") + verbose_logger.debug(f"Logging Details LiteLLM-Success Call: {cache_hit}") start_time, end_time, result = self._success_handler_helper_fn( start_time=start_time, end_time=end_time, @@ -1134,9 +1139,14 @@ class Logging: "complete_streaming_response" ] = complete_streaming_response try: - self.model_call_details["response_cost"] = litellm.completion_cost( - completion_response=complete_streaming_response, - ) + if self.model_call_details.get("cache_hit", False) == True: + self.model_call_details["response_cost"] = 0.0 + else: + self.model_call_details[ + "response_cost" + ] = litellm.completion_cost( + completion_response=complete_streaming_response, + ) verbose_logger.debug( f"Model={self.model}; cost={self.model_call_details['response_cost']}" ) @@ -1158,6 +1168,7 @@ class Logging: callbacks.append(callback) else: callbacks = litellm.success_callback + for callback in callbacks: try: if callback == "lite_debugger": @@ -1342,7 +1353,7 @@ class Logging: end_time=end_time, print_verbose=print_verbose, ) - elif ( + if ( isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get( "acompletion", False @@ -1353,9 +1364,6 @@ class Logging: ) == False ): # custom logger class - verbose_logger.info( - f"success callbacks: Running SYNC Custom Logger Class" - ) if self.stream and complete_streaming_response is None: callback.log_stream_event( kwargs=self.model_call_details, @@ -1377,7 +1385,7 @@ class Logging: start_time=start_time, end_time=end_time, ) - elif ( + if ( callable(callback) == True and self.model_call_details.get("litellm_params", {}).get( "acompletion", False @@ -1452,9 +1460,12 @@ class Logging: "complete_streaming_response" ] = complete_streaming_response try: - self.model_call_details["response_cost"] = litellm.completion_cost( - completion_response=complete_streaming_response, - ) + if self.model_call_details.get("cache_hit", False) == True: + self.model_call_details["response_cost"] = 0.0 + else: + self.model_call_details["response_cost"] = litellm.completion_cost( + completion_response=complete_streaming_response, + ) verbose_logger.debug( f"Model={self.model}; cost={self.model_call_details['response_cost']}" ) @@ -2217,7 +2228,7 @@ def client(original_function): if call_type == CallTypes.completion.value and isinstance( cached_result, dict ): - return convert_to_model_response_object( + cached_result = convert_to_model_response_object( response_object=cached_result, model_response_object=ModelResponse(), stream=kwargs.get("stream", False), @@ -2225,12 +2236,60 @@ def client(original_function): elif call_type == CallTypes.embedding.value and isinstance( cached_result, dict ): - return convert_to_model_response_object( + cached_result = convert_to_model_response_object( response_object=cached_result, response_type="embedding", ) - else: - return cached_result + + # LOG SUCCESS + cache_hit = True + end_time = datetime.datetime.now() + ( + model, + custom_llm_provider, + dynamic_api_key, + api_base, + ) = litellm.get_llm_provider( + model=model, + custom_llm_provider=kwargs.get( + "custom_llm_provider", None + ), + api_base=kwargs.get("api_base", None), + api_key=kwargs.get("api_key", None), + ) + print_verbose( + f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}" + ) + logging_obj.update_environment_variables( + model=model, + user=kwargs.get("user", None), + optional_params={}, + litellm_params={ + "logger_fn": kwargs.get("logger_fn", None), + "acompletion": False, + "metadata": kwargs.get("metadata", {}), + "model_info": kwargs.get("model_info", {}), + "proxy_server_request": kwargs.get( + "proxy_server_request", None + ), + "preset_cache_key": kwargs.get( + "preset_cache_key", None + ), + "stream_response": kwargs.get( + "stream_response", {} + ), + }, + input=kwargs.get("messages", ""), + api_key=kwargs.get("api_key", None), + original_response=str(cached_result), + additional_args=None, + stream=kwargs.get("stream", False), + ) + threading.Thread( + target=logging_obj.success_handler, + args=(cached_result, start_time, end_time, cache_hit), + ).start() + return cached_result # CHECK MAX TOKENS if (