feat(litellm_logging.py): add 'saved_cache_cost' to standard logging payload (s3)

This commit is contained in:
Krrish Dholakia 2024-08-21 16:58:07 -07:00
parent 3a7d9af01a
commit 8a05ce77e9
4 changed files with 64 additions and 3 deletions

View file

@ -524,6 +524,7 @@ class Logging:
TextCompletionResponse, TextCompletionResponse,
HttpxBinaryResponseContent, HttpxBinaryResponseContent,
], ],
cache_hit: Optional[bool] = None,
): ):
""" """
Calculate response cost using result + logging object variables. Calculate response cost using result + logging object variables.
@ -535,10 +536,13 @@ class Logging:
litellm_params=self.litellm_params litellm_params=self.litellm_params
) )
if cache_hit is None:
cache_hit = self.model_call_details.get("cache_hit", False)
response_cost = litellm.response_cost_calculator( response_cost = litellm.response_cost_calculator(
response_object=result, response_object=result,
model=self.model, model=self.model,
cache_hit=self.model_call_details.get("cache_hit", False), cache_hit=cache_hit,
custom_llm_provider=self.model_call_details.get( custom_llm_provider=self.model_call_details.get(
"custom_llm_provider", None "custom_llm_provider", None
), ),
@ -630,6 +634,7 @@ class Logging:
init_response_obj=result, init_response_obj=result,
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
logging_obj=self,
) )
) )
return start_time, end_time, result return start_time, end_time, result
@ -2181,6 +2186,7 @@ def get_standard_logging_object_payload(
init_response_obj: Any, init_response_obj: Any,
start_time: dt_object, start_time: dt_object,
end_time: dt_object, end_time: dt_object,
logging_obj: Logging,
) -> Optional[StandardLoggingPayload]: ) -> Optional[StandardLoggingPayload]:
try: try:
if kwargs is None: if kwargs is None:
@ -2277,11 +2283,17 @@ def get_standard_logging_object_payload(
cache_key = litellm.cache.get_cache_key(**kwargs) cache_key = litellm.cache.get_cache_key(**kwargs)
else: else:
cache_key = None cache_key = None
saved_cache_cost: Optional[float] = None
if cache_hit is True: if cache_hit is True:
import time import time
id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id
saved_cache_cost = logging_obj._response_cost_calculator(
result=init_response_obj, cache_hit=False
)
## Get model cost information ## ## Get model cost information ##
base_model = _get_base_model_from_metadata(model_call_details=kwargs) base_model = _get_base_model_from_metadata(model_call_details=kwargs)
custom_pricing = use_custom_pricing_for_model(litellm_params=litellm_params) custom_pricing = use_custom_pricing_for_model(litellm_params=litellm_params)
@ -2318,6 +2330,7 @@ def get_standard_logging_object_payload(
id=str(id), id=str(id),
call_type=call_type or "", call_type=call_type or "",
cache_hit=cache_hit, cache_hit=cache_hit,
saved_cache_cost=saved_cache_cost,
startTime=start_time_float, startTime=start_time_float,
endTime=end_time_float, endTime=end_time_float,
completionStartTime=completion_start_time_float, completionStartTime=completion_start_time_float,

View file

@ -4,5 +4,10 @@ model_list:
model: "*" model: "*"
litellm_settings: litellm_settings:
max_internal_user_budget: 0 # amount in USD success_callback: ["s3"]
internal_user_budget_duration: "1mo" # reset every month cache: true
s3_callback_params:
s3_bucket_name: mytestbucketlitellm # AWS Bucket Name for S3
s3_region_name: us-west-2 # AWS Region Name for S3
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3

View file

@ -1252,3 +1252,45 @@ def test_standard_logging_payload(model, turn_off_message_logging):
] ]
if turn_off_message_logging: if turn_off_message_logging:
assert "redacted-by-litellm" == slobject["messages"][0]["content"] assert "redacted-by-litellm" == slobject["messages"][0]["content"]
def test_standard_logging_payload_cache_hit():
from litellm.types.utils import StandardLoggingPayload
# sync completion
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
litellm.cache = Cache()
_ = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
caching=True,
)
with patch.object(
customHandler, "log_success_event", new=MagicMock()
) as mock_client:
_ = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
caching=True,
)
time.sleep(2)
mock_client.assert_called_once()
assert "standard_logging_object" in mock_client.call_args.kwargs["kwargs"]
assert (
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
is not None
)
standard_logging_object: StandardLoggingPayload = mock_client.call_args.kwargs[
"kwargs"
]["standard_logging_object"]
assert standard_logging_object["cache_hit"] is True
assert standard_logging_object["response_cost"] == 0
assert standard_logging_object["saved_cache_cost"] > 0

View file

@ -1218,6 +1218,7 @@ class StandardLoggingPayload(TypedDict):
metadata: StandardLoggingMetadata metadata: StandardLoggingMetadata
cache_hit: Optional[bool] cache_hit: Optional[bool]
cache_key: Optional[str] cache_key: Optional[str]
saved_cache_cost: Optional[float]
request_tags: list request_tags: list
end_user: Optional[str] end_user: Optional[str]
requester_ip_address: Optional[str] requester_ip_address: Optional[str]