feat(litellm_logging.py): support cost tracking for tts calls

This commit is contained in:
Krrish Dholakia 2024-07-05 22:09:08 -07:00
parent 407639cc7d
commit 6e43cdcb17
4 changed files with 58 additions and 33 deletions

View file

@ -24,6 +24,8 @@ from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.redact_messages import (
redact_message_input_output_from_logging,
)
from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
from litellm.types.utils import (
CallTypes,
EmbeddingResponse,
@ -521,33 +523,36 @@ class Logging:
self.model_call_details["cache_hit"] = cache_hit
## if model in model cost map - log the response cost
## else set cost to None
verbose_logger.debug(f"Model={self.model};")
if (
result is not None
and (
result is not None and self.stream is not True
): # handle streaming separately
if (
isinstance(result, ModelResponse)
or isinstance(result, EmbeddingResponse)
or isinstance(result, ImageResponse)
or isinstance(result, TranscriptionResponse)
or isinstance(result, TextCompletionResponse)
)
and self.stream != True
): # handle streaming separately
self.model_call_details["response_cost"] = (
litellm.response_cost_calculator(
response_object=result,
model=self.model,
cache_hit=self.model_call_details.get("cache_hit", False),
custom_llm_provider=self.model_call_details.get(
"custom_llm_provider", None
),
base_model=_get_base_model_from_metadata(
model_call_details=self.model_call_details
),
call_type=self.call_type,
optional_params=self.optional_params,
or isinstance(result, HttpxBinaryResponseContent) # tts
):
custom_pricing = use_custom_pricing_for_model(
litellm_params=self.litellm_params
)
self.model_call_details["response_cost"] = (
litellm.response_cost_calculator(
response_object=result,
model=self.model,
cache_hit=self.model_call_details.get("cache_hit", False),
custom_llm_provider=self.model_call_details.get(
"custom_llm_provider", None
),
base_model=_get_base_model_from_metadata(
model_call_details=self.model_call_details
),
call_type=self.call_type,
optional_params=self.optional_params,
custom_pricing=custom_pricing,
)
)
)
else: # streaming chunks + image gen.
self.model_call_details["response_cost"] = None
@ -2003,3 +2008,14 @@ def get_custom_logger_compatible_class(
if isinstance(callback, _PROXY_DynamicRateLimitHandler):
return callback # type: ignore
return None
def use_custom_pricing_for_model(litellm_params: dict) -> bool:
model_info: Optional[dict] = litellm_params.get("metadata", {}).get(
"model_info", {}
)
if model_info is not None:
for k, v in model_info.items():
if k in SPECIAL_MODEL_INFO_PARAMS:
return True
return False