Ensure base_model cost tracking works across all endpoints (#7989)

* test(test_completion_cost.py): add sdk test to ensure base model is used for cost tracking

* test(test_completion_cost.py): add sdk test to ensure custom pricing works

* fix(main.py): add base model cost tracking support for embedding calls

Enables base model cost tracking for embedding calls when base model set as a litellm_param

* fix(litellm_logging.py): update logging object with litellm params - including base model, if given

ensures base model param is always tracked

* fix(main.py): fix linting errors
This commit is contained in:
Krish Dholakia 2025-01-24 21:05:26 -08:00 committed by GitHub
parent e01c9c1fc6
commit 5feb5355df
8 changed files with 272 additions and 122 deletions

View file

@ -1527,3 +1527,43 @@ def test_add_custom_logger_callback_to_specific_event_e2e(monkeypatch):
assert len(litellm.success_callback) == curr_len_success_callback
assert len(litellm.failure_callback) == curr_len_failure_callback
@pytest.mark.asyncio
async def test_wrapper_kwargs_passthrough():
from litellm.utils import client
from litellm.litellm_core_utils.litellm_logging import (
Logging as LiteLLMLoggingObject,
)
# Create mock original function
mock_original = AsyncMock()
# Apply decorator
@client
async def test_function(**kwargs):
return await mock_original(**kwargs)
# Test kwargs
test_kwargs = {"base_model": "gpt-4o-mini"}
# Call decorated function
await test_function(**test_kwargs)
mock_original.assert_called_once()
# get litellm logging object
litellm_logging_obj: LiteLLMLoggingObject = mock_original.call_args.kwargs.get(
"litellm_logging_obj"
)
assert litellm_logging_obj is not None
print(
f"litellm_logging_obj.model_call_details: {litellm_logging_obj.model_call_details}"
)
# get base model
assert (
litellm_logging_obj.model_call_details["litellm_params"]["base_model"]
== "gpt-4o-mini"
)