mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Ensure base_model cost tracking works across all endpoints (#7989)
* test(test_completion_cost.py): add sdk test to ensure base model is used for cost tracking * test(test_completion_cost.py): add sdk test to ensure custom pricing works * fix(main.py): add base model cost tracking support for embedding calls Enables base model cost tracking for embedding calls when base model set as a litellm_param * fix(litellm_logging.py): update logging object with litellm params - including base model, if given ensures base model param is always tracked * fix(main.py): fix linting errors
This commit is contained in:
parent
e01c9c1fc6
commit
5feb5355df
8 changed files with 272 additions and 122 deletions
|
@ -1527,3 +1527,43 @@ def test_add_custom_logger_callback_to_specific_event_e2e(monkeypatch):
|
|||
|
||||
assert len(litellm.success_callback) == curr_len_success_callback
|
||||
assert len(litellm.failure_callback) == curr_len_failure_callback
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wrapper_kwargs_passthrough():
|
||||
from litellm.utils import client
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
Logging as LiteLLMLoggingObject,
|
||||
)
|
||||
|
||||
# Create mock original function
|
||||
mock_original = AsyncMock()
|
||||
|
||||
# Apply decorator
|
||||
@client
|
||||
async def test_function(**kwargs):
|
||||
return await mock_original(**kwargs)
|
||||
|
||||
# Test kwargs
|
||||
test_kwargs = {"base_model": "gpt-4o-mini"}
|
||||
|
||||
# Call decorated function
|
||||
await test_function(**test_kwargs)
|
||||
|
||||
mock_original.assert_called_once()
|
||||
|
||||
# get litellm logging object
|
||||
litellm_logging_obj: LiteLLMLoggingObject = mock_original.call_args.kwargs.get(
|
||||
"litellm_logging_obj"
|
||||
)
|
||||
assert litellm_logging_obj is not None
|
||||
|
||||
print(
|
||||
f"litellm_logging_obj.model_call_details: {litellm_logging_obj.model_call_details}"
|
||||
)
|
||||
|
||||
# get base model
|
||||
assert (
|
||||
litellm_logging_obj.model_call_details["litellm_params"]["base_model"]
|
||||
== "gpt-4o-mini"
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue