fix(litellm_logging.py): fix price information logging to s3

This commit is contained in:
Krrish Dholakia 2024-08-16 16:42:38 -07:00
parent 178139f18d
commit a92dcdd2d6
3 changed files with 7 additions and 4 deletions

View file

@ -422,14 +422,13 @@ def _select_model_name_for_cost_calc(
3. If completion response has model set return that 3. If completion response has model set return that
4. If model is passed in return that 4. If model is passed in return that
""" """
args = locals()
if custom_pricing is True: if custom_pricing is True:
return model return model
if base_model is not None: if base_model is not None:
return base_model return base_model
return_model = model return_model = model or completion_response.get("model", "") # type: ignore
if hasattr(completion_response, "_hidden_params"): if hasattr(completion_response, "_hidden_params"):
if ( if (
completion_response._hidden_params.get("model", None) is not None completion_response._hidden_params.get("model", None) is not None

View file

@ -2300,7 +2300,7 @@ def get_standard_logging_object_payload(
base_model = _get_base_model_from_metadata(model_call_details=kwargs) base_model = _get_base_model_from_metadata(model_call_details=kwargs)
custom_pricing = use_custom_pricing_for_model(litellm_params=litellm_params) custom_pricing = use_custom_pricing_for_model(litellm_params=litellm_params)
model_cost_name = _select_model_name_for_cost_calc( model_cost_name = _select_model_name_for_cost_calc(
model=kwargs.get("model"), model=None,
completion_response=init_response_obj, completion_response=init_response_obj,
base_model=base_model, base_model=base_model,
custom_pricing=custom_pricing, custom_pricing=custom_pricing,

View file

@ -1205,7 +1205,11 @@ def test_standard_logging_payload(model):
is not None is not None
) )
print(mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]) print(
"Standard Logging Object - {}".format(
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
)
)
keys_list = list(StandardLoggingPayload.__annotations__.keys()) keys_list = list(StandardLoggingPayload.__annotations__.keys())