mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
* fix(spend_tracking_utils.py): prevent logging entire mp4 files to db Fixes https://github.com/BerriAI/litellm/issues/9732 * fix(anthropic/chat/transformation.py): Fix double counting cache creation input tokens Fixes https://github.com/BerriAI/litellm/issues/9812 * refactor(anthropic/chat/transformation.py): refactor streaming to use same usage calculation block as non-streaming reduce errors * fix(bedrock/chat/converse_transformation.py): don't increment prompt tokens with cache_creation_input_tokens * build: remove redisvl from requirements.txt (temporary) * fix(spend_tracking_utils.py): handle circular references * test: update code cov test * test: update test
428 lines
14 KiB
Python
428 lines
14 KiB
Python
import hashlib
|
|
import json
|
|
import secrets
|
|
from datetime import datetime
|
|
from datetime import datetime as dt
|
|
from datetime import timezone
|
|
from typing import Any, List, Optional, cast
|
|
|
|
from pydantic import BaseModel
|
|
|
|
import litellm
|
|
from litellm._logging import verbose_proxy_logger
|
|
from litellm.litellm_core_utils.core_helpers import get_litellm_metadata_from_kwargs
|
|
from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
|
|
from litellm.proxy.utils import PrismaClient, hash_token
|
|
from litellm.types.utils import StandardLoggingMCPToolCall, StandardLoggingPayload
|
|
from litellm.utils import get_end_user_id_for_cost_tracking
|
|
|
|
|
|
def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool:
|
|
if _master_key is None:
|
|
return False
|
|
|
|
## string comparison
|
|
is_master_key = secrets.compare_digest(api_key, _master_key)
|
|
if is_master_key:
|
|
return True
|
|
|
|
## hash comparison
|
|
is_master_key = secrets.compare_digest(api_key, hash_token(_master_key))
|
|
if is_master_key:
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def _get_spend_logs_metadata(
|
|
metadata: Optional[dict],
|
|
applied_guardrails: Optional[List[str]] = None,
|
|
batch_models: Optional[List[str]] = None,
|
|
mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall] = None,
|
|
) -> SpendLogsMetadata:
|
|
if metadata is None:
|
|
return SpendLogsMetadata(
|
|
user_api_key=None,
|
|
user_api_key_alias=None,
|
|
user_api_key_team_id=None,
|
|
user_api_key_org_id=None,
|
|
user_api_key_user_id=None,
|
|
user_api_key_team_alias=None,
|
|
spend_logs_metadata=None,
|
|
requester_ip_address=None,
|
|
additional_usage_values=None,
|
|
applied_guardrails=None,
|
|
status=None or "success",
|
|
error_information=None,
|
|
proxy_server_request=None,
|
|
batch_models=None,
|
|
mcp_tool_call_metadata=None,
|
|
)
|
|
verbose_proxy_logger.debug(
|
|
"getting payload for SpendLogs, available keys in metadata: "
|
|
+ str(list(metadata.keys()))
|
|
)
|
|
|
|
# Filter the metadata dictionary to include only the specified keys
|
|
clean_metadata = SpendLogsMetadata(
|
|
**{ # type: ignore
|
|
key: metadata[key]
|
|
for key in SpendLogsMetadata.__annotations__.keys()
|
|
if key in metadata
|
|
}
|
|
)
|
|
clean_metadata["applied_guardrails"] = applied_guardrails
|
|
clean_metadata["batch_models"] = batch_models
|
|
clean_metadata["mcp_tool_call_metadata"] = mcp_tool_call_metadata
|
|
return clean_metadata
|
|
|
|
|
|
def generate_hash_from_response(response_obj: Any) -> str:
|
|
"""
|
|
Generate a stable hash from a response object.
|
|
|
|
Args:
|
|
response_obj: The response object to hash (can be dict, list, etc.)
|
|
|
|
Returns:
|
|
A hex string representation of the MD5 hash
|
|
"""
|
|
try:
|
|
# Create a stable JSON string of the entire response object
|
|
# Sort keys to ensure consistent ordering
|
|
json_str = json.dumps(response_obj, sort_keys=True)
|
|
|
|
# Generate a hash of the response object
|
|
unique_hash = hashlib.md5(json_str.encode()).hexdigest()
|
|
return unique_hash
|
|
except Exception:
|
|
# Return a fallback hash if serialization fails
|
|
return hashlib.md5(str(response_obj).encode()).hexdigest()
|
|
|
|
|
|
def get_spend_logs_id(
|
|
call_type: str, response_obj: dict, kwargs: dict
|
|
) -> Optional[str]:
|
|
if call_type == "aretrieve_batch":
|
|
# Generate a hash from the response object
|
|
id: Optional[str] = generate_hash_from_response(response_obj)
|
|
else:
|
|
id = cast(Optional[str], response_obj.get("id")) or cast(
|
|
Optional[str], kwargs.get("litellm_call_id")
|
|
)
|
|
return id
|
|
|
|
|
|
def get_logging_payload( # noqa: PLR0915
|
|
kwargs, response_obj, start_time, end_time
|
|
) -> SpendLogsPayload:
|
|
from litellm.proxy.proxy_server import general_settings, master_key
|
|
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
if response_obj is None or (
|
|
not isinstance(response_obj, BaseModel) and not isinstance(response_obj, dict)
|
|
):
|
|
response_obj = {}
|
|
# standardize this function to be used across, s3, dynamoDB, langfuse logging
|
|
litellm_params = kwargs.get("litellm_params", {})
|
|
metadata = get_litellm_metadata_from_kwargs(kwargs)
|
|
metadata = _add_proxy_server_request_to_metadata(
|
|
metadata=metadata, litellm_params=litellm_params
|
|
)
|
|
completion_start_time = kwargs.get("completion_start_time", end_time)
|
|
call_type = kwargs.get("call_type")
|
|
cache_hit = kwargs.get("cache_hit", False)
|
|
usage = cast(dict, response_obj).get("usage", None) or {}
|
|
if isinstance(usage, litellm.Usage):
|
|
usage = dict(usage)
|
|
|
|
if isinstance(response_obj, dict):
|
|
response_obj_dict = response_obj
|
|
elif isinstance(response_obj, BaseModel):
|
|
response_obj_dict = response_obj.model_dump()
|
|
else:
|
|
response_obj_dict = {}
|
|
|
|
id = get_spend_logs_id(call_type or "acompletion", response_obj_dict, kwargs)
|
|
standard_logging_payload = cast(
|
|
Optional[StandardLoggingPayload], kwargs.get("standard_logging_object", None)
|
|
)
|
|
|
|
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
|
|
|
|
api_key = metadata.get("user_api_key", "")
|
|
|
|
if api_key is not None and isinstance(api_key, str):
|
|
if api_key.startswith("sk-"):
|
|
# hash the api_key
|
|
api_key = hash_token(api_key)
|
|
if (
|
|
_is_master_key(api_key=api_key, _master_key=master_key)
|
|
and general_settings.get("disable_adding_master_key_hash_to_db") is True
|
|
):
|
|
api_key = "litellm_proxy_master_key" # use a known alias, if the user disabled storing master key in db
|
|
|
|
if (
|
|
standard_logging_payload is not None
|
|
): # [TODO] migrate completely to sl payload. currently missing pass-through endpoint data
|
|
api_key = (
|
|
api_key
|
|
or standard_logging_payload["metadata"].get("user_api_key_hash")
|
|
or ""
|
|
)
|
|
end_user_id = end_user_id or standard_logging_payload["metadata"].get(
|
|
"user_api_key_end_user_id"
|
|
)
|
|
else:
|
|
api_key = ""
|
|
request_tags = (
|
|
json.dumps(metadata.get("tags", []))
|
|
if isinstance(metadata.get("tags", []), list)
|
|
else "[]"
|
|
)
|
|
if (
|
|
_is_master_key(api_key=api_key, _master_key=master_key)
|
|
and general_settings.get("disable_adding_master_key_hash_to_db") is True
|
|
):
|
|
api_key = "litellm_proxy_master_key" # use a known alias, if the user disabled storing master key in db
|
|
|
|
_model_id = metadata.get("model_info", {}).get("id", "")
|
|
_model_group = metadata.get("model_group", "")
|
|
|
|
# clean up litellm metadata
|
|
clean_metadata = _get_spend_logs_metadata(
|
|
metadata,
|
|
applied_guardrails=(
|
|
standard_logging_payload["metadata"].get("applied_guardrails", None)
|
|
if standard_logging_payload is not None
|
|
else None
|
|
),
|
|
batch_models=(
|
|
standard_logging_payload.get("hidden_params", {}).get("batch_models", None)
|
|
if standard_logging_payload is not None
|
|
else None
|
|
),
|
|
mcp_tool_call_metadata=(
|
|
standard_logging_payload["metadata"].get("mcp_tool_call_metadata", None)
|
|
if standard_logging_payload is not None
|
|
else None
|
|
),
|
|
)
|
|
|
|
special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"]
|
|
additional_usage_values = {}
|
|
for k, v in usage.items():
|
|
if k not in special_usage_fields:
|
|
if isinstance(v, BaseModel):
|
|
v = v.model_dump()
|
|
additional_usage_values.update({k: v})
|
|
clean_metadata["additional_usage_values"] = additional_usage_values
|
|
|
|
if litellm.cache is not None:
|
|
cache_key = litellm.cache.get_cache_key(**kwargs)
|
|
else:
|
|
cache_key = "Cache OFF"
|
|
if cache_hit is True:
|
|
import time
|
|
|
|
id = f"{id}_cache_hit{time.time()}" # SpendLogs does not allow duplicate request_id
|
|
try:
|
|
payload: SpendLogsPayload = SpendLogsPayload(
|
|
request_id=str(id),
|
|
call_type=call_type or "",
|
|
api_key=str(api_key),
|
|
cache_hit=str(cache_hit),
|
|
startTime=_ensure_datetime_utc(start_time),
|
|
endTime=_ensure_datetime_utc(end_time),
|
|
completionStartTime=_ensure_datetime_utc(completion_start_time),
|
|
model=kwargs.get("model", "") or "",
|
|
user=metadata.get("user_api_key_user_id", "") or "",
|
|
team_id=metadata.get("user_api_key_team_id", "") or "",
|
|
metadata=json.dumps(clean_metadata),
|
|
cache_key=cache_key,
|
|
spend=kwargs.get("response_cost", 0),
|
|
total_tokens=usage.get("total_tokens", 0),
|
|
prompt_tokens=usage.get("prompt_tokens", 0),
|
|
completion_tokens=usage.get("completion_tokens", 0),
|
|
request_tags=request_tags,
|
|
end_user=end_user_id or "",
|
|
api_base=litellm_params.get("api_base", ""),
|
|
model_group=_model_group,
|
|
model_id=_model_id,
|
|
requester_ip_address=clean_metadata.get("requester_ip_address", None),
|
|
custom_llm_provider=kwargs.get("custom_llm_provider", ""),
|
|
messages=_get_messages_for_spend_logs_payload(
|
|
standard_logging_payload=standard_logging_payload, metadata=metadata
|
|
),
|
|
response=_get_response_for_spend_logs_payload(standard_logging_payload),
|
|
)
|
|
|
|
verbose_proxy_logger.debug(
|
|
"SpendTable: created payload - payload: %s\n\n",
|
|
json.dumps(payload, indent=4, default=str),
|
|
)
|
|
|
|
return payload
|
|
except Exception as e:
|
|
verbose_proxy_logger.exception(
|
|
"Error creating spendlogs object - {}".format(str(e))
|
|
)
|
|
raise e
|
|
|
|
|
|
def _ensure_datetime_utc(timestamp: datetime) -> datetime:
|
|
"""Helper to ensure datetime is in UTC"""
|
|
timestamp = timestamp.astimezone(timezone.utc)
|
|
return timestamp
|
|
|
|
|
|
async def get_spend_by_team_and_customer(
|
|
start_date: dt,
|
|
end_date: dt,
|
|
team_id: str,
|
|
customer_id: str,
|
|
prisma_client: PrismaClient,
|
|
):
|
|
sql_query = """
|
|
WITH SpendByModelApiKey AS (
|
|
SELECT
|
|
date_trunc('day', sl."startTime") AS group_by_day,
|
|
COALESCE(tt.team_alias, 'Unassigned Team') AS team_name,
|
|
sl.end_user AS customer,
|
|
sl.model,
|
|
sl.api_key,
|
|
SUM(sl.spend) AS model_api_spend,
|
|
SUM(sl.total_tokens) AS model_api_tokens
|
|
FROM
|
|
"LiteLLM_SpendLogs" sl
|
|
LEFT JOIN
|
|
"LiteLLM_TeamTable" tt
|
|
ON
|
|
sl.team_id = tt.team_id
|
|
WHERE
|
|
sl."startTime" BETWEEN $1::date AND $2::date
|
|
AND sl.team_id = $3
|
|
AND sl.end_user = $4
|
|
GROUP BY
|
|
date_trunc('day', sl."startTime"),
|
|
tt.team_alias,
|
|
sl.end_user,
|
|
sl.model,
|
|
sl.api_key
|
|
)
|
|
SELECT
|
|
group_by_day,
|
|
jsonb_agg(jsonb_build_object(
|
|
'team_name', team_name,
|
|
'customer', customer,
|
|
'total_spend', total_spend,
|
|
'metadata', metadata
|
|
)) AS teams_customers
|
|
FROM (
|
|
SELECT
|
|
group_by_day,
|
|
team_name,
|
|
customer,
|
|
SUM(model_api_spend) AS total_spend,
|
|
jsonb_agg(jsonb_build_object(
|
|
'model', model,
|
|
'api_key', api_key,
|
|
'spend', model_api_spend,
|
|
'total_tokens', model_api_tokens
|
|
)) AS metadata
|
|
FROM
|
|
SpendByModelApiKey
|
|
GROUP BY
|
|
group_by_day,
|
|
team_name,
|
|
customer
|
|
) AS aggregated
|
|
GROUP BY
|
|
group_by_day
|
|
ORDER BY
|
|
group_by_day;
|
|
"""
|
|
|
|
db_response = await prisma_client.db.query_raw(
|
|
sql_query, start_date, end_date, team_id, customer_id
|
|
)
|
|
if db_response is None:
|
|
return []
|
|
|
|
return db_response
|
|
|
|
|
|
def _get_messages_for_spend_logs_payload(
|
|
standard_logging_payload: Optional[StandardLoggingPayload],
|
|
metadata: Optional[dict] = None,
|
|
) -> str:
|
|
return "{}"
|
|
|
|
|
|
def _sanitize_request_body_for_spend_logs_payload(
|
|
request_body: dict,
|
|
visited: Optional[set] = None,
|
|
) -> dict:
|
|
"""
|
|
Recursively sanitize request body to prevent logging large base64 strings or other large values.
|
|
Truncates strings longer than 1000 characters and handles nested dictionaries.
|
|
"""
|
|
MAX_STRING_LENGTH = 1000
|
|
|
|
if visited is None:
|
|
visited = set()
|
|
|
|
# Get the object's memory address to track visited objects
|
|
obj_id = id(request_body)
|
|
if obj_id in visited:
|
|
return {}
|
|
visited.add(obj_id)
|
|
|
|
def _sanitize_value(value: Any) -> Any:
|
|
if isinstance(value, dict):
|
|
return _sanitize_request_body_for_spend_logs_payload(value, visited)
|
|
elif isinstance(value, list):
|
|
return [_sanitize_value(item) for item in value]
|
|
elif isinstance(value, str):
|
|
if len(value) > MAX_STRING_LENGTH:
|
|
return f"{value[:MAX_STRING_LENGTH]}... (truncated {len(value) - MAX_STRING_LENGTH} chars)"
|
|
return value
|
|
return value
|
|
|
|
return {k: _sanitize_value(v) for k, v in request_body.items()}
|
|
|
|
|
|
def _add_proxy_server_request_to_metadata(
|
|
metadata: dict,
|
|
litellm_params: dict,
|
|
) -> dict:
|
|
"""
|
|
Only store if _should_store_prompts_and_responses_in_spend_logs() is True
|
|
"""
|
|
if _should_store_prompts_and_responses_in_spend_logs():
|
|
_proxy_server_request = cast(
|
|
Optional[dict], litellm_params.get("proxy_server_request", {})
|
|
)
|
|
if _proxy_server_request is not None:
|
|
_request_body = _proxy_server_request.get("body", {}) or {}
|
|
_request_body = _sanitize_request_body_for_spend_logs_payload(_request_body)
|
|
_request_body_json_str = json.dumps(_request_body, default=str)
|
|
metadata["proxy_server_request"] = _request_body_json_str
|
|
return metadata
|
|
|
|
|
|
def _get_response_for_spend_logs_payload(
|
|
payload: Optional[StandardLoggingPayload],
|
|
) -> str:
|
|
if payload is None:
|
|
return "{}"
|
|
if _should_store_prompts_and_responses_in_spend_logs():
|
|
return json.dumps(payload.get("response", {}))
|
|
return "{}"
|
|
|
|
|
|
def _should_store_prompts_and_responses_in_spend_logs() -> bool:
|
|
from litellm.proxy.proxy_server import general_settings
|
|
|
|
return general_settings.get("store_prompts_in_spend_logs") is True
|