mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Fix anthropic prompt caching cost calc + trim logged message in db (#9838)
* fix(spend_tracking_utils.py): prevent logging entire mp4 files to db Fixes https://github.com/BerriAI/litellm/issues/9732 * fix(anthropic/chat/transformation.py): Fix double counting cache creation input tokens Fixes https://github.com/BerriAI/litellm/issues/9812 * refactor(anthropic/chat/transformation.py): refactor streaming to use same usage calculation block as non-streaming reduce errors * fix(bedrock/chat/converse_transformation.py): don't increment prompt tokens with cache_creation_input_tokens * build: remove redisvl from requirements.txt (temporary) * fix(spend_tracking_utils.py): handle circular references * test: update code cov test * test: update test
This commit is contained in:
parent
00c5c23d97
commit
87733c8193
9 changed files with 216 additions and 63 deletions
|
@ -4,7 +4,7 @@ Calling + translation logic for anthropic's `/v1/messages` endpoint
|
|||
|
||||
import copy
|
||||
import json
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import httpx # type: ignore
|
||||
|
||||
|
@ -492,29 +492,10 @@ class ModelResponseIterator:
|
|||
return False
|
||||
|
||||
def _handle_usage(self, anthropic_usage_chunk: Union[dict, UsageDelta]) -> Usage:
|
||||
usage_block = Usage(
|
||||
prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0),
|
||||
completion_tokens=anthropic_usage_chunk.get("output_tokens", 0),
|
||||
total_tokens=anthropic_usage_chunk.get("input_tokens", 0)
|
||||
+ anthropic_usage_chunk.get("output_tokens", 0),
|
||||
return AnthropicConfig().calculate_usage(
|
||||
usage_object=cast(dict, anthropic_usage_chunk), reasoning_content=None
|
||||
)
|
||||
|
||||
cache_creation_input_tokens = anthropic_usage_chunk.get(
|
||||
"cache_creation_input_tokens"
|
||||
)
|
||||
if cache_creation_input_tokens is not None and isinstance(
|
||||
cache_creation_input_tokens, int
|
||||
):
|
||||
usage_block["cache_creation_input_tokens"] = cache_creation_input_tokens
|
||||
|
||||
cache_read_input_tokens = anthropic_usage_chunk.get("cache_read_input_tokens")
|
||||
if cache_read_input_tokens is not None and isinstance(
|
||||
cache_read_input_tokens, int
|
||||
):
|
||||
usage_block["cache_read_input_tokens"] = cache_read_input_tokens
|
||||
|
||||
return usage_block
|
||||
|
||||
def _content_block_delta_helper(
|
||||
self, chunk: dict
|
||||
) -> Tuple[
|
||||
|
|
|
@ -682,6 +682,45 @@ class AnthropicConfig(BaseConfig):
|
|||
reasoning_content += block["thinking"]
|
||||
return text_content, citations, thinking_blocks, reasoning_content, tool_calls
|
||||
|
||||
def calculate_usage(
|
||||
self, usage_object: dict, reasoning_content: Optional[str]
|
||||
) -> Usage:
|
||||
prompt_tokens = usage_object.get("input_tokens", 0)
|
||||
completion_tokens = usage_object.get("output_tokens", 0)
|
||||
_usage = usage_object
|
||||
cache_creation_input_tokens: int = 0
|
||||
cache_read_input_tokens: int = 0
|
||||
|
||||
if "cache_creation_input_tokens" in _usage:
|
||||
cache_creation_input_tokens = _usage["cache_creation_input_tokens"]
|
||||
if "cache_read_input_tokens" in _usage:
|
||||
cache_read_input_tokens = _usage["cache_read_input_tokens"]
|
||||
prompt_tokens += cache_read_input_tokens
|
||||
|
||||
prompt_tokens_details = PromptTokensDetailsWrapper(
|
||||
cached_tokens=cache_read_input_tokens
|
||||
)
|
||||
completion_token_details = (
|
||||
CompletionTokensDetailsWrapper(
|
||||
reasoning_tokens=token_counter(
|
||||
text=reasoning_content, count_response_tokens=True
|
||||
)
|
||||
)
|
||||
if reasoning_content
|
||||
else None
|
||||
)
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
prompt_tokens_details=prompt_tokens_details,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
completion_tokens_details=completion_token_details,
|
||||
)
|
||||
return usage
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -772,45 +811,14 @@ class AnthropicConfig(BaseConfig):
|
|||
)
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = completion_response["usage"]["input_tokens"]
|
||||
completion_tokens = completion_response["usage"]["output_tokens"]
|
||||
_usage = completion_response["usage"]
|
||||
cache_creation_input_tokens: int = 0
|
||||
cache_read_input_tokens: int = 0
|
||||
usage = self.calculate_usage(
|
||||
usage_object=completion_response["usage"],
|
||||
reasoning_content=reasoning_content,
|
||||
)
|
||||
setattr(model_response, "usage", usage) # type: ignore
|
||||
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = completion_response["model"]
|
||||
if "cache_creation_input_tokens" in _usage:
|
||||
cache_creation_input_tokens = _usage["cache_creation_input_tokens"]
|
||||
prompt_tokens += cache_creation_input_tokens
|
||||
if "cache_read_input_tokens" in _usage:
|
||||
cache_read_input_tokens = _usage["cache_read_input_tokens"]
|
||||
prompt_tokens += cache_read_input_tokens
|
||||
|
||||
prompt_tokens_details = PromptTokensDetailsWrapper(
|
||||
cached_tokens=cache_read_input_tokens
|
||||
)
|
||||
completion_token_details = (
|
||||
CompletionTokensDetailsWrapper(
|
||||
reasoning_tokens=token_counter(
|
||||
text=reasoning_content, count_response_tokens=True
|
||||
)
|
||||
)
|
||||
if reasoning_content
|
||||
else None
|
||||
)
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
prompt_tokens_details=prompt_tokens_details,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
completion_tokens_details=completion_token_details,
|
||||
)
|
||||
|
||||
setattr(model_response, "usage", usage) # type: ignore
|
||||
|
||||
model_response._hidden_params = _hidden_params
|
||||
return model_response
|
||||
|
|
|
@ -653,8 +653,10 @@ class AmazonConverseConfig(BaseConfig):
|
|||
cache_read_input_tokens = usage["cacheReadInputTokens"]
|
||||
input_tokens += cache_read_input_tokens
|
||||
if "cacheWriteInputTokens" in usage:
|
||||
"""
|
||||
Do not increment prompt_tokens with cacheWriteInputTokens
|
||||
"""
|
||||
cache_creation_input_tokens = usage["cacheWriteInputTokens"]
|
||||
input_tokens += cache_creation_input_tokens
|
||||
|
||||
prompt_tokens_details = PromptTokensDetailsWrapper(
|
||||
cached_tokens=cache_read_input_tokens
|
||||
|
|
|
@ -44,4 +44,7 @@ litellm_settings:
|
|||
|
||||
files_settings:
|
||||
- custom_llm_provider: gemini
|
||||
api_key: os.environ/GEMINI_API_KEY
|
||||
api_key: os.environ/GEMINI_API_KEY
|
||||
|
||||
general_settings:
|
||||
store_prompts_in_spend_logs: true
|
|
@ -360,6 +360,39 @@ def _get_messages_for_spend_logs_payload(
|
|||
return "{}"
|
||||
|
||||
|
||||
def _sanitize_request_body_for_spend_logs_payload(
|
||||
request_body: dict,
|
||||
visited: Optional[set] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Recursively sanitize request body to prevent logging large base64 strings or other large values.
|
||||
Truncates strings longer than 1000 characters and handles nested dictionaries.
|
||||
"""
|
||||
MAX_STRING_LENGTH = 1000
|
||||
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
||||
# Get the object's memory address to track visited objects
|
||||
obj_id = id(request_body)
|
||||
if obj_id in visited:
|
||||
return {}
|
||||
visited.add(obj_id)
|
||||
|
||||
def _sanitize_value(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
return _sanitize_request_body_for_spend_logs_payload(value, visited)
|
||||
elif isinstance(value, list):
|
||||
return [_sanitize_value(item) for item in value]
|
||||
elif isinstance(value, str):
|
||||
if len(value) > MAX_STRING_LENGTH:
|
||||
return f"{value[:MAX_STRING_LENGTH]}... (truncated {len(value) - MAX_STRING_LENGTH} chars)"
|
||||
return value
|
||||
return value
|
||||
|
||||
return {k: _sanitize_value(v) for k, v in request_body.items()}
|
||||
|
||||
|
||||
def _add_proxy_server_request_to_metadata(
|
||||
metadata: dict,
|
||||
litellm_params: dict,
|
||||
|
@ -373,6 +406,7 @@ def _add_proxy_server_request_to_metadata(
|
|||
)
|
||||
if _proxy_server_request is not None:
|
||||
_request_body = _proxy_server_request.get("body", {}) or {}
|
||||
_request_body = _sanitize_request_body_for_spend_logs_payload(_request_body)
|
||||
_request_body_json_str = json.dumps(_request_body, default=str)
|
||||
metadata["proxy_server_request"] = _request_body_json_str
|
||||
return metadata
|
||||
|
|
|
@ -16,6 +16,8 @@ IGNORE_FUNCTIONS = [
|
|||
"_transform_prompt",
|
||||
"mask_dict",
|
||||
"_serialize", # we now set a max depth for this
|
||||
"_sanitize_request_body_for_spend_logs_payload", # testing added for circular reference
|
||||
"_sanitize_value", # testing added for circular reference
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -33,3 +33,26 @@ def test_response_format_transformation_unit_test():
|
|||
"agent_doing": {"title": "Agent Doing", "type": "string"}
|
||||
}
|
||||
print(result)
|
||||
|
||||
|
||||
def test_calculate_usage():
|
||||
"""
|
||||
Do not include cache_creation_input_tokens in the prompt_tokens
|
||||
|
||||
Fixes https://github.com/BerriAI/litellm/issues/9812
|
||||
"""
|
||||
config = AnthropicConfig()
|
||||
|
||||
usage_object = {
|
||||
"input_tokens": 3,
|
||||
"cache_creation_input_tokens": 12304,
|
||||
"cache_read_input_tokens": 0,
|
||||
"output_tokens": 550,
|
||||
}
|
||||
usage = config.calculate_usage(usage_object=usage_object, reasoning_content=None)
|
||||
assert usage.prompt_tokens == 3
|
||||
assert usage.completion_tokens == 550
|
||||
assert usage.total_tokens == 3 + 550
|
||||
assert usage.prompt_tokens_details.cached_tokens == 0
|
||||
assert usage._cache_creation_input_tokens == 12304
|
||||
assert usage._cache_read_input_tokens == 0
|
||||
|
|
|
@ -30,9 +30,7 @@ def test_transform_usage():
|
|||
openai_usage = config._transform_usage(usage)
|
||||
assert (
|
||||
openai_usage.prompt_tokens
|
||||
== usage["inputTokens"]
|
||||
+ usage["cacheWriteInputTokens"]
|
||||
+ usage["cacheReadInputTokens"]
|
||||
== usage["inputTokens"] + usage["cacheReadInputTokens"]
|
||||
)
|
||||
assert openai_usage.completion_tokens == usage["outputTokens"]
|
||||
assert openai_usage.total_tokens == usage["totalTokens"]
|
||||
|
|
102
tests/litellm/proxy/spend_tracking/test_spend_tracking_utils.py
Normal file
102
tests/litellm/proxy/spend_tracking/test_spend_tracking_utils.py
Normal file
|
@ -0,0 +1,102 @@
|
|||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import litellm
|
||||
from litellm.proxy.spend_tracking.spend_tracking_utils import (
|
||||
_sanitize_request_body_for_spend_logs_payload,
|
||||
)
|
||||
|
||||
|
||||
def test_sanitize_request_body_for_spend_logs_payload_basic():
|
||||
request_body = {
|
||||
"messages": [{"role": "user", "content": "Hello, how are you?"}],
|
||||
}
|
||||
assert _sanitize_request_body_for_spend_logs_payload(request_body) == request_body
|
||||
|
||||
|
||||
def test_sanitize_request_body_for_spend_logs_payload_long_string():
|
||||
long_string = "a" * 2000 # Create a string longer than MAX_STRING_LENGTH
|
||||
request_body = {"text": long_string, "normal_text": "short text"}
|
||||
sanitized = _sanitize_request_body_for_spend_logs_payload(request_body)
|
||||
assert len(sanitized["text"]) == 1000 + len("... (truncated 1000 chars)")
|
||||
assert sanitized["normal_text"] == "short text"
|
||||
|
||||
|
||||
def test_sanitize_request_body_for_spend_logs_payload_nested_dict():
|
||||
request_body = {"outer": {"inner": {"text": "a" * 2000, "normal": "short"}}}
|
||||
sanitized = _sanitize_request_body_for_spend_logs_payload(request_body)
|
||||
assert len(sanitized["outer"]["inner"]["text"]) == 1000 + len(
|
||||
"... (truncated 1000 chars)"
|
||||
)
|
||||
assert sanitized["outer"]["inner"]["normal"] == "short"
|
||||
|
||||
|
||||
def test_sanitize_request_body_for_spend_logs_payload_nested_list():
|
||||
request_body = {
|
||||
"items": [{"text": "a" * 2000}, {"text": "short"}, [{"text": "a" * 2000}]]
|
||||
}
|
||||
sanitized = _sanitize_request_body_for_spend_logs_payload(request_body)
|
||||
assert len(sanitized["items"][0]["text"]) == 1000 + len(
|
||||
"... (truncated 1000 chars)"
|
||||
)
|
||||
assert sanitized["items"][1]["text"] == "short"
|
||||
assert len(sanitized["items"][2][0]["text"]) == 1000 + len(
|
||||
"... (truncated 1000 chars)"
|
||||
)
|
||||
|
||||
|
||||
def test_sanitize_request_body_for_spend_logs_payload_non_string_values():
|
||||
request_body = {"number": 42, "boolean": True, "none": None, "float": 3.14}
|
||||
sanitized = _sanitize_request_body_for_spend_logs_payload(request_body)
|
||||
assert sanitized == request_body
|
||||
|
||||
|
||||
def test_sanitize_request_body_for_spend_logs_payload_empty():
|
||||
request_body: dict[str, Any] = {}
|
||||
sanitized = _sanitize_request_body_for_spend_logs_payload(request_body)
|
||||
assert sanitized == request_body
|
||||
|
||||
|
||||
def test_sanitize_request_body_for_spend_logs_payload_mixed_types():
|
||||
request_body = {
|
||||
"text": "a" * 2000,
|
||||
"number": 42,
|
||||
"nested": {"list": ["short", "a" * 2000], "dict": {"key": "a" * 2000}},
|
||||
}
|
||||
sanitized = _sanitize_request_body_for_spend_logs_payload(request_body)
|
||||
assert len(sanitized["text"]) == 1000 + len("... (truncated 1000 chars)")
|
||||
assert sanitized["number"] == 42
|
||||
assert sanitized["nested"]["list"][0] == "short"
|
||||
assert len(sanitized["nested"]["list"][1]) == 1000 + len(
|
||||
"... (truncated 1000 chars)"
|
||||
)
|
||||
assert len(sanitized["nested"]["dict"]["key"]) == 1000 + len(
|
||||
"... (truncated 1000 chars)"
|
||||
)
|
||||
|
||||
|
||||
def test_sanitize_request_body_for_spend_logs_payload_circular_reference():
|
||||
# Create a circular reference
|
||||
a: dict[str, Any] = {}
|
||||
b: dict[str, Any] = {"a": a}
|
||||
a["b"] = b
|
||||
|
||||
# Test that it handles circular reference without infinite recursion
|
||||
sanitized = _sanitize_request_body_for_spend_logs_payload(a)
|
||||
assert sanitized == {
|
||||
"b": {"a": {}}
|
||||
} # Should return empty dict for circular reference
|
Loading…
Add table
Add a link
Reference in a new issue