Ensure base_model cost tracking works across all endpoints (#7989)

* test(test_completion_cost.py): add sdk test to ensure base model is used for cost tracking

* test(test_completion_cost.py): add sdk test to ensure custom pricing works

* fix(main.py): add base model cost tracking support for embedding calls

Enables base model cost tracking for embedding calls when base model set as a litellm_param

* fix(litellm_logging.py): update logging object with litellm params - including base model, if given

ensures base model param is always tracked

* fix(main.py): fix linting errors
This commit is contained in:
Krish Dholakia 2025-01-24 21:05:26 -08:00 committed by GitHub
parent e01c9c1fc6
commit 5feb5355df
8 changed files with 272 additions and 122 deletions

View file

@ -71,6 +71,10 @@ from litellm.litellm_core_utils.exception_mapping_utils import (
exception_type,
get_error_message,
)
from litellm.litellm_core_utils.get_litellm_params import (
_get_base_model_from_litellm_call_metadata,
get_litellm_params,
)
from litellm.litellm_core_utils.get_llm_provider_logic import (
_is_non_openai_azure_model,
get_llm_provider,
@ -2094,88 +2098,6 @@ def register_model(model_cost: Union[str, dict]): # noqa: PLR0915
return model_cost
def get_litellm_params(
api_key: Optional[str] = None,
force_timeout=600,
azure=False,
logger_fn=None,
verbose=False,
hugging_face=False,
replicate=False,
together_ai=False,
custom_llm_provider: Optional[str] = None,
api_base: Optional[str] = None,
litellm_call_id=None,
model_alias_map=None,
completion_call_id=None,
metadata: Optional[dict] = None,
model_info=None,
proxy_server_request=None,
acompletion=None,
preset_cache_key=None,
no_log=None,
input_cost_per_second=None,
input_cost_per_token=None,
output_cost_per_token=None,
output_cost_per_second=None,
cooldown_time=None,
text_completion=None,
azure_ad_token_provider=None,
user_continue_message=None,
base_model: Optional[str] = None,
litellm_trace_id: Optional[str] = None,
hf_model_name: Optional[str] = None,
custom_prompt_dict: Optional[dict] = None,
litellm_metadata: Optional[dict] = None,
disable_add_transform_inline_image_block: Optional[bool] = None,
drop_params: Optional[bool] = None,
prompt_id: Optional[str] = None,
prompt_variables: Optional[dict] = None,
async_call: Optional[bool] = None,
ssl_verify: Optional[bool] = None,
**kwargs,
) -> dict:
litellm_params = {
"acompletion": acompletion,
"api_key": api_key,
"force_timeout": force_timeout,
"logger_fn": logger_fn,
"verbose": verbose,
"custom_llm_provider": custom_llm_provider,
"api_base": api_base,
"litellm_call_id": litellm_call_id,
"model_alias_map": model_alias_map,
"completion_call_id": completion_call_id,
"metadata": metadata,
"model_info": model_info,
"proxy_server_request": proxy_server_request,
"preset_cache_key": preset_cache_key,
"no-log": no_log,
"stream_response": {}, # litellm_call_id: ModelResponse Dict
"input_cost_per_token": input_cost_per_token,
"input_cost_per_second": input_cost_per_second,
"output_cost_per_token": output_cost_per_token,
"output_cost_per_second": output_cost_per_second,
"cooldown_time": cooldown_time,
"text_completion": text_completion,
"azure_ad_token_provider": azure_ad_token_provider,
"user_continue_message": user_continue_message,
"base_model": base_model
or _get_base_model_from_litellm_call_metadata(metadata=metadata),
"litellm_trace_id": litellm_trace_id,
"hf_model_name": hf_model_name,
"custom_prompt_dict": custom_prompt_dict,
"litellm_metadata": litellm_metadata,
"disable_add_transform_inline_image_block": disable_add_transform_inline_image_block,
"drop_params": drop_params,
"prompt_id": prompt_id,
"prompt_variables": prompt_variables,
"async_call": async_call,
"ssl_verify": ssl_verify,
}
return litellm_params
def _should_drop_param(k, additional_drop_params) -> bool:
if (
additional_drop_params is not None
@ -5666,22 +5588,6 @@ def get_logging_id(start_time, response_obj):
return None
def _get_base_model_from_litellm_call_metadata(
metadata: Optional[dict],
) -> Optional[str]:
if metadata is None:
return None
if metadata is not None:
model_info = metadata.get("model_info", {})
if model_info is not None:
base_model = model_info.get("base_model", None)
if base_model is not None:
return base_model
return None
def _get_base_model_from_metadata(model_call_details=None):
if model_call_details is None:
return None