Fix anthropic prompt caching cost calc + trim logged message in db (#9838)

* fix(spend_tracking_utils.py): prevent logging entire mp4 files to db

Fixes https://github.com/BerriAI/litellm/issues/9732

* fix(anthropic/chat/transformation.py): Fix double counting cache creation input tokens

Fixes https://github.com/BerriAI/litellm/issues/9812

* refactor(anthropic/chat/transformation.py): refactor streaming to use same usage calculation block as non-streaming

reduce errors

* fix(bedrock/chat/converse_transformation.py): don't increment prompt tokens with cache_creation_input_tokens

* build: remove redisvl from requirements.txt (temporary)

* fix(spend_tracking_utils.py): handle circular references

* test: update code cov test

* test: update test
This commit is contained in:
Krish Dholakia 2025-04-09 21:26:43 -07:00 committed by GitHub
parent 00c5c23d97
commit 87733c8193
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 216 additions and 63 deletions

View file

@ -4,7 +4,7 @@ Calling + translation logic for anthropic's `/v1/messages` endpoint
import copy
import json
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
import httpx # type: ignore
@ -492,29 +492,10 @@ class ModelResponseIterator:
return False
def _handle_usage(self, anthropic_usage_chunk: Union[dict, UsageDelta]) -> Usage:
usage_block = Usage(
prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0),
completion_tokens=anthropic_usage_chunk.get("output_tokens", 0),
total_tokens=anthropic_usage_chunk.get("input_tokens", 0)
+ anthropic_usage_chunk.get("output_tokens", 0),
return AnthropicConfig().calculate_usage(
usage_object=cast(dict, anthropic_usage_chunk), reasoning_content=None
)
cache_creation_input_tokens = anthropic_usage_chunk.get(
"cache_creation_input_tokens"
)
if cache_creation_input_tokens is not None and isinstance(
cache_creation_input_tokens, int
):
usage_block["cache_creation_input_tokens"] = cache_creation_input_tokens
cache_read_input_tokens = anthropic_usage_chunk.get("cache_read_input_tokens")
if cache_read_input_tokens is not None and isinstance(
cache_read_input_tokens, int
):
usage_block["cache_read_input_tokens"] = cache_read_input_tokens
return usage_block
def _content_block_delta_helper(
self, chunk: dict
) -> Tuple[

View file

@ -682,6 +682,45 @@ class AnthropicConfig(BaseConfig):
reasoning_content += block["thinking"]
return text_content, citations, thinking_blocks, reasoning_content, tool_calls
def calculate_usage(
self, usage_object: dict, reasoning_content: Optional[str]
) -> Usage:
prompt_tokens = usage_object.get("input_tokens", 0)
completion_tokens = usage_object.get("output_tokens", 0)
_usage = usage_object
cache_creation_input_tokens: int = 0
cache_read_input_tokens: int = 0
if "cache_creation_input_tokens" in _usage:
cache_creation_input_tokens = _usage["cache_creation_input_tokens"]
if "cache_read_input_tokens" in _usage:
cache_read_input_tokens = _usage["cache_read_input_tokens"]
prompt_tokens += cache_read_input_tokens
prompt_tokens_details = PromptTokensDetailsWrapper(
cached_tokens=cache_read_input_tokens
)
completion_token_details = (
CompletionTokensDetailsWrapper(
reasoning_tokens=token_counter(
text=reasoning_content, count_response_tokens=True
)
)
if reasoning_content
else None
)
total_tokens = prompt_tokens + completion_tokens
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
prompt_tokens_details=prompt_tokens_details,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
completion_tokens_details=completion_token_details,
)
return usage
def transform_response(
self,
model: str,
@ -772,45 +811,14 @@ class AnthropicConfig(BaseConfig):
)
## CALCULATING USAGE
prompt_tokens = completion_response["usage"]["input_tokens"]
completion_tokens = completion_response["usage"]["output_tokens"]
_usage = completion_response["usage"]
cache_creation_input_tokens: int = 0
cache_read_input_tokens: int = 0
usage = self.calculate_usage(
usage_object=completion_response["usage"],
reasoning_content=reasoning_content,
)
setattr(model_response, "usage", usage) # type: ignore
model_response.created = int(time.time())
model_response.model = completion_response["model"]
if "cache_creation_input_tokens" in _usage:
cache_creation_input_tokens = _usage["cache_creation_input_tokens"]
prompt_tokens += cache_creation_input_tokens
if "cache_read_input_tokens" in _usage:
cache_read_input_tokens = _usage["cache_read_input_tokens"]
prompt_tokens += cache_read_input_tokens
prompt_tokens_details = PromptTokensDetailsWrapper(
cached_tokens=cache_read_input_tokens
)
completion_token_details = (
CompletionTokensDetailsWrapper(
reasoning_tokens=token_counter(
text=reasoning_content, count_response_tokens=True
)
)
if reasoning_content
else None
)
total_tokens = prompt_tokens + completion_tokens
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
prompt_tokens_details=prompt_tokens_details,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
completion_tokens_details=completion_token_details,
)
setattr(model_response, "usage", usage) # type: ignore
model_response._hidden_params = _hidden_params
return model_response

View file

@ -653,8 +653,10 @@ class AmazonConverseConfig(BaseConfig):
cache_read_input_tokens = usage["cacheReadInputTokens"]
input_tokens += cache_read_input_tokens
if "cacheWriteInputTokens" in usage:
"""
Do not increment prompt_tokens with cacheWriteInputTokens
"""
cache_creation_input_tokens = usage["cacheWriteInputTokens"]
input_tokens += cache_creation_input_tokens
prompt_tokens_details = PromptTokensDetailsWrapper(
cached_tokens=cache_read_input_tokens

View file

@ -44,4 +44,7 @@ litellm_settings:
files_settings:
- custom_llm_provider: gemini
api_key: os.environ/GEMINI_API_KEY
api_key: os.environ/GEMINI_API_KEY
general_settings:
store_prompts_in_spend_logs: true

View file

@ -360,6 +360,39 @@ def _get_messages_for_spend_logs_payload(
return "{}"
def _sanitize_request_body_for_spend_logs_payload(
request_body: dict,
visited: Optional[set] = None,
) -> dict:
"""
Recursively sanitize request body to prevent logging large base64 strings or other large values.
Truncates strings longer than 1000 characters and handles nested dictionaries.
"""
MAX_STRING_LENGTH = 1000
if visited is None:
visited = set()
# Get the object's memory address to track visited objects
obj_id = id(request_body)
if obj_id in visited:
return {}
visited.add(obj_id)
def _sanitize_value(value: Any) -> Any:
if isinstance(value, dict):
return _sanitize_request_body_for_spend_logs_payload(value, visited)
elif isinstance(value, list):
return [_sanitize_value(item) for item in value]
elif isinstance(value, str):
if len(value) > MAX_STRING_LENGTH:
return f"{value[:MAX_STRING_LENGTH]}... (truncated {len(value) - MAX_STRING_LENGTH} chars)"
return value
return value
return {k: _sanitize_value(v) for k, v in request_body.items()}
def _add_proxy_server_request_to_metadata(
metadata: dict,
litellm_params: dict,
@ -373,6 +406,7 @@ def _add_proxy_server_request_to_metadata(
)
if _proxy_server_request is not None:
_request_body = _proxy_server_request.get("body", {}) or {}
_request_body = _sanitize_request_body_for_spend_logs_payload(_request_body)
_request_body_json_str = json.dumps(_request_body, default=str)
metadata["proxy_server_request"] = _request_body_json_str
return metadata

View file

@ -16,6 +16,8 @@ IGNORE_FUNCTIONS = [
"_transform_prompt",
"mask_dict",
"_serialize", # we now set a max depth for this
"_sanitize_request_body_for_spend_logs_payload", # testing added for circular reference
"_sanitize_value", # testing added for circular reference
]

View file

@ -33,3 +33,26 @@ def test_response_format_transformation_unit_test():
"agent_doing": {"title": "Agent Doing", "type": "string"}
}
print(result)
def test_calculate_usage():
"""
Do not include cache_creation_input_tokens in the prompt_tokens
Fixes https://github.com/BerriAI/litellm/issues/9812
"""
config = AnthropicConfig()
usage_object = {
"input_tokens": 3,
"cache_creation_input_tokens": 12304,
"cache_read_input_tokens": 0,
"output_tokens": 550,
}
usage = config.calculate_usage(usage_object=usage_object, reasoning_content=None)
assert usage.prompt_tokens == 3
assert usage.completion_tokens == 550
assert usage.total_tokens == 3 + 550
assert usage.prompt_tokens_details.cached_tokens == 0
assert usage._cache_creation_input_tokens == 12304
assert usage._cache_read_input_tokens == 0

View file

@ -30,9 +30,7 @@ def test_transform_usage():
openai_usage = config._transform_usage(usage)
assert (
openai_usage.prompt_tokens
== usage["inputTokens"]
+ usage["cacheWriteInputTokens"]
+ usage["cacheReadInputTokens"]
== usage["inputTokens"] + usage["cacheReadInputTokens"]
)
assert openai_usage.completion_tokens == usage["outputTokens"]
assert openai_usage.total_tokens == usage["totalTokens"]

View file

@ -0,0 +1,102 @@
import asyncio
import datetime
import json
import os
import sys
from datetime import timezone
from typing import Any
import pytest
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../../..")
) # Adds the parent directory to the system path
from unittest.mock import MagicMock, patch
import litellm
from litellm.proxy.spend_tracking.spend_tracking_utils import (
_sanitize_request_body_for_spend_logs_payload,
)
def test_sanitize_request_body_for_spend_logs_payload_basic():
request_body = {
"messages": [{"role": "user", "content": "Hello, how are you?"}],
}
assert _sanitize_request_body_for_spend_logs_payload(request_body) == request_body
def test_sanitize_request_body_for_spend_logs_payload_long_string():
long_string = "a" * 2000 # Create a string longer than MAX_STRING_LENGTH
request_body = {"text": long_string, "normal_text": "short text"}
sanitized = _sanitize_request_body_for_spend_logs_payload(request_body)
assert len(sanitized["text"]) == 1000 + len("... (truncated 1000 chars)")
assert sanitized["normal_text"] == "short text"
def test_sanitize_request_body_for_spend_logs_payload_nested_dict():
request_body = {"outer": {"inner": {"text": "a" * 2000, "normal": "short"}}}
sanitized = _sanitize_request_body_for_spend_logs_payload(request_body)
assert len(sanitized["outer"]["inner"]["text"]) == 1000 + len(
"... (truncated 1000 chars)"
)
assert sanitized["outer"]["inner"]["normal"] == "short"
def test_sanitize_request_body_for_spend_logs_payload_nested_list():
request_body = {
"items": [{"text": "a" * 2000}, {"text": "short"}, [{"text": "a" * 2000}]]
}
sanitized = _sanitize_request_body_for_spend_logs_payload(request_body)
assert len(sanitized["items"][0]["text"]) == 1000 + len(
"... (truncated 1000 chars)"
)
assert sanitized["items"][1]["text"] == "short"
assert len(sanitized["items"][2][0]["text"]) == 1000 + len(
"... (truncated 1000 chars)"
)
def test_sanitize_request_body_for_spend_logs_payload_non_string_values():
request_body = {"number": 42, "boolean": True, "none": None, "float": 3.14}
sanitized = _sanitize_request_body_for_spend_logs_payload(request_body)
assert sanitized == request_body
def test_sanitize_request_body_for_spend_logs_payload_empty():
request_body: dict[str, Any] = {}
sanitized = _sanitize_request_body_for_spend_logs_payload(request_body)
assert sanitized == request_body
def test_sanitize_request_body_for_spend_logs_payload_mixed_types():
request_body = {
"text": "a" * 2000,
"number": 42,
"nested": {"list": ["short", "a" * 2000], "dict": {"key": "a" * 2000}},
}
sanitized = _sanitize_request_body_for_spend_logs_payload(request_body)
assert len(sanitized["text"]) == 1000 + len("... (truncated 1000 chars)")
assert sanitized["number"] == 42
assert sanitized["nested"]["list"][0] == "short"
assert len(sanitized["nested"]["list"][1]) == 1000 + len(
"... (truncated 1000 chars)"
)
assert len(sanitized["nested"]["dict"]["key"]) == 1000 + len(
"... (truncated 1000 chars)"
)
def test_sanitize_request_body_for_spend_logs_payload_circular_reference():
# Create a circular reference
a: dict[str, Any] = {}
b: dict[str, Any] = {"a": a}
a["b"] = b
# Test that it handles circular reference without infinite recursion
sanitized = _sanitize_request_body_for_spend_logs_payload(a)
assert sanitized == {
"b": {"a": {}}
} # Should return empty dict for circular reference