fix(litellm_logging.py): use 1 cost calc function across response headers + logging integrations

Ensures consistent cost calculation when azure base models are used
This commit is contained in:
Krrish Dholakia 2024-08-01 10:26:59 -07:00
parent d02d3d9712
commit 10b571ca42
3 changed files with 40 additions and 24 deletions

View file

@ -12,6 +12,8 @@ import traceback
import uuid import uuid
from typing import Any, Callable, Dict, List, Literal, Optional from typing import Any, Callable, Dict, List, Literal, Optional
from pydantic import BaseModel
import litellm import litellm
from litellm import ( from litellm import (
json_logs, json_logs,
@ -503,6 +505,34 @@ class Logging:
) )
) )
def _response_cost_calculator(self, result: BaseModel):
"""
Calculate response cost using result + logging object variables.
used for consistent cost calculation across response headers + logging integrations.
"""
## RESPONSE COST ##
custom_pricing = use_custom_pricing_for_model(
litellm_params=self.litellm_params
)
response_cost = litellm.response_cost_calculator(
response_object=result,
model=self.model,
cache_hit=self.model_call_details.get("cache_hit", False),
custom_llm_provider=self.model_call_details.get(
"custom_llm_provider", None
),
base_model=_get_base_model_from_metadata(
model_call_details=self.model_call_details
),
call_type=self.call_type,
optional_params=self.optional_params,
custom_pricing=custom_pricing,
)
return response_cost
def _success_handler_helper_fn( def _success_handler_helper_fn(
self, result=None, start_time=None, end_time=None, cache_hit=None self, result=None, start_time=None, end_time=None, cache_hit=None
): ):
@ -537,20 +567,7 @@ class Logging:
litellm_params=self.litellm_params litellm_params=self.litellm_params
) )
self.model_call_details["response_cost"] = ( self.model_call_details["response_cost"] = (
litellm.response_cost_calculator( self._response_cost_calculator(result=result)
response_object=result,
model=self.model,
cache_hit=self.model_call_details.get("cache_hit", False),
custom_llm_provider=self.model_call_details.get(
"custom_llm_provider", None
),
base_model=_get_base_model_from_metadata(
model_call_details=self.model_call_details
),
call_type=self.call_type,
optional_params=self.optional_params,
custom_pricing=custom_pricing,
)
) )
## HIDDEN PARAMS ## ## HIDDEN PARAMS ##

View file

@ -2,3 +2,10 @@ model_list:
- model_name: "*" - model_name: "*"
litellm_params: litellm_params:
model: "*" model: "*"
- model_name: "azure-gpt-4o-mini"
litellm_params:
model: azure/my-gpt-4o-mini
api_key: os.environ/AZURE_API_KEY
api_base: os.environ/AZURE_API_BASE
model_info:
base_model: azure/global-standard/gpt-4o-mini

View file

@ -1368,15 +1368,7 @@ def client(original_function):
optional_params=kwargs, optional_params=kwargs,
) )
result._hidden_params["response_cost"] = ( result._hidden_params["response_cost"] = (
litellm.response_cost_calculator( logging_obj._response_cost_calculator(result=result)
response_object=result,
model=getattr(logging_obj, "model", ""),
custom_llm_provider=getattr(
logging_obj, "custom_llm_provider", None
),
call_type=getattr(logging_obj, "call_type", "completion"),
optional_params=getattr(logging_obj, "optional_params", {}),
)
) )
if ( if (
isinstance(result, ModelResponse) isinstance(result, ModelResponse)