From a92dcdd2d6b6748a15292230c0a2ff6ed8480c9a Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 16 Aug 2024 16:42:38 -0700 Subject: [PATCH] fix(litellm_logging.py): fix price information logging to s3 --- litellm/cost_calculator.py | 3 +-- litellm/litellm_core_utils/litellm_logging.py | 2 +- litellm/tests/test_custom_callback_input.py | 6 +++++- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index 2a7d7ba86..4b5ac51db 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -422,14 +422,13 @@ def _select_model_name_for_cost_calc( 3. If completion response has model set return that 4. If model is passed in return that """ - args = locals() if custom_pricing is True: return model if base_model is not None: return base_model - return_model = model + return_model = model or completion_response.get("model", "") # type: ignore if hasattr(completion_response, "_hidden_params"): if ( completion_response._hidden_params.get("model", None) is not None diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index b442b5e30..3e7f61f72 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -2300,7 +2300,7 @@ def get_standard_logging_object_payload( base_model = _get_base_model_from_metadata(model_call_details=kwargs) custom_pricing = use_custom_pricing_for_model(litellm_params=litellm_params) model_cost_name = _select_model_name_for_cost_calc( - model=kwargs.get("model"), + model=None, completion_response=init_response_obj, base_model=base_model, custom_pricing=custom_pricing, diff --git a/litellm/tests/test_custom_callback_input.py b/litellm/tests/test_custom_callback_input.py index 374d42ade..ffec5ac7d 100644 --- a/litellm/tests/test_custom_callback_input.py +++ b/litellm/tests/test_custom_callback_input.py @@ -1205,7 +1205,11 @@ def test_standard_logging_payload(model): is not None ) - print(mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]) + print( + "Standard Logging Object - {}".format( + mock_client.call_args.kwargs["kwargs"]["standard_logging_object"] + ) + ) keys_list = list(StandardLoggingPayload.__annotations__.keys())