diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 29c181ee02..c81b1b123d 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -582,9 +582,6 @@ class Logging: or isinstance(result, HttpxBinaryResponseContent) # tts ): ## RESPONSE COST ## - custom_pricing = use_custom_pricing_for_model( - litellm_params=self.litellm_params - ) self.model_call_details["response_cost"] = ( self._response_cost_calculator(result=result) ) @@ -2159,6 +2156,9 @@ def get_custom_logger_compatible_class( def use_custom_pricing_for_model(litellm_params: Optional[dict]) -> bool: if litellm_params is None: return False + for k, v in litellm_params.items(): + if k in SPECIAL_MODEL_INFO_PARAMS: + return True metadata: Optional[dict] = litellm_params.get("metadata", {}) if metadata is None: return False @@ -2167,6 +2167,7 @@ def use_custom_pricing_for_model(litellm_params: Optional[dict]) -> bool: for k, v in model_info.items(): if k in SPECIAL_MODEL_INFO_PARAMS: return True + return False diff --git a/litellm/tests/test_cost_calc.py b/litellm/tests/test_cost_calc.py index 39d3c28fd7..ecead06794 100644 --- a/litellm/tests/test_cost_calc.py +++ b/litellm/tests/test_cost_calc.py @@ -55,14 +55,15 @@ router = Router( "model", [ "openai/gpt-3.5-turbo", - "anthropic/claude-3-haiku-20240307", - "together_ai/meta-llama/Llama-2-7b-chat-hf", + # "anthropic/claude-3-haiku-20240307", + # "together_ai/meta-llama/Llama-2-7b-chat-hf", ], ) def test_run(model: str): """ Relevant issue - https://github.com/BerriAI/litellm/issues/4965 """ + # litellm.set_verbose = True prompt = "Hi" kwargs = dict( model=model, @@ -97,9 +98,9 @@ def test_run(model: str): streaming_cost_calc = completion_cost(response) * 100 print(f"Stream output : {output}") - if output == non_stream_output: - # assert cost is the same - assert streaming_cost_calc == non_stream_cost_calc print(f"Stream usage : {response.usage}") # type: ignore print(f"Stream cost : {streaming_cost_calc} (response)") print("") + if output == non_stream_output: + # assert cost is the same + assert streaming_cost_calc == non_stream_cost_calc diff --git a/litellm/utils.py b/litellm/utils.py index 0aba7e5051..7edcabe863 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -837,7 +837,7 @@ def client(original_function): and kwargs.get("atranscription", False) != True ): # allow users to control returning cached responses from the completion function # checking cache - print_verbose(f"INSIDE CHECKING CACHE") + print_verbose("INSIDE CHECKING CACHE") if ( litellm.cache is not None and str(original_function.__name__) @@ -965,10 +965,10 @@ def client(original_function): # MODEL CALL result = original_function(*args, **kwargs) end_time = datetime.datetime.now() - if "stream" in kwargs and kwargs["stream"] == True: + if "stream" in kwargs and kwargs["stream"] is True: if ( "complete_response" in kwargs - and kwargs["complete_response"] == True + and kwargs["complete_response"] is True ): chunks = [] for idx, chunk in enumerate(result): @@ -978,15 +978,15 @@ def client(original_function): ) else: return result - elif "acompletion" in kwargs and kwargs["acompletion"] == True: + elif "acompletion" in kwargs and kwargs["acompletion"] is True: return result - elif "aembedding" in kwargs and kwargs["aembedding"] == True: + elif "aembedding" in kwargs and kwargs["aembedding"] is True: return result - elif "aimg_generation" in kwargs and kwargs["aimg_generation"] == True: + elif "aimg_generation" in kwargs and kwargs["aimg_generation"] is True: return result - elif "atranscription" in kwargs and kwargs["atranscription"] == True: + elif "atranscription" in kwargs and kwargs["atranscription"] is True: return result - elif "aspeech" in kwargs and kwargs["aspeech"] == True: + elif "aspeech" in kwargs and kwargs["aspeech"] is True: return result ### POST-CALL RULES ### @@ -1005,7 +1005,7 @@ def client(original_function): litellm.cache.add_cache(result, *args, **kwargs) # LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated - verbose_logger.info(f"Wrapper: Completed Call, calling success_handler") + verbose_logger.info("Wrapper: Completed Call, calling success_handler") threading.Thread( target=logging_obj.success_handler, args=(result, start_time, end_time) ).start() @@ -1019,15 +1019,7 @@ def client(original_function): optional_params=getattr(logging_obj, "optional_params", {}), ) result._hidden_params["response_cost"] = ( - litellm.response_cost_calculator( - response_object=result, - model=getattr(logging_obj, "model", ""), - custom_llm_provider=getattr( - logging_obj, "custom_llm_provider", None - ), - call_type=getattr(logging_obj, "call_type", "completion"), - optional_params=getattr(logging_obj, "optional_params", {}), - ) + logging_obj._response_cost_calculator(result=result) ) result._response_ms = ( end_time - start_time