diff --git a/litellm/llms/anthropic/chat/handler.py b/litellm/llms/anthropic/chat/handler.py index 44567facf9..ebb8650044 100644 --- a/litellm/llms/anthropic/chat/handler.py +++ b/litellm/llms/anthropic/chat/handler.py @@ -4,7 +4,7 @@ Calling + translation logic for anthropic's `/v1/messages` endpoint import copy import json -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast import httpx # type: ignore @@ -492,29 +492,10 @@ class ModelResponseIterator: return False def _handle_usage(self, anthropic_usage_chunk: Union[dict, UsageDelta]) -> Usage: - usage_block = Usage( - prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0), - completion_tokens=anthropic_usage_chunk.get("output_tokens", 0), - total_tokens=anthropic_usage_chunk.get("input_tokens", 0) - + anthropic_usage_chunk.get("output_tokens", 0), + return AnthropicConfig().calculate_usage( + usage_object=cast(dict, anthropic_usage_chunk), reasoning_content=None ) - cache_creation_input_tokens = anthropic_usage_chunk.get( - "cache_creation_input_tokens" - ) - if cache_creation_input_tokens is not None and isinstance( - cache_creation_input_tokens, int - ): - usage_block["cache_creation_input_tokens"] = cache_creation_input_tokens - - cache_read_input_tokens = anthropic_usage_chunk.get("cache_read_input_tokens") - if cache_read_input_tokens is not None and isinstance( - cache_read_input_tokens, int - ): - usage_block["cache_read_input_tokens"] = cache_read_input_tokens - - return usage_block - def _content_block_delta_helper( self, chunk: dict ) -> Tuple[ diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py index 9b66249630..96da34a855 100644 --- a/litellm/llms/anthropic/chat/transformation.py +++ b/litellm/llms/anthropic/chat/transformation.py @@ -682,6 +682,45 @@ class AnthropicConfig(BaseConfig): reasoning_content += block["thinking"] return text_content, citations, thinking_blocks, reasoning_content, tool_calls + def calculate_usage( + self, usage_object: dict, reasoning_content: Optional[str] + ) -> Usage: + prompt_tokens = usage_object.get("input_tokens", 0) + completion_tokens = usage_object.get("output_tokens", 0) + _usage = usage_object + cache_creation_input_tokens: int = 0 + cache_read_input_tokens: int = 0 + + if "cache_creation_input_tokens" in _usage: + cache_creation_input_tokens = _usage["cache_creation_input_tokens"] + if "cache_read_input_tokens" in _usage: + cache_read_input_tokens = _usage["cache_read_input_tokens"] + prompt_tokens += cache_read_input_tokens + + prompt_tokens_details = PromptTokensDetailsWrapper( + cached_tokens=cache_read_input_tokens + ) + completion_token_details = ( + CompletionTokensDetailsWrapper( + reasoning_tokens=token_counter( + text=reasoning_content, count_response_tokens=True + ) + ) + if reasoning_content + else None + ) + total_tokens = prompt_tokens + completion_tokens + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + prompt_tokens_details=prompt_tokens_details, + cache_creation_input_tokens=cache_creation_input_tokens, + cache_read_input_tokens=cache_read_input_tokens, + completion_tokens_details=completion_token_details, + ) + return usage + def transform_response( self, model: str, @@ -772,45 +811,14 @@ class AnthropicConfig(BaseConfig): ) ## CALCULATING USAGE - prompt_tokens = completion_response["usage"]["input_tokens"] - completion_tokens = completion_response["usage"]["output_tokens"] - _usage = completion_response["usage"] - cache_creation_input_tokens: int = 0 - cache_read_input_tokens: int = 0 + usage = self.calculate_usage( + usage_object=completion_response["usage"], + reasoning_content=reasoning_content, + ) + setattr(model_response, "usage", usage) # type: ignore model_response.created = int(time.time()) model_response.model = completion_response["model"] - if "cache_creation_input_tokens" in _usage: - cache_creation_input_tokens = _usage["cache_creation_input_tokens"] - prompt_tokens += cache_creation_input_tokens - if "cache_read_input_tokens" in _usage: - cache_read_input_tokens = _usage["cache_read_input_tokens"] - prompt_tokens += cache_read_input_tokens - - prompt_tokens_details = PromptTokensDetailsWrapper( - cached_tokens=cache_read_input_tokens - ) - completion_token_details = ( - CompletionTokensDetailsWrapper( - reasoning_tokens=token_counter( - text=reasoning_content, count_response_tokens=True - ) - ) - if reasoning_content - else None - ) - total_tokens = prompt_tokens + completion_tokens - usage = Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - prompt_tokens_details=prompt_tokens_details, - cache_creation_input_tokens=cache_creation_input_tokens, - cache_read_input_tokens=cache_read_input_tokens, - completion_tokens_details=completion_token_details, - ) - - setattr(model_response, "usage", usage) # type: ignore model_response._hidden_params = _hidden_params return model_response diff --git a/litellm/llms/bedrock/chat/converse_transformation.py b/litellm/llms/bedrock/chat/converse_transformation.py index fbe2dc4937..76ea51f435 100644 --- a/litellm/llms/bedrock/chat/converse_transformation.py +++ b/litellm/llms/bedrock/chat/converse_transformation.py @@ -653,8 +653,10 @@ class AmazonConverseConfig(BaseConfig): cache_read_input_tokens = usage["cacheReadInputTokens"] input_tokens += cache_read_input_tokens if "cacheWriteInputTokens" in usage: + """ + Do not increment prompt_tokens with cacheWriteInputTokens + """ cache_creation_input_tokens = usage["cacheWriteInputTokens"] - input_tokens += cache_creation_input_tokens prompt_tokens_details = PromptTokensDetailsWrapper( cached_tokens=cache_read_input_tokens diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 511a8cde94..f64e243489 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -44,4 +44,7 @@ litellm_settings: files_settings: - custom_llm_provider: gemini - api_key: os.environ/GEMINI_API_KEY \ No newline at end of file + api_key: os.environ/GEMINI_API_KEY + +general_settings: + store_prompts_in_spend_logs: true \ No newline at end of file diff --git a/litellm/proxy/spend_tracking/spend_tracking_utils.py b/litellm/proxy/spend_tracking/spend_tracking_utils.py index 096c5191b1..51991e339e 100644 --- a/litellm/proxy/spend_tracking/spend_tracking_utils.py +++ b/litellm/proxy/spend_tracking/spend_tracking_utils.py @@ -360,6 +360,39 @@ def _get_messages_for_spend_logs_payload( return "{}" +def _sanitize_request_body_for_spend_logs_payload( + request_body: dict, + visited: Optional[set] = None, +) -> dict: + """ + Recursively sanitize request body to prevent logging large base64 strings or other large values. + Truncates strings longer than 1000 characters and handles nested dictionaries. + """ + MAX_STRING_LENGTH = 1000 + + if visited is None: + visited = set() + + # Get the object's memory address to track visited objects + obj_id = id(request_body) + if obj_id in visited: + return {} + visited.add(obj_id) + + def _sanitize_value(value: Any) -> Any: + if isinstance(value, dict): + return _sanitize_request_body_for_spend_logs_payload(value, visited) + elif isinstance(value, list): + return [_sanitize_value(item) for item in value] + elif isinstance(value, str): + if len(value) > MAX_STRING_LENGTH: + return f"{value[:MAX_STRING_LENGTH]}... (truncated {len(value) - MAX_STRING_LENGTH} chars)" + return value + return value + + return {k: _sanitize_value(v) for k, v in request_body.items()} + + def _add_proxy_server_request_to_metadata( metadata: dict, litellm_params: dict, @@ -373,6 +406,7 @@ def _add_proxy_server_request_to_metadata( ) if _proxy_server_request is not None: _request_body = _proxy_server_request.get("body", {}) or {} + _request_body = _sanitize_request_body_for_spend_logs_payload(_request_body) _request_body_json_str = json.dumps(_request_body, default=str) metadata["proxy_server_request"] = _request_body_json_str return metadata diff --git a/tests/code_coverage_tests/recursive_detector.py b/tests/code_coverage_tests/recursive_detector.py index 48fe604dbc..c0761975a2 100644 --- a/tests/code_coverage_tests/recursive_detector.py +++ b/tests/code_coverage_tests/recursive_detector.py @@ -16,6 +16,8 @@ IGNORE_FUNCTIONS = [ "_transform_prompt", "mask_dict", "_serialize", # we now set a max depth for this + "_sanitize_request_body_for_spend_logs_payload", # testing added for circular reference + "_sanitize_value", # testing added for circular reference ] diff --git a/tests/litellm/llms/anthropic/chat/test_anthropic_chat_transformation.py b/tests/litellm/llms/anthropic/chat/test_anthropic_chat_transformation.py index 04f2728284..9f672110a4 100644 --- a/tests/litellm/llms/anthropic/chat/test_anthropic_chat_transformation.py +++ b/tests/litellm/llms/anthropic/chat/test_anthropic_chat_transformation.py @@ -33,3 +33,26 @@ def test_response_format_transformation_unit_test(): "agent_doing": {"title": "Agent Doing", "type": "string"} } print(result) + + +def test_calculate_usage(): + """ + Do not include cache_creation_input_tokens in the prompt_tokens + + Fixes https://github.com/BerriAI/litellm/issues/9812 + """ + config = AnthropicConfig() + + usage_object = { + "input_tokens": 3, + "cache_creation_input_tokens": 12304, + "cache_read_input_tokens": 0, + "output_tokens": 550, + } + usage = config.calculate_usage(usage_object=usage_object, reasoning_content=None) + assert usage.prompt_tokens == 3 + assert usage.completion_tokens == 550 + assert usage.total_tokens == 3 + 550 + assert usage.prompt_tokens_details.cached_tokens == 0 + assert usage._cache_creation_input_tokens == 12304 + assert usage._cache_read_input_tokens == 0 diff --git a/tests/litellm/llms/bedrock/chat/test_converse_transformation.py b/tests/litellm/llms/bedrock/chat/test_converse_transformation.py index e912ada8ff..5390daa1a0 100644 --- a/tests/litellm/llms/bedrock/chat/test_converse_transformation.py +++ b/tests/litellm/llms/bedrock/chat/test_converse_transformation.py @@ -30,9 +30,7 @@ def test_transform_usage(): openai_usage = config._transform_usage(usage) assert ( openai_usage.prompt_tokens - == usage["inputTokens"] - + usage["cacheWriteInputTokens"] - + usage["cacheReadInputTokens"] + == usage["inputTokens"] + usage["cacheReadInputTokens"] ) assert openai_usage.completion_tokens == usage["outputTokens"] assert openai_usage.total_tokens == usage["totalTokens"] diff --git a/tests/litellm/proxy/spend_tracking/test_spend_tracking_utils.py b/tests/litellm/proxy/spend_tracking/test_spend_tracking_utils.py new file mode 100644 index 0000000000..2bef2512f3 --- /dev/null +++ b/tests/litellm/proxy/spend_tracking/test_spend_tracking_utils.py @@ -0,0 +1,102 @@ +import asyncio +import datetime +import json +import os +import sys +from datetime import timezone +from typing import Any + +import pytest +from fastapi.testclient import TestClient + +sys.path.insert( + 0, os.path.abspath("../../../..") +) # Adds the parent directory to the system path + +from unittest.mock import MagicMock, patch + +import litellm +from litellm.proxy.spend_tracking.spend_tracking_utils import ( + _sanitize_request_body_for_spend_logs_payload, +) + + +def test_sanitize_request_body_for_spend_logs_payload_basic(): + request_body = { + "messages": [{"role": "user", "content": "Hello, how are you?"}], + } + assert _sanitize_request_body_for_spend_logs_payload(request_body) == request_body + + +def test_sanitize_request_body_for_spend_logs_payload_long_string(): + long_string = "a" * 2000 # Create a string longer than MAX_STRING_LENGTH + request_body = {"text": long_string, "normal_text": "short text"} + sanitized = _sanitize_request_body_for_spend_logs_payload(request_body) + assert len(sanitized["text"]) == 1000 + len("... (truncated 1000 chars)") + assert sanitized["normal_text"] == "short text" + + +def test_sanitize_request_body_for_spend_logs_payload_nested_dict(): + request_body = {"outer": {"inner": {"text": "a" * 2000, "normal": "short"}}} + sanitized = _sanitize_request_body_for_spend_logs_payload(request_body) + assert len(sanitized["outer"]["inner"]["text"]) == 1000 + len( + "... (truncated 1000 chars)" + ) + assert sanitized["outer"]["inner"]["normal"] == "short" + + +def test_sanitize_request_body_for_spend_logs_payload_nested_list(): + request_body = { + "items": [{"text": "a" * 2000}, {"text": "short"}, [{"text": "a" * 2000}]] + } + sanitized = _sanitize_request_body_for_spend_logs_payload(request_body) + assert len(sanitized["items"][0]["text"]) == 1000 + len( + "... (truncated 1000 chars)" + ) + assert sanitized["items"][1]["text"] == "short" + assert len(sanitized["items"][2][0]["text"]) == 1000 + len( + "... (truncated 1000 chars)" + ) + + +def test_sanitize_request_body_for_spend_logs_payload_non_string_values(): + request_body = {"number": 42, "boolean": True, "none": None, "float": 3.14} + sanitized = _sanitize_request_body_for_spend_logs_payload(request_body) + assert sanitized == request_body + + +def test_sanitize_request_body_for_spend_logs_payload_empty(): + request_body: dict[str, Any] = {} + sanitized = _sanitize_request_body_for_spend_logs_payload(request_body) + assert sanitized == request_body + + +def test_sanitize_request_body_for_spend_logs_payload_mixed_types(): + request_body = { + "text": "a" * 2000, + "number": 42, + "nested": {"list": ["short", "a" * 2000], "dict": {"key": "a" * 2000}}, + } + sanitized = _sanitize_request_body_for_spend_logs_payload(request_body) + assert len(sanitized["text"]) == 1000 + len("... (truncated 1000 chars)") + assert sanitized["number"] == 42 + assert sanitized["nested"]["list"][0] == "short" + assert len(sanitized["nested"]["list"][1]) == 1000 + len( + "... (truncated 1000 chars)" + ) + assert len(sanitized["nested"]["dict"]["key"]) == 1000 + len( + "... (truncated 1000 chars)" + ) + + +def test_sanitize_request_body_for_spend_logs_payload_circular_reference(): + # Create a circular reference + a: dict[str, Any] = {} + b: dict[str, Any] = {"a": a} + a["b"] = b + + # Test that it handles circular reference without infinite recursion + sanitized = _sanitize_request_body_for_spend_logs_payload(a) + assert sanitized == { + "b": {"a": {}} + } # Should return empty dict for circular reference