forked from phoenix/litellm-mirror
fix(utils.py): ensure consistent cost calc b/w returned header and logged object
This commit is contained in:
parent
f51f7750c0
commit
8e9117f701
3 changed files with 20 additions and 26 deletions
|
@ -582,9 +582,6 @@ class Logging:
|
||||||
or isinstance(result, HttpxBinaryResponseContent) # tts
|
or isinstance(result, HttpxBinaryResponseContent) # tts
|
||||||
):
|
):
|
||||||
## RESPONSE COST ##
|
## RESPONSE COST ##
|
||||||
custom_pricing = use_custom_pricing_for_model(
|
|
||||||
litellm_params=self.litellm_params
|
|
||||||
)
|
|
||||||
self.model_call_details["response_cost"] = (
|
self.model_call_details["response_cost"] = (
|
||||||
self._response_cost_calculator(result=result)
|
self._response_cost_calculator(result=result)
|
||||||
)
|
)
|
||||||
|
@ -2159,6 +2156,9 @@ def get_custom_logger_compatible_class(
|
||||||
def use_custom_pricing_for_model(litellm_params: Optional[dict]) -> bool:
|
def use_custom_pricing_for_model(litellm_params: Optional[dict]) -> bool:
|
||||||
if litellm_params is None:
|
if litellm_params is None:
|
||||||
return False
|
return False
|
||||||
|
for k, v in litellm_params.items():
|
||||||
|
if k in SPECIAL_MODEL_INFO_PARAMS:
|
||||||
|
return True
|
||||||
metadata: Optional[dict] = litellm_params.get("metadata", {})
|
metadata: Optional[dict] = litellm_params.get("metadata", {})
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
return False
|
return False
|
||||||
|
@ -2167,6 +2167,7 @@ def use_custom_pricing_for_model(litellm_params: Optional[dict]) -> bool:
|
||||||
for k, v in model_info.items():
|
for k, v in model_info.items():
|
||||||
if k in SPECIAL_MODEL_INFO_PARAMS:
|
if k in SPECIAL_MODEL_INFO_PARAMS:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -55,14 +55,15 @@ router = Router(
|
||||||
"model",
|
"model",
|
||||||
[
|
[
|
||||||
"openai/gpt-3.5-turbo",
|
"openai/gpt-3.5-turbo",
|
||||||
"anthropic/claude-3-haiku-20240307",
|
# "anthropic/claude-3-haiku-20240307",
|
||||||
"together_ai/meta-llama/Llama-2-7b-chat-hf",
|
# "together_ai/meta-llama/Llama-2-7b-chat-hf",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_run(model: str):
|
def test_run(model: str):
|
||||||
"""
|
"""
|
||||||
Relevant issue - https://github.com/BerriAI/litellm/issues/4965
|
Relevant issue - https://github.com/BerriAI/litellm/issues/4965
|
||||||
"""
|
"""
|
||||||
|
# litellm.set_verbose = True
|
||||||
prompt = "Hi"
|
prompt = "Hi"
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -97,9 +98,9 @@ def test_run(model: str):
|
||||||
streaming_cost_calc = completion_cost(response) * 100
|
streaming_cost_calc = completion_cost(response) * 100
|
||||||
print(f"Stream output : {output}")
|
print(f"Stream output : {output}")
|
||||||
|
|
||||||
if output == non_stream_output:
|
|
||||||
# assert cost is the same
|
|
||||||
assert streaming_cost_calc == non_stream_cost_calc
|
|
||||||
print(f"Stream usage : {response.usage}") # type: ignore
|
print(f"Stream usage : {response.usage}") # type: ignore
|
||||||
print(f"Stream cost : {streaming_cost_calc} (response)")
|
print(f"Stream cost : {streaming_cost_calc} (response)")
|
||||||
print("")
|
print("")
|
||||||
|
if output == non_stream_output:
|
||||||
|
# assert cost is the same
|
||||||
|
assert streaming_cost_calc == non_stream_cost_calc
|
||||||
|
|
|
@ -837,7 +837,7 @@ def client(original_function):
|
||||||
and kwargs.get("atranscription", False) != True
|
and kwargs.get("atranscription", False) != True
|
||||||
): # allow users to control returning cached responses from the completion function
|
): # allow users to control returning cached responses from the completion function
|
||||||
# checking cache
|
# checking cache
|
||||||
print_verbose(f"INSIDE CHECKING CACHE")
|
print_verbose("INSIDE CHECKING CACHE")
|
||||||
if (
|
if (
|
||||||
litellm.cache is not None
|
litellm.cache is not None
|
||||||
and str(original_function.__name__)
|
and str(original_function.__name__)
|
||||||
|
@ -965,10 +965,10 @@ def client(original_function):
|
||||||
# MODEL CALL
|
# MODEL CALL
|
||||||
result = original_function(*args, **kwargs)
|
result = original_function(*args, **kwargs)
|
||||||
end_time = datetime.datetime.now()
|
end_time = datetime.datetime.now()
|
||||||
if "stream" in kwargs and kwargs["stream"] == True:
|
if "stream" in kwargs and kwargs["stream"] is True:
|
||||||
if (
|
if (
|
||||||
"complete_response" in kwargs
|
"complete_response" in kwargs
|
||||||
and kwargs["complete_response"] == True
|
and kwargs["complete_response"] is True
|
||||||
):
|
):
|
||||||
chunks = []
|
chunks = []
|
||||||
for idx, chunk in enumerate(result):
|
for idx, chunk in enumerate(result):
|
||||||
|
@ -978,15 +978,15 @@ def client(original_function):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return result
|
return result
|
||||||
elif "acompletion" in kwargs and kwargs["acompletion"] == True:
|
elif "acompletion" in kwargs and kwargs["acompletion"] is True:
|
||||||
return result
|
return result
|
||||||
elif "aembedding" in kwargs and kwargs["aembedding"] == True:
|
elif "aembedding" in kwargs and kwargs["aembedding"] is True:
|
||||||
return result
|
return result
|
||||||
elif "aimg_generation" in kwargs and kwargs["aimg_generation"] == True:
|
elif "aimg_generation" in kwargs and kwargs["aimg_generation"] is True:
|
||||||
return result
|
return result
|
||||||
elif "atranscription" in kwargs and kwargs["atranscription"] == True:
|
elif "atranscription" in kwargs and kwargs["atranscription"] is True:
|
||||||
return result
|
return result
|
||||||
elif "aspeech" in kwargs and kwargs["aspeech"] == True:
|
elif "aspeech" in kwargs and kwargs["aspeech"] is True:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
### POST-CALL RULES ###
|
### POST-CALL RULES ###
|
||||||
|
@ -1005,7 +1005,7 @@ def client(original_function):
|
||||||
litellm.cache.add_cache(result, *args, **kwargs)
|
litellm.cache.add_cache(result, *args, **kwargs)
|
||||||
|
|
||||||
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
|
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
|
||||||
verbose_logger.info(f"Wrapper: Completed Call, calling success_handler")
|
verbose_logger.info("Wrapper: Completed Call, calling success_handler")
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
target=logging_obj.success_handler, args=(result, start_time, end_time)
|
target=logging_obj.success_handler, args=(result, start_time, end_time)
|
||||||
).start()
|
).start()
|
||||||
|
@ -1019,15 +1019,7 @@ def client(original_function):
|
||||||
optional_params=getattr(logging_obj, "optional_params", {}),
|
optional_params=getattr(logging_obj, "optional_params", {}),
|
||||||
)
|
)
|
||||||
result._hidden_params["response_cost"] = (
|
result._hidden_params["response_cost"] = (
|
||||||
litellm.response_cost_calculator(
|
logging_obj._response_cost_calculator(result=result)
|
||||||
response_object=result,
|
|
||||||
model=getattr(logging_obj, "model", ""),
|
|
||||||
custom_llm_provider=getattr(
|
|
||||||
logging_obj, "custom_llm_provider", None
|
|
||||||
),
|
|
||||||
call_type=getattr(logging_obj, "call_type", "completion"),
|
|
||||||
optional_params=getattr(logging_obj, "optional_params", {}),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
result._response_ms = (
|
result._response_ms = (
|
||||||
end_time - start_time
|
end_time - start_time
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue