Ensure base_model cost tracking works across all endpoints (#7989)

* test(test_completion_cost.py): add sdk test to ensure base model is used for cost tracking

* test(test_completion_cost.py): add sdk test to ensure custom pricing works

* fix(main.py): add base model cost tracking support for embedding calls

Enables base model cost tracking for embedding calls when base model set as a litellm_param

* fix(litellm_logging.py): update logging object with litellm params - including base model, if given

ensures base model param is always tracked

* fix(main.py): fix linting errors
This commit is contained in:
Krish Dholakia 2025-01-24 21:05:26 -08:00 committed by GitHub
parent e01c9c1fc6
commit 5feb5355df
8 changed files with 272 additions and 122 deletions

View file

@ -3224,8 +3224,6 @@ def embedding( # noqa: PLR0915
**non_default_params,
)
if mock_response is not None:
return mock_embedding(model=model, mock_response=mock_response)
### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ###
if input_cost_per_token is not None and output_cost_per_token is not None:
litellm.register_model(
@ -3248,28 +3246,22 @@ def embedding( # noqa: PLR0915
}
}
)
litellm_params_dict = get_litellm_params(**kwargs)
logging: Logging = litellm_logging_obj # type: ignore
logging.update_environment_variables(
model=model,
user=user,
optional_params=optional_params,
litellm_params=litellm_params_dict,
custom_llm_provider=custom_llm_provider,
)
if mock_response is not None:
return mock_embedding(model=model, mock_response=mock_response)
try:
response: Optional[EmbeddingResponse] = None
logging: Logging = litellm_logging_obj # type: ignore
logging.update_environment_variables(
model=model,
user=user,
optional_params=optional_params,
litellm_params={
"timeout": timeout,
"azure": azure,
"litellm_call_id": litellm_call_id,
"logger_fn": logger_fn,
"proxy_server_request": proxy_server_request,
"model_info": model_info,
"metadata": metadata,
"aembedding": aembedding,
"preset_cache_key": None,
"stream_response": {},
"cooldown_time": cooldown_time,
},
custom_llm_provider=custom_llm_provider,
)
if azure is True or custom_llm_provider == "azure":
# azure configs
api_type = get_secret_str("AZURE_API_TYPE") or "azure"