mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Fix anthropic prompt caching cost calc + trim logged message in db (#9838)
* fix(spend_tracking_utils.py): prevent logging entire mp4 files to db Fixes https://github.com/BerriAI/litellm/issues/9732 * fix(anthropic/chat/transformation.py): Fix double counting cache creation input tokens Fixes https://github.com/BerriAI/litellm/issues/9812 * refactor(anthropic/chat/transformation.py): refactor streaming to use same usage calculation block as non-streaming reduce errors * fix(bedrock/chat/converse_transformation.py): don't increment prompt tokens with cache_creation_input_tokens * build: remove redisvl from requirements.txt (temporary) * fix(spend_tracking_utils.py): handle circular references * test: update code cov test * test: update test
This commit is contained in:
parent
00c5c23d97
commit
87733c8193
9 changed files with 216 additions and 63 deletions
|
@ -4,7 +4,7 @@ Calling + translation logic for anthropic's `/v1/messages` endpoint
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
|
|
||||||
|
@ -492,29 +492,10 @@ class ModelResponseIterator:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _handle_usage(self, anthropic_usage_chunk: Union[dict, UsageDelta]) -> Usage:
|
def _handle_usage(self, anthropic_usage_chunk: Union[dict, UsageDelta]) -> Usage:
|
||||||
usage_block = Usage(
|
return AnthropicConfig().calculate_usage(
|
||||||
prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0),
|
usage_object=cast(dict, anthropic_usage_chunk), reasoning_content=None
|
||||||
completion_tokens=anthropic_usage_chunk.get("output_tokens", 0),
|
|
||||||
total_tokens=anthropic_usage_chunk.get("input_tokens", 0)
|
|
||||||
+ anthropic_usage_chunk.get("output_tokens", 0),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
cache_creation_input_tokens = anthropic_usage_chunk.get(
|
|
||||||
"cache_creation_input_tokens"
|
|
||||||
)
|
|
||||||
if cache_creation_input_tokens is not None and isinstance(
|
|
||||||
cache_creation_input_tokens, int
|
|
||||||
):
|
|
||||||
usage_block["cache_creation_input_tokens"] = cache_creation_input_tokens
|
|
||||||
|
|
||||||
cache_read_input_tokens = anthropic_usage_chunk.get("cache_read_input_tokens")
|
|
||||||
if cache_read_input_tokens is not None and isinstance(
|
|
||||||
cache_read_input_tokens, int
|
|
||||||
):
|
|
||||||
usage_block["cache_read_input_tokens"] = cache_read_input_tokens
|
|
||||||
|
|
||||||
return usage_block
|
|
||||||
|
|
||||||
def _content_block_delta_helper(
|
def _content_block_delta_helper(
|
||||||
self, chunk: dict
|
self, chunk: dict
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
|
|
|
@ -682,6 +682,45 @@ class AnthropicConfig(BaseConfig):
|
||||||
reasoning_content += block["thinking"]
|
reasoning_content += block["thinking"]
|
||||||
return text_content, citations, thinking_blocks, reasoning_content, tool_calls
|
return text_content, citations, thinking_blocks, reasoning_content, tool_calls
|
||||||
|
|
||||||
|
def calculate_usage(
|
||||||
|
self, usage_object: dict, reasoning_content: Optional[str]
|
||||||
|
) -> Usage:
|
||||||
|
prompt_tokens = usage_object.get("input_tokens", 0)
|
||||||
|
completion_tokens = usage_object.get("output_tokens", 0)
|
||||||
|
_usage = usage_object
|
||||||
|
cache_creation_input_tokens: int = 0
|
||||||
|
cache_read_input_tokens: int = 0
|
||||||
|
|
||||||
|
if "cache_creation_input_tokens" in _usage:
|
||||||
|
cache_creation_input_tokens = _usage["cache_creation_input_tokens"]
|
||||||
|
if "cache_read_input_tokens" in _usage:
|
||||||
|
cache_read_input_tokens = _usage["cache_read_input_tokens"]
|
||||||
|
prompt_tokens += cache_read_input_tokens
|
||||||
|
|
||||||
|
prompt_tokens_details = PromptTokensDetailsWrapper(
|
||||||
|
cached_tokens=cache_read_input_tokens
|
||||||
|
)
|
||||||
|
completion_token_details = (
|
||||||
|
CompletionTokensDetailsWrapper(
|
||||||
|
reasoning_tokens=token_counter(
|
||||||
|
text=reasoning_content, count_response_tokens=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if reasoning_content
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
total_tokens = prompt_tokens + completion_tokens
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
prompt_tokens_details=prompt_tokens_details,
|
||||||
|
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||||
|
cache_read_input_tokens=cache_read_input_tokens,
|
||||||
|
completion_tokens_details=completion_token_details,
|
||||||
|
)
|
||||||
|
return usage
|
||||||
|
|
||||||
def transform_response(
|
def transform_response(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -772,45 +811,14 @@ class AnthropicConfig(BaseConfig):
|
||||||
)
|
)
|
||||||
|
|
||||||
## CALCULATING USAGE
|
## CALCULATING USAGE
|
||||||
prompt_tokens = completion_response["usage"]["input_tokens"]
|
usage = self.calculate_usage(
|
||||||
completion_tokens = completion_response["usage"]["output_tokens"]
|
usage_object=completion_response["usage"],
|
||||||
_usage = completion_response["usage"]
|
reasoning_content=reasoning_content,
|
||||||
cache_creation_input_tokens: int = 0
|
)
|
||||||
cache_read_input_tokens: int = 0
|
setattr(model_response, "usage", usage) # type: ignore
|
||||||
|
|
||||||
model_response.created = int(time.time())
|
model_response.created = int(time.time())
|
||||||
model_response.model = completion_response["model"]
|
model_response.model = completion_response["model"]
|
||||||
if "cache_creation_input_tokens" in _usage:
|
|
||||||
cache_creation_input_tokens = _usage["cache_creation_input_tokens"]
|
|
||||||
prompt_tokens += cache_creation_input_tokens
|
|
||||||
if "cache_read_input_tokens" in _usage:
|
|
||||||
cache_read_input_tokens = _usage["cache_read_input_tokens"]
|
|
||||||
prompt_tokens += cache_read_input_tokens
|
|
||||||
|
|
||||||
prompt_tokens_details = PromptTokensDetailsWrapper(
|
|
||||||
cached_tokens=cache_read_input_tokens
|
|
||||||
)
|
|
||||||
completion_token_details = (
|
|
||||||
CompletionTokensDetailsWrapper(
|
|
||||||
reasoning_tokens=token_counter(
|
|
||||||
text=reasoning_content, count_response_tokens=True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if reasoning_content
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
total_tokens = prompt_tokens + completion_tokens
|
|
||||||
usage = Usage(
|
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
total_tokens=total_tokens,
|
|
||||||
prompt_tokens_details=prompt_tokens_details,
|
|
||||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
|
||||||
cache_read_input_tokens=cache_read_input_tokens,
|
|
||||||
completion_tokens_details=completion_token_details,
|
|
||||||
)
|
|
||||||
|
|
||||||
setattr(model_response, "usage", usage) # type: ignore
|
|
||||||
|
|
||||||
model_response._hidden_params = _hidden_params
|
model_response._hidden_params = _hidden_params
|
||||||
return model_response
|
return model_response
|
||||||
|
|
|
@ -653,8 +653,10 @@ class AmazonConverseConfig(BaseConfig):
|
||||||
cache_read_input_tokens = usage["cacheReadInputTokens"]
|
cache_read_input_tokens = usage["cacheReadInputTokens"]
|
||||||
input_tokens += cache_read_input_tokens
|
input_tokens += cache_read_input_tokens
|
||||||
if "cacheWriteInputTokens" in usage:
|
if "cacheWriteInputTokens" in usage:
|
||||||
|
"""
|
||||||
|
Do not increment prompt_tokens with cacheWriteInputTokens
|
||||||
|
"""
|
||||||
cache_creation_input_tokens = usage["cacheWriteInputTokens"]
|
cache_creation_input_tokens = usage["cacheWriteInputTokens"]
|
||||||
input_tokens += cache_creation_input_tokens
|
|
||||||
|
|
||||||
prompt_tokens_details = PromptTokensDetailsWrapper(
|
prompt_tokens_details = PromptTokensDetailsWrapper(
|
||||||
cached_tokens=cache_read_input_tokens
|
cached_tokens=cache_read_input_tokens
|
||||||
|
|
|
@ -45,3 +45,6 @@ litellm_settings:
|
||||||
files_settings:
|
files_settings:
|
||||||
- custom_llm_provider: gemini
|
- custom_llm_provider: gemini
|
||||||
api_key: os.environ/GEMINI_API_KEY
|
api_key: os.environ/GEMINI_API_KEY
|
||||||
|
|
||||||
|
general_settings:
|
||||||
|
store_prompts_in_spend_logs: true
|
|
@ -360,6 +360,39 @@ def _get_messages_for_spend_logs_payload(
|
||||||
return "{}"
|
return "{}"
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_request_body_for_spend_logs_payload(
|
||||||
|
request_body: dict,
|
||||||
|
visited: Optional[set] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Recursively sanitize request body to prevent logging large base64 strings or other large values.
|
||||||
|
Truncates strings longer than 1000 characters and handles nested dictionaries.
|
||||||
|
"""
|
||||||
|
MAX_STRING_LENGTH = 1000
|
||||||
|
|
||||||
|
if visited is None:
|
||||||
|
visited = set()
|
||||||
|
|
||||||
|
# Get the object's memory address to track visited objects
|
||||||
|
obj_id = id(request_body)
|
||||||
|
if obj_id in visited:
|
||||||
|
return {}
|
||||||
|
visited.add(obj_id)
|
||||||
|
|
||||||
|
def _sanitize_value(value: Any) -> Any:
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return _sanitize_request_body_for_spend_logs_payload(value, visited)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
return [_sanitize_value(item) for item in value]
|
||||||
|
elif isinstance(value, str):
|
||||||
|
if len(value) > MAX_STRING_LENGTH:
|
||||||
|
return f"{value[:MAX_STRING_LENGTH]}... (truncated {len(value) - MAX_STRING_LENGTH} chars)"
|
||||||
|
return value
|
||||||
|
return value
|
||||||
|
|
||||||
|
return {k: _sanitize_value(v) for k, v in request_body.items()}
|
||||||
|
|
||||||
|
|
||||||
def _add_proxy_server_request_to_metadata(
|
def _add_proxy_server_request_to_metadata(
|
||||||
metadata: dict,
|
metadata: dict,
|
||||||
litellm_params: dict,
|
litellm_params: dict,
|
||||||
|
@ -373,6 +406,7 @@ def _add_proxy_server_request_to_metadata(
|
||||||
)
|
)
|
||||||
if _proxy_server_request is not None:
|
if _proxy_server_request is not None:
|
||||||
_request_body = _proxy_server_request.get("body", {}) or {}
|
_request_body = _proxy_server_request.get("body", {}) or {}
|
||||||
|
_request_body = _sanitize_request_body_for_spend_logs_payload(_request_body)
|
||||||
_request_body_json_str = json.dumps(_request_body, default=str)
|
_request_body_json_str = json.dumps(_request_body, default=str)
|
||||||
metadata["proxy_server_request"] = _request_body_json_str
|
metadata["proxy_server_request"] = _request_body_json_str
|
||||||
return metadata
|
return metadata
|
||||||
|
|
|
@ -16,6 +16,8 @@ IGNORE_FUNCTIONS = [
|
||||||
"_transform_prompt",
|
"_transform_prompt",
|
||||||
"mask_dict",
|
"mask_dict",
|
||||||
"_serialize", # we now set a max depth for this
|
"_serialize", # we now set a max depth for this
|
||||||
|
"_sanitize_request_body_for_spend_logs_payload", # testing added for circular reference
|
||||||
|
"_sanitize_value", # testing added for circular reference
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -33,3 +33,26 @@ def test_response_format_transformation_unit_test():
|
||||||
"agent_doing": {"title": "Agent Doing", "type": "string"}
|
"agent_doing": {"title": "Agent Doing", "type": "string"}
|
||||||
}
|
}
|
||||||
print(result)
|
print(result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_calculate_usage():
|
||||||
|
"""
|
||||||
|
Do not include cache_creation_input_tokens in the prompt_tokens
|
||||||
|
|
||||||
|
Fixes https://github.com/BerriAI/litellm/issues/9812
|
||||||
|
"""
|
||||||
|
config = AnthropicConfig()
|
||||||
|
|
||||||
|
usage_object = {
|
||||||
|
"input_tokens": 3,
|
||||||
|
"cache_creation_input_tokens": 12304,
|
||||||
|
"cache_read_input_tokens": 0,
|
||||||
|
"output_tokens": 550,
|
||||||
|
}
|
||||||
|
usage = config.calculate_usage(usage_object=usage_object, reasoning_content=None)
|
||||||
|
assert usage.prompt_tokens == 3
|
||||||
|
assert usage.completion_tokens == 550
|
||||||
|
assert usage.total_tokens == 3 + 550
|
||||||
|
assert usage.prompt_tokens_details.cached_tokens == 0
|
||||||
|
assert usage._cache_creation_input_tokens == 12304
|
||||||
|
assert usage._cache_read_input_tokens == 0
|
||||||
|
|
|
@ -30,9 +30,7 @@ def test_transform_usage():
|
||||||
openai_usage = config._transform_usage(usage)
|
openai_usage = config._transform_usage(usage)
|
||||||
assert (
|
assert (
|
||||||
openai_usage.prompt_tokens
|
openai_usage.prompt_tokens
|
||||||
== usage["inputTokens"]
|
== usage["inputTokens"] + usage["cacheReadInputTokens"]
|
||||||
+ usage["cacheWriteInputTokens"]
|
|
||||||
+ usage["cacheReadInputTokens"]
|
|
||||||
)
|
)
|
||||||
assert openai_usage.completion_tokens == usage["outputTokens"]
|
assert openai_usage.completion_tokens == usage["outputTokens"]
|
||||||
assert openai_usage.total_tokens == usage["totalTokens"]
|
assert openai_usage.total_tokens == usage["totalTokens"]
|
||||||
|
|
102
tests/litellm/proxy/spend_tracking/test_spend_tracking_utils.py
Normal file
102
tests/litellm/proxy/spend_tracking/test_spend_tracking_utils.py
Normal file
|
@ -0,0 +1,102 @@
|
||||||
|
import asyncio
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.proxy.spend_tracking.spend_tracking_utils import (
|
||||||
|
_sanitize_request_body_for_spend_logs_payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_request_body_for_spend_logs_payload_basic():
|
||||||
|
request_body = {
|
||||||
|
"messages": [{"role": "user", "content": "Hello, how are you?"}],
|
||||||
|
}
|
||||||
|
assert _sanitize_request_body_for_spend_logs_payload(request_body) == request_body
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_request_body_for_spend_logs_payload_long_string():
|
||||||
|
long_string = "a" * 2000 # Create a string longer than MAX_STRING_LENGTH
|
||||||
|
request_body = {"text": long_string, "normal_text": "short text"}
|
||||||
|
sanitized = _sanitize_request_body_for_spend_logs_payload(request_body)
|
||||||
|
assert len(sanitized["text"]) == 1000 + len("... (truncated 1000 chars)")
|
||||||
|
assert sanitized["normal_text"] == "short text"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_request_body_for_spend_logs_payload_nested_dict():
|
||||||
|
request_body = {"outer": {"inner": {"text": "a" * 2000, "normal": "short"}}}
|
||||||
|
sanitized = _sanitize_request_body_for_spend_logs_payload(request_body)
|
||||||
|
assert len(sanitized["outer"]["inner"]["text"]) == 1000 + len(
|
||||||
|
"... (truncated 1000 chars)"
|
||||||
|
)
|
||||||
|
assert sanitized["outer"]["inner"]["normal"] == "short"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_request_body_for_spend_logs_payload_nested_list():
|
||||||
|
request_body = {
|
||||||
|
"items": [{"text": "a" * 2000}, {"text": "short"}, [{"text": "a" * 2000}]]
|
||||||
|
}
|
||||||
|
sanitized = _sanitize_request_body_for_spend_logs_payload(request_body)
|
||||||
|
assert len(sanitized["items"][0]["text"]) == 1000 + len(
|
||||||
|
"... (truncated 1000 chars)"
|
||||||
|
)
|
||||||
|
assert sanitized["items"][1]["text"] == "short"
|
||||||
|
assert len(sanitized["items"][2][0]["text"]) == 1000 + len(
|
||||||
|
"... (truncated 1000 chars)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_request_body_for_spend_logs_payload_non_string_values():
|
||||||
|
request_body = {"number": 42, "boolean": True, "none": None, "float": 3.14}
|
||||||
|
sanitized = _sanitize_request_body_for_spend_logs_payload(request_body)
|
||||||
|
assert sanitized == request_body
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_request_body_for_spend_logs_payload_empty():
|
||||||
|
request_body: dict[str, Any] = {}
|
||||||
|
sanitized = _sanitize_request_body_for_spend_logs_payload(request_body)
|
||||||
|
assert sanitized == request_body
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_request_body_for_spend_logs_payload_mixed_types():
|
||||||
|
request_body = {
|
||||||
|
"text": "a" * 2000,
|
||||||
|
"number": 42,
|
||||||
|
"nested": {"list": ["short", "a" * 2000], "dict": {"key": "a" * 2000}},
|
||||||
|
}
|
||||||
|
sanitized = _sanitize_request_body_for_spend_logs_payload(request_body)
|
||||||
|
assert len(sanitized["text"]) == 1000 + len("... (truncated 1000 chars)")
|
||||||
|
assert sanitized["number"] == 42
|
||||||
|
assert sanitized["nested"]["list"][0] == "short"
|
||||||
|
assert len(sanitized["nested"]["list"][1]) == 1000 + len(
|
||||||
|
"... (truncated 1000 chars)"
|
||||||
|
)
|
||||||
|
assert len(sanitized["nested"]["dict"]["key"]) == 1000 + len(
|
||||||
|
"... (truncated 1000 chars)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_request_body_for_spend_logs_payload_circular_reference():
|
||||||
|
# Create a circular reference
|
||||||
|
a: dict[str, Any] = {}
|
||||||
|
b: dict[str, Any] = {"a": a}
|
||||||
|
a["b"] = b
|
||||||
|
|
||||||
|
# Test that it handles circular reference without infinite recursion
|
||||||
|
sanitized = _sanitize_request_body_for_spend_logs_payload(a)
|
||||||
|
assert sanitized == {
|
||||||
|
"b": {"a": {}}
|
||||||
|
} # Should return empty dict for circular reference
|
Loading…
Add table
Add a link
Reference in a new issue