Fix anthropic prompt caching cost calc + trim logged message in db (#9838)

* fix(spend_tracking_utils.py): prevent logging entire mp4 files to db

Fixes https://github.com/BerriAI/litellm/issues/9732

* fix(anthropic/chat/transformation.py): Fix double counting cache creation input tokens

Fixes https://github.com/BerriAI/litellm/issues/9812

* refactor(anthropic/chat/transformation.py): refactor streaming to use same usage calculation block as non-streaming

reduce errors

* fix(bedrock/chat/converse_transformation.py): don't increment prompt tokens with cache_creation_input_tokens

* build: remove redisvl from requirements.txt (temporary)

* fix(spend_tracking_utils.py): handle circular references

* test: update code cov test

* test: update test
This commit is contained in:
Krish Dholakia 2025-04-09 21:26:43 -07:00 committed by GitHub
parent 00c5c23d97
commit 87733c8193
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 216 additions and 63 deletions

View file

@ -4,7 +4,7 @@ Calling + translation logic for anthropic's `/v1/messages` endpoint
import copy import copy
import json import json
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
import httpx # type: ignore import httpx # type: ignore
@ -492,29 +492,10 @@ class ModelResponseIterator:
return False return False
def _handle_usage(self, anthropic_usage_chunk: Union[dict, UsageDelta]) -> Usage: def _handle_usage(self, anthropic_usage_chunk: Union[dict, UsageDelta]) -> Usage:
usage_block = Usage( return AnthropicConfig().calculate_usage(
prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0), usage_object=cast(dict, anthropic_usage_chunk), reasoning_content=None
completion_tokens=anthropic_usage_chunk.get("output_tokens", 0),
total_tokens=anthropic_usage_chunk.get("input_tokens", 0)
+ anthropic_usage_chunk.get("output_tokens", 0),
) )
cache_creation_input_tokens = anthropic_usage_chunk.get(
"cache_creation_input_tokens"
)
if cache_creation_input_tokens is not None and isinstance(
cache_creation_input_tokens, int
):
usage_block["cache_creation_input_tokens"] = cache_creation_input_tokens
cache_read_input_tokens = anthropic_usage_chunk.get("cache_read_input_tokens")
if cache_read_input_tokens is not None and isinstance(
cache_read_input_tokens, int
):
usage_block["cache_read_input_tokens"] = cache_read_input_tokens
return usage_block
def _content_block_delta_helper( def _content_block_delta_helper(
self, chunk: dict self, chunk: dict
) -> Tuple[ ) -> Tuple[

View file

@ -682,6 +682,45 @@ class AnthropicConfig(BaseConfig):
reasoning_content += block["thinking"] reasoning_content += block["thinking"]
return text_content, citations, thinking_blocks, reasoning_content, tool_calls return text_content, citations, thinking_blocks, reasoning_content, tool_calls
def calculate_usage(
self, usage_object: dict, reasoning_content: Optional[str]
) -> Usage:
prompt_tokens = usage_object.get("input_tokens", 0)
completion_tokens = usage_object.get("output_tokens", 0)
_usage = usage_object
cache_creation_input_tokens: int = 0
cache_read_input_tokens: int = 0
if "cache_creation_input_tokens" in _usage:
cache_creation_input_tokens = _usage["cache_creation_input_tokens"]
if "cache_read_input_tokens" in _usage:
cache_read_input_tokens = _usage["cache_read_input_tokens"]
prompt_tokens += cache_read_input_tokens
prompt_tokens_details = PromptTokensDetailsWrapper(
cached_tokens=cache_read_input_tokens
)
completion_token_details = (
CompletionTokensDetailsWrapper(
reasoning_tokens=token_counter(
text=reasoning_content, count_response_tokens=True
)
)
if reasoning_content
else None
)
total_tokens = prompt_tokens + completion_tokens
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
prompt_tokens_details=prompt_tokens_details,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
completion_tokens_details=completion_token_details,
)
return usage
def transform_response( def transform_response(
self, self,
model: str, model: str,
@ -772,45 +811,14 @@ class AnthropicConfig(BaseConfig):
) )
## CALCULATING USAGE ## CALCULATING USAGE
prompt_tokens = completion_response["usage"]["input_tokens"] usage = self.calculate_usage(
completion_tokens = completion_response["usage"]["output_tokens"] usage_object=completion_response["usage"],
_usage = completion_response["usage"] reasoning_content=reasoning_content,
cache_creation_input_tokens: int = 0 )
cache_read_input_tokens: int = 0 setattr(model_response, "usage", usage) # type: ignore
model_response.created = int(time.time()) model_response.created = int(time.time())
model_response.model = completion_response["model"] model_response.model = completion_response["model"]
if "cache_creation_input_tokens" in _usage:
cache_creation_input_tokens = _usage["cache_creation_input_tokens"]
prompt_tokens += cache_creation_input_tokens
if "cache_read_input_tokens" in _usage:
cache_read_input_tokens = _usage["cache_read_input_tokens"]
prompt_tokens += cache_read_input_tokens
prompt_tokens_details = PromptTokensDetailsWrapper(
cached_tokens=cache_read_input_tokens
)
completion_token_details = (
CompletionTokensDetailsWrapper(
reasoning_tokens=token_counter(
text=reasoning_content, count_response_tokens=True
)
)
if reasoning_content
else None
)
total_tokens = prompt_tokens + completion_tokens
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
prompt_tokens_details=prompt_tokens_details,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
completion_tokens_details=completion_token_details,
)
setattr(model_response, "usage", usage) # type: ignore
model_response._hidden_params = _hidden_params model_response._hidden_params = _hidden_params
return model_response return model_response

View file

@ -653,8 +653,10 @@ class AmazonConverseConfig(BaseConfig):
cache_read_input_tokens = usage["cacheReadInputTokens"] cache_read_input_tokens = usage["cacheReadInputTokens"]
input_tokens += cache_read_input_tokens input_tokens += cache_read_input_tokens
if "cacheWriteInputTokens" in usage: if "cacheWriteInputTokens" in usage:
"""
Do not increment prompt_tokens with cacheWriteInputTokens
"""
cache_creation_input_tokens = usage["cacheWriteInputTokens"] cache_creation_input_tokens = usage["cacheWriteInputTokens"]
input_tokens += cache_creation_input_tokens
prompt_tokens_details = PromptTokensDetailsWrapper( prompt_tokens_details = PromptTokensDetailsWrapper(
cached_tokens=cache_read_input_tokens cached_tokens=cache_read_input_tokens

View file

@ -45,3 +45,6 @@ litellm_settings:
files_settings: files_settings:
- custom_llm_provider: gemini - custom_llm_provider: gemini
api_key: os.environ/GEMINI_API_KEY api_key: os.environ/GEMINI_API_KEY
general_settings:
store_prompts_in_spend_logs: true

View file

@ -360,6 +360,39 @@ def _get_messages_for_spend_logs_payload(
return "{}" return "{}"
def _sanitize_request_body_for_spend_logs_payload(
request_body: dict,
visited: Optional[set] = None,
) -> dict:
"""
Recursively sanitize request body to prevent logging large base64 strings or other large values.
Truncates strings longer than 1000 characters and handles nested dictionaries.
"""
MAX_STRING_LENGTH = 1000
if visited is None:
visited = set()
# Get the object's memory address to track visited objects
obj_id = id(request_body)
if obj_id in visited:
return {}
visited.add(obj_id)
def _sanitize_value(value: Any) -> Any:
if isinstance(value, dict):
return _sanitize_request_body_for_spend_logs_payload(value, visited)
elif isinstance(value, list):
return [_sanitize_value(item) for item in value]
elif isinstance(value, str):
if len(value) > MAX_STRING_LENGTH:
return f"{value[:MAX_STRING_LENGTH]}... (truncated {len(value) - MAX_STRING_LENGTH} chars)"
return value
return value
return {k: _sanitize_value(v) for k, v in request_body.items()}
def _add_proxy_server_request_to_metadata( def _add_proxy_server_request_to_metadata(
metadata: dict, metadata: dict,
litellm_params: dict, litellm_params: dict,
@ -373,6 +406,7 @@ def _add_proxy_server_request_to_metadata(
) )
if _proxy_server_request is not None: if _proxy_server_request is not None:
_request_body = _proxy_server_request.get("body", {}) or {} _request_body = _proxy_server_request.get("body", {}) or {}
_request_body = _sanitize_request_body_for_spend_logs_payload(_request_body)
_request_body_json_str = json.dumps(_request_body, default=str) _request_body_json_str = json.dumps(_request_body, default=str)
metadata["proxy_server_request"] = _request_body_json_str metadata["proxy_server_request"] = _request_body_json_str
return metadata return metadata

View file

@ -16,6 +16,8 @@ IGNORE_FUNCTIONS = [
"_transform_prompt", "_transform_prompt",
"mask_dict", "mask_dict",
"_serialize", # we now set a max depth for this "_serialize", # we now set a max depth for this
"_sanitize_request_body_for_spend_logs_payload", # testing added for circular reference
"_sanitize_value", # testing added for circular reference
] ]

View file

@ -33,3 +33,26 @@ def test_response_format_transformation_unit_test():
"agent_doing": {"title": "Agent Doing", "type": "string"} "agent_doing": {"title": "Agent Doing", "type": "string"}
} }
print(result) print(result)
def test_calculate_usage():
"""
Do not include cache_creation_input_tokens in the prompt_tokens
Fixes https://github.com/BerriAI/litellm/issues/9812
"""
config = AnthropicConfig()
usage_object = {
"input_tokens": 3,
"cache_creation_input_tokens": 12304,
"cache_read_input_tokens": 0,
"output_tokens": 550,
}
usage = config.calculate_usage(usage_object=usage_object, reasoning_content=None)
assert usage.prompt_tokens == 3
assert usage.completion_tokens == 550
assert usage.total_tokens == 3 + 550
assert usage.prompt_tokens_details.cached_tokens == 0
assert usage._cache_creation_input_tokens == 12304
assert usage._cache_read_input_tokens == 0

View file

@ -30,9 +30,7 @@ def test_transform_usage():
openai_usage = config._transform_usage(usage) openai_usage = config._transform_usage(usage)
assert ( assert (
openai_usage.prompt_tokens openai_usage.prompt_tokens
== usage["inputTokens"] == usage["inputTokens"] + usage["cacheReadInputTokens"]
+ usage["cacheWriteInputTokens"]
+ usage["cacheReadInputTokens"]
) )
assert openai_usage.completion_tokens == usage["outputTokens"] assert openai_usage.completion_tokens == usage["outputTokens"]
assert openai_usage.total_tokens == usage["totalTokens"] assert openai_usage.total_tokens == usage["totalTokens"]

View file

@ -0,0 +1,102 @@
import asyncio
import datetime
import json
import os
import sys
from datetime import timezone
from typing import Any
import pytest
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../../..")
) # Adds the parent directory to the system path
from unittest.mock import MagicMock, patch
import litellm
from litellm.proxy.spend_tracking.spend_tracking_utils import (
_sanitize_request_body_for_spend_logs_payload,
)
def test_sanitize_request_body_for_spend_logs_payload_basic():
request_body = {
"messages": [{"role": "user", "content": "Hello, how are you?"}],
}
assert _sanitize_request_body_for_spend_logs_payload(request_body) == request_body
def test_sanitize_request_body_for_spend_logs_payload_long_string():
long_string = "a" * 2000 # Create a string longer than MAX_STRING_LENGTH
request_body = {"text": long_string, "normal_text": "short text"}
sanitized = _sanitize_request_body_for_spend_logs_payload(request_body)
assert len(sanitized["text"]) == 1000 + len("... (truncated 1000 chars)")
assert sanitized["normal_text"] == "short text"
def test_sanitize_request_body_for_spend_logs_payload_nested_dict():
request_body = {"outer": {"inner": {"text": "a" * 2000, "normal": "short"}}}
sanitized = _sanitize_request_body_for_spend_logs_payload(request_body)
assert len(sanitized["outer"]["inner"]["text"]) == 1000 + len(
"... (truncated 1000 chars)"
)
assert sanitized["outer"]["inner"]["normal"] == "short"
def test_sanitize_request_body_for_spend_logs_payload_nested_list():
request_body = {
"items": [{"text": "a" * 2000}, {"text": "short"}, [{"text": "a" * 2000}]]
}
sanitized = _sanitize_request_body_for_spend_logs_payload(request_body)
assert len(sanitized["items"][0]["text"]) == 1000 + len(
"... (truncated 1000 chars)"
)
assert sanitized["items"][1]["text"] == "short"
assert len(sanitized["items"][2][0]["text"]) == 1000 + len(
"... (truncated 1000 chars)"
)
def test_sanitize_request_body_for_spend_logs_payload_non_string_values():
request_body = {"number": 42, "boolean": True, "none": None, "float": 3.14}
sanitized = _sanitize_request_body_for_spend_logs_payload(request_body)
assert sanitized == request_body
def test_sanitize_request_body_for_spend_logs_payload_empty():
request_body: dict[str, Any] = {}
sanitized = _sanitize_request_body_for_spend_logs_payload(request_body)
assert sanitized == request_body
def test_sanitize_request_body_for_spend_logs_payload_mixed_types():
request_body = {
"text": "a" * 2000,
"number": 42,
"nested": {"list": ["short", "a" * 2000], "dict": {"key": "a" * 2000}},
}
sanitized = _sanitize_request_body_for_spend_logs_payload(request_body)
assert len(sanitized["text"]) == 1000 + len("... (truncated 1000 chars)")
assert sanitized["number"] == 42
assert sanitized["nested"]["list"][0] == "short"
assert len(sanitized["nested"]["list"][1]) == 1000 + len(
"... (truncated 1000 chars)"
)
assert len(sanitized["nested"]["dict"]["key"]) == 1000 + len(
"... (truncated 1000 chars)"
)
def test_sanitize_request_body_for_spend_logs_payload_circular_reference():
# Create a circular reference
a: dict[str, Any] = {}
b: dict[str, Any] = {"a": a}
a["b"] = b
# Test that it handles circular reference without infinite recursion
sanitized = _sanitize_request_body_for_spend_logs_payload(a)
assert sanitized == {
"b": {"a": {}}
} # Should return empty dict for circular reference