fix(utils.py): ensure consistent cost calc b/w returned header and logged object

This commit is contained in:
Krrish Dholakia 2024-08-20 18:03:35 -07:00
parent f51f7750c0
commit 8e9117f701
3 changed files with 20 additions and 26 deletions

View file

@ -582,9 +582,6 @@ class Logging:
or isinstance(result, HttpxBinaryResponseContent) # tts or isinstance(result, HttpxBinaryResponseContent) # tts
): ):
## RESPONSE COST ## ## RESPONSE COST ##
custom_pricing = use_custom_pricing_for_model(
litellm_params=self.litellm_params
)
self.model_call_details["response_cost"] = ( self.model_call_details["response_cost"] = (
self._response_cost_calculator(result=result) self._response_cost_calculator(result=result)
) )
@ -2159,6 +2156,9 @@ def get_custom_logger_compatible_class(
def use_custom_pricing_for_model(litellm_params: Optional[dict]) -> bool: def use_custom_pricing_for_model(litellm_params: Optional[dict]) -> bool:
if litellm_params is None: if litellm_params is None:
return False return False
for k, v in litellm_params.items():
if k in SPECIAL_MODEL_INFO_PARAMS:
return True
metadata: Optional[dict] = litellm_params.get("metadata", {}) metadata: Optional[dict] = litellm_params.get("metadata", {})
if metadata is None: if metadata is None:
return False return False
@ -2167,6 +2167,7 @@ def use_custom_pricing_for_model(litellm_params: Optional[dict]) -> bool:
for k, v in model_info.items(): for k, v in model_info.items():
if k in SPECIAL_MODEL_INFO_PARAMS: if k in SPECIAL_MODEL_INFO_PARAMS:
return True return True
return False return False

View file

@ -55,14 +55,15 @@ router = Router(
"model", "model",
[ [
"openai/gpt-3.5-turbo", "openai/gpt-3.5-turbo",
"anthropic/claude-3-haiku-20240307", # "anthropic/claude-3-haiku-20240307",
"together_ai/meta-llama/Llama-2-7b-chat-hf", # "together_ai/meta-llama/Llama-2-7b-chat-hf",
], ],
) )
def test_run(model: str): def test_run(model: str):
""" """
Relevant issue - https://github.com/BerriAI/litellm/issues/4965 Relevant issue - https://github.com/BerriAI/litellm/issues/4965
""" """
# litellm.set_verbose = True
prompt = "Hi" prompt = "Hi"
kwargs = dict( kwargs = dict(
model=model, model=model,
@ -97,9 +98,9 @@ def test_run(model: str):
streaming_cost_calc = completion_cost(response) * 100 streaming_cost_calc = completion_cost(response) * 100
print(f"Stream output : {output}") print(f"Stream output : {output}")
if output == non_stream_output:
# assert cost is the same
assert streaming_cost_calc == non_stream_cost_calc
print(f"Stream usage : {response.usage}") # type: ignore print(f"Stream usage : {response.usage}") # type: ignore
print(f"Stream cost : {streaming_cost_calc} (response)") print(f"Stream cost : {streaming_cost_calc} (response)")
print("") print("")
if output == non_stream_output:
# assert cost is the same
assert streaming_cost_calc == non_stream_cost_calc

View file

@ -837,7 +837,7 @@ def client(original_function):
and kwargs.get("atranscription", False) != True and kwargs.get("atranscription", False) != True
): # allow users to control returning cached responses from the completion function ): # allow users to control returning cached responses from the completion function
# checking cache # checking cache
print_verbose(f"INSIDE CHECKING CACHE") print_verbose("INSIDE CHECKING CACHE")
if ( if (
litellm.cache is not None litellm.cache is not None
and str(original_function.__name__) and str(original_function.__name__)
@ -965,10 +965,10 @@ def client(original_function):
# MODEL CALL # MODEL CALL
result = original_function(*args, **kwargs) result = original_function(*args, **kwargs)
end_time = datetime.datetime.now() end_time = datetime.datetime.now()
if "stream" in kwargs and kwargs["stream"] == True: if "stream" in kwargs and kwargs["stream"] is True:
if ( if (
"complete_response" in kwargs "complete_response" in kwargs
and kwargs["complete_response"] == True and kwargs["complete_response"] is True
): ):
chunks = [] chunks = []
for idx, chunk in enumerate(result): for idx, chunk in enumerate(result):
@ -978,15 +978,15 @@ def client(original_function):
) )
else: else:
return result return result
elif "acompletion" in kwargs and kwargs["acompletion"] == True: elif "acompletion" in kwargs and kwargs["acompletion"] is True:
return result return result
elif "aembedding" in kwargs and kwargs["aembedding"] == True: elif "aembedding" in kwargs and kwargs["aembedding"] is True:
return result return result
elif "aimg_generation" in kwargs and kwargs["aimg_generation"] == True: elif "aimg_generation" in kwargs and kwargs["aimg_generation"] is True:
return result return result
elif "atranscription" in kwargs and kwargs["atranscription"] == True: elif "atranscription" in kwargs and kwargs["atranscription"] is True:
return result return result
elif "aspeech" in kwargs and kwargs["aspeech"] == True: elif "aspeech" in kwargs and kwargs["aspeech"] is True:
return result return result
### POST-CALL RULES ### ### POST-CALL RULES ###
@ -1005,7 +1005,7 @@ def client(original_function):
litellm.cache.add_cache(result, *args, **kwargs) litellm.cache.add_cache(result, *args, **kwargs)
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated # LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
verbose_logger.info(f"Wrapper: Completed Call, calling success_handler") verbose_logger.info("Wrapper: Completed Call, calling success_handler")
threading.Thread( threading.Thread(
target=logging_obj.success_handler, args=(result, start_time, end_time) target=logging_obj.success_handler, args=(result, start_time, end_time)
).start() ).start()
@ -1019,15 +1019,7 @@ def client(original_function):
optional_params=getattr(logging_obj, "optional_params", {}), optional_params=getattr(logging_obj, "optional_params", {}),
) )
result._hidden_params["response_cost"] = ( result._hidden_params["response_cost"] = (
litellm.response_cost_calculator( logging_obj._response_cost_calculator(result=result)
response_object=result,
model=getattr(logging_obj, "model", ""),
custom_llm_provider=getattr(
logging_obj, "custom_llm_provider", None
),
call_type=getattr(logging_obj, "call_type", "completion"),
optional_params=getattr(logging_obj, "optional_params", {}),
)
) )
result._response_ms = ( result._response_ms = (
end_time - start_time end_time - start_time