feat(litellm_logging.py): add 'saved_cache_cost' to standard logging payload (s3)

This commit is contained in:
Krrish Dholakia 2024-08-21 16:58:07 -07:00
parent 3a7d9af01a
commit 8a05ce77e9
4 changed files with 64 additions and 3 deletions

View file

@ -524,6 +524,7 @@ class Logging:
TextCompletionResponse,
HttpxBinaryResponseContent,
],
cache_hit: Optional[bool] = None,
):
"""
Calculate response cost using result + logging object variables.
@ -535,10 +536,13 @@ class Logging:
litellm_params=self.litellm_params
)
if cache_hit is None:
cache_hit = self.model_call_details.get("cache_hit", False)
response_cost = litellm.response_cost_calculator(
response_object=result,
model=self.model,
cache_hit=self.model_call_details.get("cache_hit", False),
cache_hit=cache_hit,
custom_llm_provider=self.model_call_details.get(
"custom_llm_provider", None
),
@ -630,6 +634,7 @@ class Logging:
init_response_obj=result,
start_time=start_time,
end_time=end_time,
logging_obj=self,
)
)
return start_time, end_time, result
@ -2181,6 +2186,7 @@ def get_standard_logging_object_payload(
init_response_obj: Any,
start_time: dt_object,
end_time: dt_object,
logging_obj: Logging,
) -> Optional[StandardLoggingPayload]:
try:
if kwargs is None:
@ -2277,11 +2283,17 @@ def get_standard_logging_object_payload(
cache_key = litellm.cache.get_cache_key(**kwargs)
else:
cache_key = None
saved_cache_cost: Optional[float] = None
if cache_hit is True:
import time
id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id
saved_cache_cost = logging_obj._response_cost_calculator(
result=init_response_obj, cache_hit=False
)
## Get model cost information ##
base_model = _get_base_model_from_metadata(model_call_details=kwargs)
custom_pricing = use_custom_pricing_for_model(litellm_params=litellm_params)
@ -2318,6 +2330,7 @@ def get_standard_logging_object_payload(
id=str(id),
call_type=call_type or "",
cache_hit=cache_hit,
saved_cache_cost=saved_cache_cost,
startTime=start_time_float,
endTime=end_time_float,
completionStartTime=completion_start_time_float,

View file

@ -4,5 +4,10 @@ model_list:
model: "*"
litellm_settings:
max_internal_user_budget: 0 # amount in USD
internal_user_budget_duration: "1mo" # reset every month
success_callback: ["s3"]
cache: true
s3_callback_params:
s3_bucket_name: mytestbucketlitellm # AWS Bucket Name for S3
s3_region_name: us-west-2 # AWS Region Name for S3
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3

View file

@ -1252,3 +1252,45 @@ def test_standard_logging_payload(model, turn_off_message_logging):
]
if turn_off_message_logging:
assert "redacted-by-litellm" == slobject["messages"][0]["content"]
def test_standard_logging_payload_cache_hit():
from litellm.types.utils import StandardLoggingPayload
# sync completion
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
litellm.cache = Cache()
_ = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
caching=True,
)
with patch.object(
customHandler, "log_success_event", new=MagicMock()
) as mock_client:
_ = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
caching=True,
)
time.sleep(2)
mock_client.assert_called_once()
assert "standard_logging_object" in mock_client.call_args.kwargs["kwargs"]
assert (
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
is not None
)
standard_logging_object: StandardLoggingPayload = mock_client.call_args.kwargs[
"kwargs"
]["standard_logging_object"]
assert standard_logging_object["cache_hit"] is True
assert standard_logging_object["response_cost"] == 0
assert standard_logging_object["saved_cache_cost"] > 0

View file

@ -1218,6 +1218,7 @@ class StandardLoggingPayload(TypedDict):
metadata: StandardLoggingMetadata
cache_hit: Optional[bool]
cache_key: Optional[str]
saved_cache_cost: Optional[float]
request_tags: list
end_user: Optional[str]
requester_ip_address: Optional[str]