feat(litellm_logging.py): support logging model price information to s3 logs

This commit is contained in:
Krrish Dholakia 2024-08-16 16:21:34 -07:00
parent 9c3124c5a7
commit 178139f18d
9 changed files with 97 additions and 26 deletions

View file

@ -410,6 +410,36 @@ def get_replicate_completion_pricing(completion_response=None, total_time=0.0):
return a100_80gb_price_per_second_public * total_time / 1000 return a100_80gb_price_per_second_public * total_time / 1000
def _select_model_name_for_cost_calc(
model: Optional[str],
completion_response: Union[BaseModel, dict],
base_model: Optional[str] = None,
custom_pricing: Optional[bool] = None,
) -> Optional[str]:
"""
1. If custom pricing is true, return received model name
2. If base_model is set (e.g. for azure models), return that
3. If completion response has model set return that
4. If model is passed in return that
"""
args = locals()
if custom_pricing is True:
return model
if base_model is not None:
return base_model
return_model = model
if hasattr(completion_response, "_hidden_params"):
if (
completion_response._hidden_params.get("model", None) is not None
and len(completion_response._hidden_params["model"]) > 0
):
return_model = completion_response._hidden_params.get("model", model)
return return_model
def completion_cost( def completion_cost(
completion_response=None, completion_response=None,
model: Optional[str] = None, model: Optional[str] = None,
@ -511,15 +541,10 @@ def completion_cost(
verbose_logger.debug( verbose_logger.debug(
f"completion_response response ms: {getattr(completion_response, '_response_ms', None)} " f"completion_response response ms: {getattr(completion_response, '_response_ms', None)} "
) )
model = model or completion_response.get( model = _select_model_name_for_cost_calc(
"model", None model=model, completion_response=completion_response
) # check if user passed an override for model, if it's none check completion_response['model'] )
if hasattr(completion_response, "_hidden_params"): if hasattr(completion_response, "_hidden_params"):
if (
completion_response._hidden_params.get("model", None) is not None
and len(completion_response._hidden_params["model"]) > 0
):
model = completion_response._hidden_params.get("model", model)
custom_llm_provider = completion_response._hidden_params.get( custom_llm_provider = completion_response._hidden_params.get(
"custom_llm_provider", custom_llm_provider or "" "custom_llm_provider", custom_llm_provider or ""
) )

View file

@ -24,6 +24,7 @@ from litellm import (
verbose_logger, verbose_logger,
) )
from litellm.caching import DualCache, InMemoryCache, S3Cache from litellm.caching import DualCache, InMemoryCache, S3Cache
from litellm.cost_calculator import _select_model_name_for_cost_calc
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.redact_messages import ( from litellm.litellm_core_utils.redact_messages import (
redact_message_input_output_from_logging, redact_message_input_output_from_logging,
@ -37,6 +38,7 @@ from litellm.types.utils import (
ModelResponse, ModelResponse,
StandardLoggingHiddenParams, StandardLoggingHiddenParams,
StandardLoggingMetadata, StandardLoggingMetadata,
StandardLoggingModelInformation,
StandardLoggingPayload, StandardLoggingPayload,
TextCompletionResponse, TextCompletionResponse,
TranscriptionResponse, TranscriptionResponse,
@ -2294,6 +2296,38 @@ def get_standard_logging_object_payload(
id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id
## Get model cost information ##
base_model = _get_base_model_from_metadata(model_call_details=kwargs)
custom_pricing = use_custom_pricing_for_model(litellm_params=litellm_params)
model_cost_name = _select_model_name_for_cost_calc(
model=kwargs.get("model"),
completion_response=init_response_obj,
base_model=base_model,
custom_pricing=custom_pricing,
)
if model_cost_name is None:
model_cost_information = StandardLoggingModelInformation(
model_map_key="", model_map_value=None
)
else:
custom_llm_provider = kwargs.get("custom_llm_provider", None)
try:
_model_cost_information = litellm.get_model_info(
model=model_cost_name, custom_llm_provider=custom_llm_provider
)
model_cost_information = StandardLoggingModelInformation(
model_map_key=model_cost_name,
model_map_value=_model_cost_information,
)
except Exception:
verbose_logger.warning(
"Model is not mapped in model cost map. Defaulting to None model_cost_information for standard_logging_payload"
)
model_cost_information = StandardLoggingModelInformation(
model_map_key=model_cost_name, model_map_value=None
)
payload: StandardLoggingPayload = StandardLoggingPayload( payload: StandardLoggingPayload = StandardLoggingPayload(
id=str(id), id=str(id),
call_type=call_type or "", call_type=call_type or "",
@ -2320,6 +2354,7 @@ def get_standard_logging_object_payload(
), ),
model_parameters=kwargs.get("optional_params", None), model_parameters=kwargs.get("optional_params", None),
hidden_params=clean_hidden_params, hidden_params=clean_hidden_params,
model_map_information=model_cost_information,
) )
verbose_logger.debug( verbose_logger.debug(

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1,9 +1,12 @@
# model_list: model_list:
# - model_name: "gpt-4" - model_name: "*"
# litellm_params: litellm_params:
# model: "gpt-4" model: "*"
# model_info:
# my_custom_key: "my_custom_value"
general_settings: litellm_settings:
infer_model_from_keys: true success_callback: ["s3"]
s3_callback_params:
s3_bucket_name: mytestbucketlitellm # AWS Bucket Name for S3
s3_region_name: us-west-2 # AWS Region Name for S3
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3

View file

@ -251,7 +251,7 @@ def test_cost_azure_gpt_35():
) )
cost = litellm.completion_cost( cost = litellm.completion_cost(
completion_response=resp, model="azure/gpt-35-turbo" completion_response=resp, model="azure/chatgpt-v-2"
) )
print("\n Calculated Cost for azure/gpt-3.5-turbo", cost) print("\n Calculated Cost for azure/gpt-3.5-turbo", cost)
input_cost = model_cost["azure/gpt-35-turbo"]["input_cost_per_token"] input_cost = model_cost["azure/gpt-35-turbo"]["input_cost_per_token"]
@ -262,9 +262,7 @@ def test_cost_azure_gpt_35():
print("\n Excpected cost", expected_cost) print("\n Excpected cost", expected_cost)
assert cost == expected_cost assert cost == expected_cost
except Exception as e: except Exception as e:
pytest.fail( pytest.fail(f"Cost Calc failed for azure/gpt-3.5-turbo. {str(e)}")
f"Cost Calc failed for azure/gpt-3.5-turbo. Expected {expected_cost}, Calculated cost {cost}"
)
# test_cost_azure_gpt_35() # test_cost_azure_gpt_35()

View file

@ -1171,7 +1171,8 @@ def test_turn_off_message_logging():
##### VALID JSON ###### ##### VALID JSON ######
def test_standard_logging_payload(): @pytest.mark.parametrize("model", ["gpt-3.5-turbo", "azure/chatgpt-v-2"])
def test_standard_logging_payload(model):
""" """
Ensure valid standard_logging_payload is passed for logging calls to s3 Ensure valid standard_logging_payload is passed for logging calls to s3
@ -1187,9 +1188,9 @@ def test_standard_logging_payload():
customHandler, "log_success_event", new=MagicMock() customHandler, "log_success_event", new=MagicMock()
) as mock_client: ) as mock_client:
_ = litellm.completion( _ = litellm.completion(
model="gpt-3.5-turbo", model=model,
messages=[{"role": "user", "content": "Hey, how's it going?"}], messages=[{"role": "user", "content": "Hey, how's it going?"}],
mock_response="Going well!", # mock_response="Going well!",
) )
time.sleep(2) time.sleep(2)
@ -1226,3 +1227,9 @@ def test_standard_logging_payload():
] ]
> 0 > 0
) )
assert (
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"][
"model_map_information"
]["model_map_value"]
is not None
)

View file

@ -1195,6 +1195,11 @@ class StandardLoggingHiddenParams(TypedDict):
additional_headers: Optional[dict] additional_headers: Optional[dict]
class StandardLoggingModelInformation(TypedDict):
model_map_key: str
model_map_value: Optional[ModelInfo]
class StandardLoggingPayload(TypedDict): class StandardLoggingPayload(TypedDict):
id: str id: str
call_type: str call_type: str
@ -1205,6 +1210,7 @@ class StandardLoggingPayload(TypedDict):
startTime: float startTime: float
endTime: float endTime: float
completionStartTime: float completionStartTime: float
model_map_information: StandardLoggingModelInformation
model: str model: str
model_id: Optional[str] model_id: Optional[str]
model_group: Optional[str] model_group: Optional[str]