forked from phoenix/litellm-mirror
feat(litellm_logging.py): add 'saved_cache_cost' to standard logging payload (s3)
This commit is contained in:
parent
3a7d9af01a
commit
8a05ce77e9
4 changed files with 64 additions and 3 deletions
|
@ -524,6 +524,7 @@ class Logging:
|
|||
TextCompletionResponse,
|
||||
HttpxBinaryResponseContent,
|
||||
],
|
||||
cache_hit: Optional[bool] = None,
|
||||
):
|
||||
"""
|
||||
Calculate response cost using result + logging object variables.
|
||||
|
@ -535,10 +536,13 @@ class Logging:
|
|||
litellm_params=self.litellm_params
|
||||
)
|
||||
|
||||
if cache_hit is None:
|
||||
cache_hit = self.model_call_details.get("cache_hit", False)
|
||||
|
||||
response_cost = litellm.response_cost_calculator(
|
||||
response_object=result,
|
||||
model=self.model,
|
||||
cache_hit=self.model_call_details.get("cache_hit", False),
|
||||
cache_hit=cache_hit,
|
||||
custom_llm_provider=self.model_call_details.get(
|
||||
"custom_llm_provider", None
|
||||
),
|
||||
|
@ -630,6 +634,7 @@ class Logging:
|
|||
init_response_obj=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
)
|
||||
)
|
||||
return start_time, end_time, result
|
||||
|
@ -2181,6 +2186,7 @@ def get_standard_logging_object_payload(
|
|||
init_response_obj: Any,
|
||||
start_time: dt_object,
|
||||
end_time: dt_object,
|
||||
logging_obj: Logging,
|
||||
) -> Optional[StandardLoggingPayload]:
|
||||
try:
|
||||
if kwargs is None:
|
||||
|
@ -2277,11 +2283,17 @@ def get_standard_logging_object_payload(
|
|||
cache_key = litellm.cache.get_cache_key(**kwargs)
|
||||
else:
|
||||
cache_key = None
|
||||
|
||||
saved_cache_cost: Optional[float] = None
|
||||
if cache_hit is True:
|
||||
import time
|
||||
|
||||
id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id
|
||||
|
||||
saved_cache_cost = logging_obj._response_cost_calculator(
|
||||
result=init_response_obj, cache_hit=False
|
||||
)
|
||||
|
||||
## Get model cost information ##
|
||||
base_model = _get_base_model_from_metadata(model_call_details=kwargs)
|
||||
custom_pricing = use_custom_pricing_for_model(litellm_params=litellm_params)
|
||||
|
@ -2318,6 +2330,7 @@ def get_standard_logging_object_payload(
|
|||
id=str(id),
|
||||
call_type=call_type or "",
|
||||
cache_hit=cache_hit,
|
||||
saved_cache_cost=saved_cache_cost,
|
||||
startTime=start_time_float,
|
||||
endTime=end_time_float,
|
||||
completionStartTime=completion_start_time_float,
|
||||
|
|
|
@ -4,5 +4,10 @@ model_list:
|
|||
model: "*"
|
||||
|
||||
litellm_settings:
|
||||
max_internal_user_budget: 0 # amount in USD
|
||||
internal_user_budget_duration: "1mo" # reset every month
|
||||
success_callback: ["s3"]
|
||||
cache: true
|
||||
s3_callback_params:
|
||||
s3_bucket_name: mytestbucketlitellm # AWS Bucket Name for S3
|
||||
s3_region_name: us-west-2 # AWS Region Name for S3
|
||||
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
|
||||
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3
|
||||
|
|
|
@ -1252,3 +1252,45 @@ def test_standard_logging_payload(model, turn_off_message_logging):
|
|||
]
|
||||
if turn_off_message_logging:
|
||||
assert "redacted-by-litellm" == slobject["messages"][0]["content"]
|
||||
|
||||
|
||||
def test_standard_logging_payload_cache_hit():
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
# sync completion
|
||||
customHandler = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler]
|
||||
|
||||
litellm.cache = Cache()
|
||||
|
||||
_ = litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
caching=True,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
customHandler, "log_success_event", new=MagicMock()
|
||||
) as mock_client:
|
||||
_ = litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
caching=True,
|
||||
)
|
||||
|
||||
time.sleep(2)
|
||||
mock_client.assert_called_once()
|
||||
|
||||
assert "standard_logging_object" in mock_client.call_args.kwargs["kwargs"]
|
||||
assert (
|
||||
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
|
||||
is not None
|
||||
)
|
||||
|
||||
standard_logging_object: StandardLoggingPayload = mock_client.call_args.kwargs[
|
||||
"kwargs"
|
||||
]["standard_logging_object"]
|
||||
|
||||
assert standard_logging_object["cache_hit"] is True
|
||||
assert standard_logging_object["response_cost"] == 0
|
||||
assert standard_logging_object["saved_cache_cost"] > 0
|
||||
|
|
|
@ -1218,6 +1218,7 @@ class StandardLoggingPayload(TypedDict):
|
|||
metadata: StandardLoggingMetadata
|
||||
cache_hit: Optional[bool]
|
||||
cache_key: Optional[str]
|
||||
saved_cache_cost: Optional[float]
|
||||
request_tags: list
|
||||
end_user: Optional[str]
|
||||
requester_ip_address: Optional[str]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue