forked from phoenix/litellm-mirror
feat(litellm_logging.py): support logging model price information to s3 logs
This commit is contained in:
parent
9c3124c5a7
commit
178139f18d
9 changed files with 97 additions and 26 deletions
|
@ -410,6 +410,36 @@ def get_replicate_completion_pricing(completion_response=None, total_time=0.0):
|
||||||
return a100_80gb_price_per_second_public * total_time / 1000
|
return a100_80gb_price_per_second_public * total_time / 1000
|
||||||
|
|
||||||
|
|
||||||
|
def _select_model_name_for_cost_calc(
|
||||||
|
model: Optional[str],
|
||||||
|
completion_response: Union[BaseModel, dict],
|
||||||
|
base_model: Optional[str] = None,
|
||||||
|
custom_pricing: Optional[bool] = None,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
1. If custom pricing is true, return received model name
|
||||||
|
2. If base_model is set (e.g. for azure models), return that
|
||||||
|
3. If completion response has model set return that
|
||||||
|
4. If model is passed in return that
|
||||||
|
"""
|
||||||
|
args = locals()
|
||||||
|
if custom_pricing is True:
|
||||||
|
return model
|
||||||
|
|
||||||
|
if base_model is not None:
|
||||||
|
return base_model
|
||||||
|
|
||||||
|
return_model = model
|
||||||
|
if hasattr(completion_response, "_hidden_params"):
|
||||||
|
if (
|
||||||
|
completion_response._hidden_params.get("model", None) is not None
|
||||||
|
and len(completion_response._hidden_params["model"]) > 0
|
||||||
|
):
|
||||||
|
return_model = completion_response._hidden_params.get("model", model)
|
||||||
|
|
||||||
|
return return_model
|
||||||
|
|
||||||
|
|
||||||
def completion_cost(
|
def completion_cost(
|
||||||
completion_response=None,
|
completion_response=None,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
|
@ -511,15 +541,10 @@ def completion_cost(
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
f"completion_response response ms: {getattr(completion_response, '_response_ms', None)} "
|
f"completion_response response ms: {getattr(completion_response, '_response_ms', None)} "
|
||||||
)
|
)
|
||||||
model = model or completion_response.get(
|
model = _select_model_name_for_cost_calc(
|
||||||
"model", None
|
model=model, completion_response=completion_response
|
||||||
) # check if user passed an override for model, if it's none check completion_response['model']
|
)
|
||||||
if hasattr(completion_response, "_hidden_params"):
|
if hasattr(completion_response, "_hidden_params"):
|
||||||
if (
|
|
||||||
completion_response._hidden_params.get("model", None) is not None
|
|
||||||
and len(completion_response._hidden_params["model"]) > 0
|
|
||||||
):
|
|
||||||
model = completion_response._hidden_params.get("model", model)
|
|
||||||
custom_llm_provider = completion_response._hidden_params.get(
|
custom_llm_provider = completion_response._hidden_params.get(
|
||||||
"custom_llm_provider", custom_llm_provider or ""
|
"custom_llm_provider", custom_llm_provider or ""
|
||||||
)
|
)
|
||||||
|
|
|
@ -24,6 +24,7 @@ from litellm import (
|
||||||
verbose_logger,
|
verbose_logger,
|
||||||
)
|
)
|
||||||
from litellm.caching import DualCache, InMemoryCache, S3Cache
|
from litellm.caching import DualCache, InMemoryCache, S3Cache
|
||||||
|
from litellm.cost_calculator import _select_model_name_for_cost_calc
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.litellm_core_utils.redact_messages import (
|
from litellm.litellm_core_utils.redact_messages import (
|
||||||
redact_message_input_output_from_logging,
|
redact_message_input_output_from_logging,
|
||||||
|
@ -37,6 +38,7 @@ from litellm.types.utils import (
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
StandardLoggingHiddenParams,
|
StandardLoggingHiddenParams,
|
||||||
StandardLoggingMetadata,
|
StandardLoggingMetadata,
|
||||||
|
StandardLoggingModelInformation,
|
||||||
StandardLoggingPayload,
|
StandardLoggingPayload,
|
||||||
TextCompletionResponse,
|
TextCompletionResponse,
|
||||||
TranscriptionResponse,
|
TranscriptionResponse,
|
||||||
|
@ -2294,6 +2296,38 @@ def get_standard_logging_object_payload(
|
||||||
|
|
||||||
id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id
|
id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id
|
||||||
|
|
||||||
|
## Get model cost information ##
|
||||||
|
base_model = _get_base_model_from_metadata(model_call_details=kwargs)
|
||||||
|
custom_pricing = use_custom_pricing_for_model(litellm_params=litellm_params)
|
||||||
|
model_cost_name = _select_model_name_for_cost_calc(
|
||||||
|
model=kwargs.get("model"),
|
||||||
|
completion_response=init_response_obj,
|
||||||
|
base_model=base_model,
|
||||||
|
custom_pricing=custom_pricing,
|
||||||
|
)
|
||||||
|
if model_cost_name is None:
|
||||||
|
model_cost_information = StandardLoggingModelInformation(
|
||||||
|
model_map_key="", model_map_value=None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
custom_llm_provider = kwargs.get("custom_llm_provider", None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
_model_cost_information = litellm.get_model_info(
|
||||||
|
model=model_cost_name, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
model_cost_information = StandardLoggingModelInformation(
|
||||||
|
model_map_key=model_cost_name,
|
||||||
|
model_map_value=_model_cost_information,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
verbose_logger.warning(
|
||||||
|
"Model is not mapped in model cost map. Defaulting to None model_cost_information for standard_logging_payload"
|
||||||
|
)
|
||||||
|
model_cost_information = StandardLoggingModelInformation(
|
||||||
|
model_map_key=model_cost_name, model_map_value=None
|
||||||
|
)
|
||||||
|
|
||||||
payload: StandardLoggingPayload = StandardLoggingPayload(
|
payload: StandardLoggingPayload = StandardLoggingPayload(
|
||||||
id=str(id),
|
id=str(id),
|
||||||
call_type=call_type or "",
|
call_type=call_type or "",
|
||||||
|
@ -2320,6 +2354,7 @@ def get_standard_logging_object_payload(
|
||||||
),
|
),
|
||||||
model_parameters=kwargs.get("optional_params", None),
|
model_parameters=kwargs.get("optional_params", None),
|
||||||
hidden_params=clean_hidden_params,
|
hidden_params=clean_hidden_params,
|
||||||
|
model_map_information=model_cost_information,
|
||||||
)
|
)
|
||||||
|
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -1,9 +1,12 @@
|
||||||
# model_list:
|
model_list:
|
||||||
# - model_name: "gpt-4"
|
- model_name: "*"
|
||||||
# litellm_params:
|
litellm_params:
|
||||||
# model: "gpt-4"
|
model: "*"
|
||||||
# model_info:
|
|
||||||
# my_custom_key: "my_custom_value"
|
|
||||||
|
|
||||||
general_settings:
|
litellm_settings:
|
||||||
infer_model_from_keys: true
|
success_callback: ["s3"]
|
||||||
|
s3_callback_params:
|
||||||
|
s3_bucket_name: mytestbucketlitellm # AWS Bucket Name for S3
|
||||||
|
s3_region_name: us-west-2 # AWS Region Name for S3
|
||||||
|
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
|
||||||
|
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3
|
||||||
|
|
|
@ -251,7 +251,7 @@ def test_cost_azure_gpt_35():
|
||||||
)
|
)
|
||||||
|
|
||||||
cost = litellm.completion_cost(
|
cost = litellm.completion_cost(
|
||||||
completion_response=resp, model="azure/gpt-35-turbo"
|
completion_response=resp, model="azure/chatgpt-v-2"
|
||||||
)
|
)
|
||||||
print("\n Calculated Cost for azure/gpt-3.5-turbo", cost)
|
print("\n Calculated Cost for azure/gpt-3.5-turbo", cost)
|
||||||
input_cost = model_cost["azure/gpt-35-turbo"]["input_cost_per_token"]
|
input_cost = model_cost["azure/gpt-35-turbo"]["input_cost_per_token"]
|
||||||
|
@ -262,9 +262,7 @@ def test_cost_azure_gpt_35():
|
||||||
print("\n Excpected cost", expected_cost)
|
print("\n Excpected cost", expected_cost)
|
||||||
assert cost == expected_cost
|
assert cost == expected_cost
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(
|
pytest.fail(f"Cost Calc failed for azure/gpt-3.5-turbo. {str(e)}")
|
||||||
f"Cost Calc failed for azure/gpt-3.5-turbo. Expected {expected_cost}, Calculated cost {cost}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# test_cost_azure_gpt_35()
|
# test_cost_azure_gpt_35()
|
||||||
|
|
|
@ -1171,7 +1171,8 @@ def test_turn_off_message_logging():
|
||||||
##### VALID JSON ######
|
##### VALID JSON ######
|
||||||
|
|
||||||
|
|
||||||
def test_standard_logging_payload():
|
@pytest.mark.parametrize("model", ["gpt-3.5-turbo", "azure/chatgpt-v-2"])
|
||||||
|
def test_standard_logging_payload(model):
|
||||||
"""
|
"""
|
||||||
Ensure valid standard_logging_payload is passed for logging calls to s3
|
Ensure valid standard_logging_payload is passed for logging calls to s3
|
||||||
|
|
||||||
|
@ -1187,9 +1188,9 @@ def test_standard_logging_payload():
|
||||||
customHandler, "log_success_event", new=MagicMock()
|
customHandler, "log_success_event", new=MagicMock()
|
||||||
) as mock_client:
|
) as mock_client:
|
||||||
_ = litellm.completion(
|
_ = litellm.completion(
|
||||||
model="gpt-3.5-turbo",
|
model=model,
|
||||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
mock_response="Going well!",
|
# mock_response="Going well!",
|
||||||
)
|
)
|
||||||
|
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
|
@ -1226,3 +1227,9 @@ def test_standard_logging_payload():
|
||||||
]
|
]
|
||||||
> 0
|
> 0
|
||||||
)
|
)
|
||||||
|
assert (
|
||||||
|
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"][
|
||||||
|
"model_map_information"
|
||||||
|
]["model_map_value"]
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|
|
@ -1195,6 +1195,11 @@ class StandardLoggingHiddenParams(TypedDict):
|
||||||
additional_headers: Optional[dict]
|
additional_headers: Optional[dict]
|
||||||
|
|
||||||
|
|
||||||
|
class StandardLoggingModelInformation(TypedDict):
|
||||||
|
model_map_key: str
|
||||||
|
model_map_value: Optional[ModelInfo]
|
||||||
|
|
||||||
|
|
||||||
class StandardLoggingPayload(TypedDict):
|
class StandardLoggingPayload(TypedDict):
|
||||||
id: str
|
id: str
|
||||||
call_type: str
|
call_type: str
|
||||||
|
@ -1205,6 +1210,7 @@ class StandardLoggingPayload(TypedDict):
|
||||||
startTime: float
|
startTime: float
|
||||||
endTime: float
|
endTime: float
|
||||||
completionStartTime: float
|
completionStartTime: float
|
||||||
|
model_map_information: StandardLoggingModelInformation
|
||||||
model: str
|
model: str
|
||||||
model_id: Optional[str]
|
model_id: Optional[str]
|
||||||
model_group: Optional[str]
|
model_group: Optional[str]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue