From 0a73fa6d013a5863844014b41db103676a7de2a7 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 12 Nov 2024 17:45:12 -0800 Subject: [PATCH] fix undo changes cost tracking --- litellm/cost_calculator.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index 11cd58c5b..03a86fb13 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -171,7 +171,6 @@ def cost_per_token( # noqa: PLR0915 model_with_provider = model_with_provider_and_region else: _, custom_llm_provider, _, _ = litellm.get_llm_provider(model=model) - model_without_prefix = model model_parts = model.split("/", 1) if len(model_parts) > 1: @@ -516,7 +515,6 @@ def _infer_call_type( def completion_cost( # noqa: PLR0915 completion_response=None, model: Optional[str] = None, - base_model: Optional[str] = None, prompt="", messages: List = [], completion="", @@ -530,7 +528,6 @@ def completion_cost( # noqa: PLR0915 quality=None, n=None, # number of images ### CUSTOM PRICING ### - custom_pricing: Optional[bool] = None, custom_cost_per_token: Optional[CostPerToken] = None, custom_cost_per_second: Optional[float] = None, optional_params: Optional[dict] = None, @@ -625,8 +622,6 @@ def completion_cost( # noqa: PLR0915 model = _select_model_name_for_cost_calc( model=model, completion_response=completion_response, - base_model=base_model, - custom_pricing=custom_pricing, ) hidden_params = getattr(completion_response, "_hidden_params", None) if hidden_params is not None: @@ -858,15 +853,24 @@ def response_cost_calculator( else: if isinstance(response_object, BaseModel): response_object._hidden_params["optional_params"] = optional_params - # base_model defaults to None if not set on model_info - response_cost = completion_cost( - completion_response=response_object, - call_type=call_type, - model=model, - base_model=base_model, - custom_llm_provider=custom_llm_provider, - custom_pricing=custom_pricing, - ) + if isinstance(response_object, ImageResponse): + response_cost = completion_cost( + completion_response=response_object, + model=model, + call_type=call_type, + custom_llm_provider=custom_llm_provider, + optional_params=optional_params, + ) + else: + if custom_pricing is True: # override defaults if custom pricing is set + base_model = model + # base_model defaults to None if not set on model_info + response_cost = completion_cost( + completion_response=response_object, + call_type=call_type, + model=base_model, + custom_llm_provider=custom_llm_provider, + ) return response_cost except Exception as e: raise e