litellm-mirror/litellm/proxy/spend_tracking/spend_tracking_utils.py
Krish Dholakia 1e011b66d3
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 13s
Ollama ssl verify = False + Spend Logs reliability fixes (#7931)
* fix(http_handler.py): support passing ssl verify dynamically and using the correct httpx client based on passed ssl verify param

Fixes https://github.com/BerriAI/litellm/issues/6499

* feat(llm_http_handler.py): support passing `ssl_verify=False` dynamically in call args

Closes https://github.com/BerriAI/litellm/issues/6499

* fix(proxy/utils.py): prevent bad logs from breaking all cost tracking + reset list regardless of success/failure

prevents malformed logs from causing all spend tracking to break since they're constantly retried

* test(test_proxy_utils.py): add test to ensure bad log is dropped

* test(test_proxy_utils.py): ensure in-memory spend logs reset after bad log error

* test(test_user_api_key_auth.py): add unit test to ensure end user id as str works

* fix(auth_utils.py): ensure extracted end user id is always a str

prevents db cost tracking errors

* test(test_auth_utils.py): ensure get end user id from request body always returns a string

* test: update tests

* test: skip bedrock test- behaviour now supported

* test: fix testing

* refactor(spend_tracking_utils.py): reduce size of get_logging_payload

* test: fix test

* bump: version 1.59.4 → 1.59.5

* Revert "bump: version 1.59.4 → 1.59.5"

This reverts commit 1182b46b2e.

* fix(utils.py): fix spend logs retry logic

* fix(spend_tracking_utils.py): fix get tags

* fix(spend_tracking_utils.py): fix end user id spend tracking on pass-through endpoints
2025-01-23 23:05:41 -08:00

276 lines
9.1 KiB
Python

import json
import secrets
from datetime import datetime as dt
from typing import Optional, cast
from pydantic import BaseModel
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
from litellm.proxy.utils import PrismaClient, hash_token
from litellm.types.utils import StandardLoggingPayload
from litellm.utils import get_end_user_id_for_cost_tracking
def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool:
if _master_key is None:
return False
## string comparison
is_master_key = secrets.compare_digest(api_key, _master_key)
if is_master_key:
return True
## hash comparison
is_master_key = secrets.compare_digest(api_key, hash_token(_master_key))
if is_master_key:
return True
return False
def _get_spend_logs_metadata(metadata: Optional[dict]) -> SpendLogsMetadata:
if metadata is None:
return SpendLogsMetadata(
user_api_key=None,
user_api_key_alias=None,
user_api_key_team_id=None,
user_api_key_user_id=None,
user_api_key_team_alias=None,
spend_logs_metadata=None,
requester_ip_address=None,
additional_usage_values=None,
)
verbose_proxy_logger.debug(
"getting payload for SpendLogs, available keys in metadata: "
+ str(list(metadata.keys()))
)
# Filter the metadata dictionary to include only the specified keys
clean_metadata = SpendLogsMetadata(
**{ # type: ignore
key: metadata[key]
for key in SpendLogsMetadata.__annotations__.keys()
if key in metadata
}
)
return clean_metadata
def get_logging_payload(kwargs, response_obj, start_time, end_time) -> SpendLogsPayload:
from litellm.proxy.proxy_server import general_settings, master_key
if kwargs is None:
kwargs = {}
if response_obj is None or (
not isinstance(response_obj, BaseModel) and not isinstance(response_obj, dict)
):
response_obj = {}
# standardize this function to be used across, s3, dynamoDB, langfuse logging
litellm_params = kwargs.get("litellm_params", {})
metadata = (
litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None
completion_start_time = kwargs.get("completion_start_time", end_time)
call_type = kwargs.get("call_type")
cache_hit = kwargs.get("cache_hit", False)
usage = cast(dict, response_obj).get("usage", None) or {}
if isinstance(usage, litellm.Usage):
usage = dict(usage)
id = cast(dict, response_obj).get("id") or kwargs.get("litellm_call_id")
standard_logging_payload = cast(
Optional[StandardLoggingPayload], kwargs.get("standard_logging_object", None)
)
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
if standard_logging_payload is not None:
api_key = standard_logging_payload["metadata"].get("user_api_key_hash") or ""
end_user_id = end_user_id or standard_logging_payload["metadata"].get(
"user_api_key_end_user_id"
)
else:
api_key = ""
request_tags = (
json.dumps(metadata.get("tags", []))
if isinstance(metadata.get("tags", []), list)
else "[]"
)
if (
_is_master_key(api_key=api_key, _master_key=master_key)
and general_settings.get("disable_adding_master_key_hash_to_db") is True
):
api_key = "litellm_proxy_master_key" # use a known alias, if the user disabled storing master key in db
_model_id = metadata.get("model_info", {}).get("id", "")
_model_group = metadata.get("model_group", "")
# clean up litellm metadata
clean_metadata = _get_spend_logs_metadata(metadata)
special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"]
additional_usage_values = {}
for k, v in usage.items():
if k not in special_usage_fields:
if isinstance(v, BaseModel):
v = v.model_dump()
additional_usage_values.update({k: v})
clean_metadata["additional_usage_values"] = additional_usage_values
if litellm.cache is not None:
cache_key = litellm.cache.get_cache_key(**kwargs)
else:
cache_key = "Cache OFF"
if cache_hit is True:
import time
id = f"{id}_cache_hit{time.time()}" # SpendLogs does not allow duplicate request_id
try:
payload: SpendLogsPayload = SpendLogsPayload(
request_id=str(id),
call_type=call_type or "",
api_key=str(api_key),
cache_hit=str(cache_hit),
startTime=start_time,
endTime=end_time,
completionStartTime=completion_start_time,
model=kwargs.get("model", "") or "",
user=kwargs.get("litellm_params", {})
.get("metadata", {})
.get("user_api_key_user_id", "")
or "",
team_id=kwargs.get("litellm_params", {})
.get("metadata", {})
.get("user_api_key_team_id", "")
or "",
metadata=json.dumps(clean_metadata),
cache_key=cache_key,
spend=kwargs.get("response_cost", 0),
total_tokens=usage.get("total_tokens", 0),
prompt_tokens=usage.get("prompt_tokens", 0),
completion_tokens=usage.get("completion_tokens", 0),
request_tags=request_tags,
end_user=end_user_id or "",
api_base=litellm_params.get("api_base", ""),
model_group=_model_group,
model_id=_model_id,
requester_ip_address=clean_metadata.get("requester_ip_address", None),
custom_llm_provider=kwargs.get("custom_llm_provider", ""),
messages=_get_messages_for_spend_logs_payload(standard_logging_payload),
response=_get_response_for_spend_logs_payload(standard_logging_payload),
)
verbose_proxy_logger.debug(
"SpendTable: created payload - payload: %s\n\n",
json.dumps(payload, indent=4, default=str),
)
return payload
except Exception as e:
verbose_proxy_logger.exception(
"Error creating spendlogs object - {}".format(str(e))
)
raise e
async def get_spend_by_team_and_customer(
start_date: dt,
end_date: dt,
team_id: str,
customer_id: str,
prisma_client: PrismaClient,
):
sql_query = """
WITH SpendByModelApiKey AS (
SELECT
date_trunc('day', sl."startTime") AS group_by_day,
COALESCE(tt.team_alias, 'Unassigned Team') AS team_name,
sl.end_user AS customer,
sl.model,
sl.api_key,
SUM(sl.spend) AS model_api_spend,
SUM(sl.total_tokens) AS model_api_tokens
FROM
"LiteLLM_SpendLogs" sl
LEFT JOIN
"LiteLLM_TeamTable" tt
ON
sl.team_id = tt.team_id
WHERE
sl."startTime" BETWEEN $1::date AND $2::date
AND sl.team_id = $3
AND sl.end_user = $4
GROUP BY
date_trunc('day', sl."startTime"),
tt.team_alias,
sl.end_user,
sl.model,
sl.api_key
)
SELECT
group_by_day,
jsonb_agg(jsonb_build_object(
'team_name', team_name,
'customer', customer,
'total_spend', total_spend,
'metadata', metadata
)) AS teams_customers
FROM (
SELECT
group_by_day,
team_name,
customer,
SUM(model_api_spend) AS total_spend,
jsonb_agg(jsonb_build_object(
'model', model,
'api_key', api_key,
'spend', model_api_spend,
'total_tokens', model_api_tokens
)) AS metadata
FROM
SpendByModelApiKey
GROUP BY
group_by_day,
team_name,
customer
) AS aggregated
GROUP BY
group_by_day
ORDER BY
group_by_day;
"""
db_response = await prisma_client.db.query_raw(
sql_query, start_date, end_date, team_id, customer_id
)
if db_response is None:
return []
return db_response
def _get_messages_for_spend_logs_payload(
payload: Optional[StandardLoggingPayload],
) -> str:
if payload is None:
return "{}"
if _should_store_prompts_and_responses_in_spend_logs():
return json.dumps(payload.get("messages", {}))
return "{}"
def _get_response_for_spend_logs_payload(
payload: Optional[StandardLoggingPayload],
) -> str:
if payload is None:
return "{}"
if _should_store_prompts_and_responses_in_spend_logs():
return json.dumps(payload.get("response", {}))
return "{}"
def _should_store_prompts_and_responses_in_spend_logs() -> bool:
from litellm.proxy.proxy_server import general_settings
return general_settings.get("store_prompts_in_spend_logs") is True