Add aggregate spend by tag (#10071)

* feat: initial commit adding daily tag spend table to db

* feat(db_spend_update_writer.py): correctly log tag spend transactions

* build(schema.prisma): add new tag table to root

* build: add new migration file
This commit is contained in:
Krish Dholakia 2025-04-16 12:26:21 -07:00 committed by GitHub
parent 47e811d6ce
commit d8a1071bc4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 565 additions and 373 deletions

View file

@ -0,0 +1,45 @@
-- AlterTable
ALTER TABLE "LiteLLM_DailyTeamSpend" ADD COLUMN "cache_creation_input_tokens" INTEGER NOT NULL DEFAULT 0,
ADD COLUMN "cache_read_input_tokens" INTEGER NOT NULL DEFAULT 0;
-- CreateTable
CREATE TABLE "LiteLLM_DailyTagSpend" (
"id" TEXT NOT NULL,
"tag" TEXT NOT NULL,
"date" TEXT NOT NULL,
"api_key" TEXT NOT NULL,
"model" TEXT NOT NULL,
"model_group" TEXT,
"custom_llm_provider" TEXT,
"prompt_tokens" INTEGER NOT NULL DEFAULT 0,
"completion_tokens" INTEGER NOT NULL DEFAULT 0,
"cache_read_input_tokens" INTEGER NOT NULL DEFAULT 0,
"cache_creation_input_tokens" INTEGER NOT NULL DEFAULT 0,
"spend" DOUBLE PRECISION NOT NULL DEFAULT 0.0,
"api_requests" INTEGER NOT NULL DEFAULT 0,
"successful_requests" INTEGER NOT NULL DEFAULT 0,
"failed_requests" INTEGER NOT NULL DEFAULT 0,
"created_at" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMP(3) NOT NULL,
CONSTRAINT "LiteLLM_DailyTagSpend_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "LiteLLM_DailyTagSpend_tag_key" ON "LiteLLM_DailyTagSpend"("tag");
-- CreateIndex
CREATE INDEX "LiteLLM_DailyTagSpend_date_idx" ON "LiteLLM_DailyTagSpend"("date");
-- CreateIndex
CREATE INDEX "LiteLLM_DailyTagSpend_tag_idx" ON "LiteLLM_DailyTagSpend"("tag");
-- CreateIndex
CREATE INDEX "LiteLLM_DailyTagSpend_api_key_idx" ON "LiteLLM_DailyTagSpend"("api_key");
-- CreateIndex
CREATE INDEX "LiteLLM_DailyTagSpend_model_idx" ON "LiteLLM_DailyTagSpend"("model");
-- CreateIndex
CREATE UNIQUE INDEX "LiteLLM_DailyTagSpend_tag_date_api_key_model_custom_llm_pro_key" ON "LiteLLM_DailyTagSpend"("tag", "date", "api_key", "model", "custom_llm_provider");

View file

@ -342,6 +342,60 @@ model LiteLLM_DailyUserSpend {
@@index([model])
}
// Track daily team spend metrics per model and key
model LiteLLM_DailyTeamSpend {
id String @id @default(uuid())
team_id String
date String
api_key String
model String
model_group String?
custom_llm_provider String?
prompt_tokens Int @default(0)
completion_tokens Int @default(0)
cache_read_input_tokens Int @default(0)
cache_creation_input_tokens Int @default(0)
spend Float @default(0.0)
api_requests Int @default(0)
successful_requests Int @default(0)
failed_requests Int @default(0)
created_at DateTime @default(now())
updated_at DateTime @updatedAt
@@unique([team_id, date, api_key, model, custom_llm_provider])
@@index([date])
@@index([team_id])
@@index([api_key])
@@index([model])
}
// Track daily team spend metrics per model and key
model LiteLLM_DailyTagSpend {
id String @id @default(uuid())
tag String @unique
date String
api_key String
model String
model_group String?
custom_llm_provider String?
prompt_tokens Int @default(0)
completion_tokens Int @default(0)
cache_read_input_tokens Int @default(0)
cache_creation_input_tokens Int @default(0)
spend Float @default(0.0)
api_requests Int @default(0)
successful_requests Int @default(0)
failed_requests Int @default(0)
created_at DateTime @default(now())
updated_at DateTime @updatedAt
@@unique([tag, date, api_key, model, custom_llm_provider])
@@index([date])
@@index([tag])
@@index([api_key])
@@index([model])
}
// Track the status of cron jobs running. Only allow one pod to run the job at a time
model LiteLLM_CronJob {

View file

@ -28,6 +28,7 @@ _DEFAULT_TTL_FOR_HTTPX_CLIENTS = 3600 # 1 hour, re-use the same httpx client fo
REDIS_UPDATE_BUFFER_KEY = "litellm_spend_update_buffer"
REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY = "litellm_daily_spend_update_buffer"
REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY = "litellm_daily_team_spend_update_buffer"
REDIS_DAILY_TAG_SPEND_UPDATE_BUFFER_KEY = "litellm_daily_tag_spend_update_buffer"
MAX_REDIS_BUFFER_DEQUEUE_COUNT = 100
MAX_SIZE_IN_MEMORY_QUEUE = 10000
MAX_IN_MEMORY_QUEUE_FLUSH_COUNT = 1000

File diff suppressed because one or more lines are too long

View file

@ -650,9 +650,9 @@ class GenerateRequestBase(LiteLLMPydanticObjectBase):
allowed_cache_controls: Optional[list] = []
config: Optional[dict] = {}
permissions: Optional[dict] = {}
model_max_budget: Optional[dict] = (
{}
) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}
model_max_budget: Optional[
dict
] = {} # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}
model_config = ConfigDict(protected_namespaces=())
model_rpm_limit: Optional[dict] = None
@ -908,12 +908,12 @@ class NewCustomerRequest(BudgetNewRequest):
alias: Optional[str] = None # human-friendly alias
blocked: bool = False # allow/disallow requests for this end-user
budget_id: Optional[str] = None # give either a budget_id or max_budget
allowed_model_region: Optional[AllowedModelRegion] = (
None # require all user requests to use models in this specific region
)
default_model: Optional[str] = (
None # if no equivalent model in allowed region - default all requests to this model
)
allowed_model_region: Optional[
AllowedModelRegion
] = None # require all user requests to use models in this specific region
default_model: Optional[
str
] = None # if no equivalent model in allowed region - default all requests to this model
@model_validator(mode="before")
@classmethod
@ -935,12 +935,12 @@ class UpdateCustomerRequest(LiteLLMPydanticObjectBase):
blocked: bool = False # allow/disallow requests for this end-user
max_budget: Optional[float] = None
budget_id: Optional[str] = None # give either a budget_id or max_budget
allowed_model_region: Optional[AllowedModelRegion] = (
None # require all user requests to use models in this specific region
)
default_model: Optional[str] = (
None # if no equivalent model in allowed region - default all requests to this model
)
allowed_model_region: Optional[
AllowedModelRegion
] = None # require all user requests to use models in this specific region
default_model: Optional[
str
] = None # if no equivalent model in allowed region - default all requests to this model
class DeleteCustomerRequest(LiteLLMPydanticObjectBase):
@ -1076,9 +1076,9 @@ class BlockKeyRequest(LiteLLMPydanticObjectBase):
class AddTeamCallback(LiteLLMPydanticObjectBase):
callback_name: str
callback_type: Optional[Literal["success", "failure", "success_and_failure"]] = (
"success_and_failure"
)
callback_type: Optional[
Literal["success", "failure", "success_and_failure"]
] = "success_and_failure"
callback_vars: Dict[str, str]
@model_validator(mode="before")
@ -1335,9 +1335,9 @@ class ConfigList(LiteLLMPydanticObjectBase):
stored_in_db: Optional[bool]
field_default_value: Any
premium_field: bool = False
nested_fields: Optional[List[FieldDetail]] = (
None # For nested dictionary or Pydantic fields
)
nested_fields: Optional[
List[FieldDetail]
] = None # For nested dictionary or Pydantic fields
class ConfigGeneralSettings(LiteLLMPydanticObjectBase):
@ -1604,9 +1604,9 @@ class LiteLLM_OrganizationMembershipTable(LiteLLMPydanticObjectBase):
budget_id: Optional[str] = None
created_at: datetime
updated_at: datetime
user: Optional[Any] = (
None # You might want to replace 'Any' with a more specific type if available
)
user: Optional[
Any
] = None # You might want to replace 'Any' with a more specific type if available
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
model_config = ConfigDict(protected_namespaces=())
@ -2354,9 +2354,9 @@ class TeamModelDeleteRequest(BaseModel):
# Organization Member Requests
class OrganizationMemberAddRequest(OrgMemberAddRequest):
organization_id: str
max_budget_in_organization: Optional[float] = (
None # Users max budget within the organization
)
max_budget_in_organization: Optional[
float
] = None # Users max budget within the organization
class OrganizationMemberDeleteRequest(MemberDeleteRequest):
@ -2545,9 +2545,9 @@ class ProviderBudgetResponse(LiteLLMPydanticObjectBase):
Maps provider names to their budget configs.
"""
providers: Dict[str, ProviderBudgetResponseObject] = (
{}
) # Dictionary mapping provider names to their budget configurations
providers: Dict[
str, ProviderBudgetResponseObject
] = {} # Dictionary mapping provider names to their budget configurations
class ProxyStateVariables(TypedDict):
@ -2675,9 +2675,9 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
enforce_rbac: bool = False
roles_jwt_field: Optional[str] = None # v2 on role mappings
role_mappings: Optional[List[RoleMapping]] = None
object_id_jwt_field: Optional[str] = (
None # can be either user / team, inferred from the role mapping
)
object_id_jwt_field: Optional[
str
] = None # can be either user / team, inferred from the role mapping
scope_mappings: Optional[List[ScopeMapping]] = None
enforce_scope_based_access: bool = False
enforce_team_based_model_access: bool = False
@ -2799,6 +2799,10 @@ class DailyUserSpendTransaction(BaseDailySpendTransaction):
user_id: str
class DailyTagSpendTransaction(BaseDailySpendTransaction):
tag: str
class DBSpendUpdateTransactions(TypedDict):
"""
Internal Data Structure for buffering spend updates in Redis or in memory before committing them to the database

View file

@ -11,7 +11,7 @@ import os
import time
import traceback
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union, cast, overload
import litellm
from litellm._logging import verbose_proxy_logger
@ -20,6 +20,7 @@ from litellm.constants import DB_SPEND_UPDATE_JOB_NAME
from litellm.proxy._types import (
DB_CONNECTION_ERROR_TYPES,
BaseDailySpendTransaction,
DailyTagSpendTransaction,
DailyTeamSpendTransaction,
DailyUserSpendTransaction,
DBSpendUpdateTransactions,
@ -61,6 +62,7 @@ class DBSpendUpdateWriter:
self.spend_update_queue = SpendUpdateQueue()
self.daily_spend_update_queue = DailySpendUpdateQueue()
self.daily_team_spend_update_queue = DailySpendUpdateQueue()
self.daily_tag_spend_update_queue = DailySpendUpdateQueue()
async def update_database(
# LiteLLM management object fields
@ -170,6 +172,13 @@ class DBSpendUpdateWriter:
)
)
asyncio.create_task(
self.add_spend_log_transaction_to_daily_tag_transaction(
payload=payload,
prisma_client=prisma_client,
)
)
verbose_proxy_logger.debug("Runs spend update on all tables")
except Exception:
verbose_proxy_logger.debug(
@ -394,6 +403,7 @@ class DBSpendUpdateWriter:
spend_update_queue=self.spend_update_queue,
daily_spend_update_queue=self.daily_spend_update_queue,
daily_team_spend_update_queue=self.daily_team_spend_update_queue,
daily_tag_spend_update_queue=self.daily_tag_spend_update_queue,
)
# Only commit from redis to db if this pod is the leader
@ -495,6 +505,20 @@ class DBSpendUpdateWriter:
daily_spend_transactions=daily_team_spend_update_transactions,
)
################## Daily Tag Spend Update Transactions ##################
# Aggregate all in memory daily tag spend transactions and commit to db
daily_tag_spend_update_transactions = cast(
Dict[str, DailyTagSpendTransaction],
await self.daily_tag_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions(),
)
await DBSpendUpdateWriter.update_daily_tag_spend(
n_retry_times=n_retry_times,
prisma_client=prisma_client,
proxy_logging_obj=proxy_logging_obj,
daily_spend_transactions=daily_tag_spend_update_transactions,
)
async def _commit_spend_updates_to_db( # noqa: PLR0915
self,
prisma_client: PrismaClient,
@ -740,6 +764,208 @@ class DBSpendUpdateWriter:
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
@overload
@staticmethod
async def _update_daily_spend(
n_retry_times: int,
prisma_client: PrismaClient,
proxy_logging_obj: ProxyLogging,
daily_spend_transactions: Dict[str, DailyUserSpendTransaction],
entity_type: Literal["user"],
entity_id_field: str,
table_name: str,
unique_constraint_name: str,
) -> None:
...
@overload
@staticmethod
async def _update_daily_spend(
n_retry_times: int,
prisma_client: PrismaClient,
proxy_logging_obj: ProxyLogging,
daily_spend_transactions: Dict[str, DailyTeamSpendTransaction],
entity_type: Literal["team"],
entity_id_field: str,
table_name: str,
unique_constraint_name: str,
) -> None:
...
@overload
@staticmethod
async def _update_daily_spend(
n_retry_times: int,
prisma_client: PrismaClient,
proxy_logging_obj: ProxyLogging,
daily_spend_transactions: Dict[str, DailyTagSpendTransaction],
entity_type: Literal["tag"],
entity_id_field: str,
table_name: str,
unique_constraint_name: str,
) -> None:
...
@staticmethod
async def _update_daily_spend(
n_retry_times: int,
prisma_client: PrismaClient,
proxy_logging_obj: ProxyLogging,
daily_spend_transactions: Union[
Dict[str, DailyUserSpendTransaction],
Dict[str, DailyTeamSpendTransaction],
Dict[str, DailyTagSpendTransaction],
],
entity_type: Literal["user", "team", "tag"],
entity_id_field: str,
table_name: str,
unique_constraint_name: str,
) -> None:
"""
Generic function to update daily spend for any entity type (user, team, tag)
"""
from litellm.proxy.utils import _raise_failed_update_spend_exception
verbose_proxy_logger.debug(
f"Daily {entity_type.capitalize()} Spend transactions: {len(daily_spend_transactions)}"
)
BATCH_SIZE = 100
start_time = time.time()
try:
for i in range(n_retry_times + 1):
try:
transactions_to_process = dict(
list(daily_spend_transactions.items())[:BATCH_SIZE]
)
if len(transactions_to_process) == 0:
verbose_proxy_logger.debug(
f"No new transactions to process for daily {entity_type} spend update"
)
break
async with prisma_client.db.batch_() as batcher:
for _, transaction in transactions_to_process.items():
entity_id = transaction.get(entity_id_field)
if not entity_id:
continue
# Construct the where clause dynamically
where_clause = {
unique_constraint_name: {
entity_id_field: entity_id,
"date": transaction["date"],
"api_key": transaction["api_key"],
"model": transaction["model"],
"custom_llm_provider": transaction.get(
"custom_llm_provider"
),
}
}
# Get the table dynamically
table = getattr(batcher, table_name)
# Common data structure for both create and update
common_data = {
entity_id_field: entity_id,
"date": transaction["date"],
"api_key": transaction["api_key"],
"model": transaction["model"],
"model_group": transaction.get("model_group"),
"custom_llm_provider": transaction.get(
"custom_llm_provider"
),
"prompt_tokens": transaction["prompt_tokens"],
"completion_tokens": transaction["completion_tokens"],
"spend": transaction["spend"],
"api_requests": transaction["api_requests"],
"successful_requests": transaction[
"successful_requests"
],
"failed_requests": transaction["failed_requests"],
}
# Add cache-related fields if they exist
if "cache_read_input_tokens" in transaction:
common_data[
"cache_read_input_tokens"
] = transaction.get("cache_read_input_tokens", 0)
if "cache_creation_input_tokens" in transaction:
common_data[
"cache_creation_input_tokens"
] = transaction.get("cache_creation_input_tokens", 0)
# Create update data structure
update_data = {
"prompt_tokens": {
"increment": transaction["prompt_tokens"]
},
"completion_tokens": {
"increment": transaction["completion_tokens"]
},
"spend": {"increment": transaction["spend"]},
"api_requests": {
"increment": transaction["api_requests"]
},
"successful_requests": {
"increment": transaction["successful_requests"]
},
"failed_requests": {
"increment": transaction["failed_requests"]
},
}
# Add cache-related fields to update if they exist
if "cache_read_input_tokens" in transaction:
update_data["cache_read_input_tokens"] = {
"increment": transaction.get(
"cache_read_input_tokens", 0
)
}
if "cache_creation_input_tokens" in transaction:
update_data["cache_creation_input_tokens"] = {
"increment": transaction.get(
"cache_creation_input_tokens", 0
)
}
table.upsert(
where=where_clause,
data={
"create": common_data,
"update": update_data,
},
)
verbose_proxy_logger.info(
f"Processed {len(transactions_to_process)} daily {entity_type} transactions in {time.time() - start_time:.2f}s"
)
# Remove processed transactions
for key in transactions_to_process.keys():
daily_spend_transactions.pop(key, None)
break
except DB_CONNECTION_ERROR_TYPES as e:
if i >= n_retry_times:
_raise_failed_update_spend_exception(
e=e,
start_time=start_time,
proxy_logging_obj=proxy_logging_obj,
)
await asyncio.sleep(2**i)
except Exception as e:
if "transactions_to_process" in locals():
for key in transactions_to_process.keys(): # type: ignore
daily_spend_transactions.pop(key, None)
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
@staticmethod
async def update_daily_user_spend(
n_retry_times: int,
@ -750,144 +976,16 @@ class DBSpendUpdateWriter:
"""
Batch job to update LiteLLM_DailyUserSpend table using in-memory daily_spend_transactions
"""
from litellm.proxy.utils import _raise_failed_update_spend_exception
### UPDATE DAILY USER SPEND ###
verbose_proxy_logger.debug(
"Daily User Spend transactions: {}".format(len(daily_spend_transactions))
await DBSpendUpdateWriter._update_daily_spend(
n_retry_times=n_retry_times,
prisma_client=prisma_client,
proxy_logging_obj=proxy_logging_obj,
daily_spend_transactions=daily_spend_transactions,
entity_type="user",
entity_id_field="user_id",
table_name="litellm_dailyuserspend",
unique_constraint_name="user_id_date_api_key_model_custom_llm_provider",
)
BATCH_SIZE = (
100 # Number of aggregated records to update in each database operation
)
start_time = time.time()
try:
for i in range(n_retry_times + 1):
try:
# Get transactions to process
transactions_to_process = dict(
list(daily_spend_transactions.items())[:BATCH_SIZE]
)
if len(transactions_to_process) == 0:
verbose_proxy_logger.debug(
"No new transactions to process for daily spend update"
)
break
# Update DailyUserSpend table in batches
async with prisma_client.db.batch_() as batcher:
for _, transaction in transactions_to_process.items():
user_id = transaction.get("user_id")
if not user_id: # Skip if no user_id
continue
batcher.litellm_dailyuserspend.upsert(
where={
"user_id_date_api_key_model_custom_llm_provider": {
"user_id": user_id,
"date": transaction["date"],
"api_key": transaction["api_key"],
"model": transaction["model"],
"custom_llm_provider": transaction.get(
"custom_llm_provider"
),
}
},
data={
"create": {
"user_id": user_id,
"date": transaction["date"],
"api_key": transaction["api_key"],
"model": transaction["model"],
"model_group": transaction.get("model_group"),
"custom_llm_provider": transaction.get(
"custom_llm_provider"
),
"prompt_tokens": transaction["prompt_tokens"],
"completion_tokens": transaction[
"completion_tokens"
],
"cache_read_input_tokens": transaction.get(
"cache_read_input_tokens", 0
),
"cache_creation_input_tokens": transaction.get(
"cache_creation_input_tokens", 0
),
"spend": transaction["spend"],
"api_requests": transaction["api_requests"],
"successful_requests": transaction[
"successful_requests"
],
"failed_requests": transaction[
"failed_requests"
],
},
"update": {
"prompt_tokens": {
"increment": transaction["prompt_tokens"]
},
"completion_tokens": {
"increment": transaction[
"completion_tokens"
]
},
"cache_read_input_tokens": {
"increment": transaction.get(
"cache_read_input_tokens", 0
)
},
"cache_creation_input_tokens": {
"increment": transaction.get(
"cache_creation_input_tokens", 0
)
},
"spend": {"increment": transaction["spend"]},
"api_requests": {
"increment": transaction["api_requests"]
},
"successful_requests": {
"increment": transaction[
"successful_requests"
]
},
"failed_requests": {
"increment": transaction["failed_requests"]
},
},
},
)
verbose_proxy_logger.info(
f"Processed {len(transactions_to_process)} daily spend transactions in {time.time() - start_time:.2f}s"
)
# Remove processed transactions
for key in transactions_to_process.keys():
daily_spend_transactions.pop(key, None)
verbose_proxy_logger.debug(
f"Processed {len(transactions_to_process)} daily spend transactions in {time.time() - start_time:.2f}s"
)
break
except DB_CONNECTION_ERROR_TYPES as e:
if i >= n_retry_times:
_raise_failed_update_spend_exception(
e=e,
start_time=start_time,
proxy_logging_obj=proxy_logging_obj,
)
await asyncio.sleep(2**i) # Exponential backoff
except Exception as e:
# Remove processed transactions even if there was an error
if "transactions_to_process" in locals():
for key in transactions_to_process.keys(): # type: ignore
daily_spend_transactions.pop(key, None)
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
@staticmethod
async def update_daily_team_spend(
@ -899,140 +997,53 @@ class DBSpendUpdateWriter:
"""
Batch job to update LiteLLM_DailyTeamSpend table using in-memory daily_spend_transactions
"""
from litellm.proxy.utils import _raise_failed_update_spend_exception
### UPDATE DAILY USER SPEND ###
verbose_proxy_logger.debug(
"Daily Team Spend transactions: {}".format(len(daily_spend_transactions))
await DBSpendUpdateWriter._update_daily_spend(
n_retry_times=n_retry_times,
prisma_client=prisma_client,
proxy_logging_obj=proxy_logging_obj,
daily_spend_transactions=daily_spend_transactions,
entity_type="team",
entity_id_field="team_id",
table_name="litellm_dailyteamspend",
unique_constraint_name="team_id_date_api_key_model_custom_llm_provider",
)
BATCH_SIZE = (
100 # Number of aggregated records to update in each database operation
@staticmethod
async def update_daily_tag_spend(
n_retry_times: int,
prisma_client: PrismaClient,
proxy_logging_obj: ProxyLogging,
daily_spend_transactions: Dict[str, DailyTagSpendTransaction],
):
"""
Batch job to update LiteLLM_DailyTagSpend table using in-memory daily_spend_transactions
"""
await DBSpendUpdateWriter._update_daily_spend(
n_retry_times=n_retry_times,
prisma_client=prisma_client,
proxy_logging_obj=proxy_logging_obj,
daily_spend_transactions=daily_spend_transactions,
entity_type="tag",
entity_id_field="tag",
table_name="litellm_dailytagspend",
unique_constraint_name="tag_date_api_key_model_custom_llm_provider",
)
start_time = time.time()
try:
for i in range(n_retry_times + 1):
try:
# Get transactions to process
transactions_to_process = dict(
list(daily_spend_transactions.items())[:BATCH_SIZE]
)
if len(transactions_to_process) == 0:
verbose_proxy_logger.debug(
"No new transactions to process for daily spend update"
)
break
# Update DailyUserSpend table in batches
async with prisma_client.db.batch_() as batcher:
for _, transaction in transactions_to_process.items():
team_id = transaction.get("team_id")
if not team_id: # Skip if no team_id
continue
batcher.litellm_dailyteamspend.upsert(
where={
"team_id_date_api_key_model_custom_llm_provider": {
"team_id": team_id,
"date": transaction["date"],
"api_key": transaction["api_key"],
"model": transaction["model"],
"custom_llm_provider": transaction.get(
"custom_llm_provider"
),
}
},
data={
"create": {
"team_id": team_id,
"date": transaction["date"],
"api_key": transaction["api_key"],
"model": transaction["model"],
"model_group": transaction.get("model_group"),
"custom_llm_provider": transaction.get(
"custom_llm_provider"
),
"prompt_tokens": transaction["prompt_tokens"],
"completion_tokens": transaction[
"completion_tokens"
],
"spend": transaction["spend"],
"api_requests": transaction["api_requests"],
"successful_requests": transaction[
"successful_requests"
],
"failed_requests": transaction[
"failed_requests"
],
},
"update": {
"prompt_tokens": {
"increment": transaction["prompt_tokens"]
},
"completion_tokens": {
"increment": transaction[
"completion_tokens"
]
},
"spend": {"increment": transaction["spend"]},
"api_requests": {
"increment": transaction["api_requests"]
},
"successful_requests": {
"increment": transaction[
"successful_requests"
]
},
"failed_requests": {
"increment": transaction["failed_requests"]
},
},
},
)
verbose_proxy_logger.info(
f"Processed {len(transactions_to_process)} daily team transactions in {time.time() - start_time:.2f}s"
)
# Remove processed transactions
for key in transactions_to_process.keys():
daily_spend_transactions.pop(key, None)
verbose_proxy_logger.debug(
f"Processed {len(transactions_to_process)} daily spend transactions in {time.time() - start_time:.2f}s"
)
break
except DB_CONNECTION_ERROR_TYPES as e:
if i >= n_retry_times:
_raise_failed_update_spend_exception(
e=e,
start_time=start_time,
proxy_logging_obj=proxy_logging_obj,
)
await asyncio.sleep(2**i) # Exponential backoff
except Exception as e:
# Remove processed transactions even if there was an error
if "transactions_to_process" in locals():
for key in transactions_to_process.keys(): # type: ignore
daily_spend_transactions.pop(key, None)
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
async def _common_add_spend_log_transaction_to_daily_transaction(
self,
payload: Union[dict, SpendLogsPayload],
prisma_client: PrismaClient,
type: Literal["user", "team"] = "user",
type: Literal["user", "team", "request_tags"] = "user",
) -> Optional[BaseDailySpendTransaction]:
common_expected_keys = ["startTime", "api_key", "model", "custom_llm_provider"]
if type == "user":
expected_keys = ["user", *common_expected_keys]
else:
elif type == "team":
expected_keys = ["team_id", *common_expected_keys]
elif type == "request_tags":
expected_keys = ["request_tags", *common_expected_keys]
else:
raise ValueError(f"Invalid type: {type}")
if not all(key in payload for key in expected_keys):
verbose_proxy_logger.debug(
@ -1143,3 +1154,44 @@ class DBSpendUpdateWriter:
await self.daily_team_spend_update_queue.add_update(
update={daily_transaction_key: daily_transaction}
)
async def add_spend_log_transaction_to_daily_tag_transaction(
self,
payload: SpendLogsPayload,
prisma_client: Optional[PrismaClient] = None,
) -> None:
if prisma_client is None:
verbose_proxy_logger.debug(
"prisma_client is None. Skipping writing spend logs to db."
)
return
base_daily_transaction = (
await self._common_add_spend_log_transaction_to_daily_transaction(
payload, prisma_client, "request_tags"
)
)
if base_daily_transaction is None:
return
if payload["request_tags"] is None:
verbose_proxy_logger.debug(
"request_tags is None for request. Skipping incrementing tag spend."
)
return
request_tags = []
if isinstance(payload["request_tags"], str):
request_tags = json.loads(payload["request_tags"])
elif isinstance(payload["request_tags"], list):
request_tags = payload["request_tags"]
else:
raise ValueError(f"Invalid request_tags: {payload['request_tags']}")
for tag in request_tags:
daily_transaction_key = f"{tag}_{base_daily_transaction['date']}_{payload['api_key']}_{payload['model']}_{payload['custom_llm_provider']}"
daily_transaction = DailyTagSpendTransaction(
tag=tag, **base_daily_transaction
)
await self.daily_tag_spend_update_queue.add_update(
update={daily_transaction_key: daily_transaction}
)

View file

@ -13,6 +13,7 @@ from litellm.caching import RedisCache
from litellm.constants import (
MAX_REDIS_BUFFER_DEQUEUE_COUNT,
REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY,
REDIS_DAILY_TAG_SPEND_UPDATE_BUFFER_KEY,
REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY,
REDIS_UPDATE_BUFFER_KEY,
)
@ -68,11 +69,41 @@ class RedisUpdateBuffer:
return False
return _use_redis_transaction_buffer
async def _store_transactions_in_redis(
self,
transactions: Any,
redis_key: str,
service_type: ServiceTypes,
) -> None:
"""
Helper method to store transactions in Redis and emit an event
Args:
transactions: The transactions to store
redis_key: The Redis key to store under
service_type: The service type for event emission
"""
if transactions is None or len(transactions) == 0:
return
list_of_transactions = [safe_dumps(transactions)]
if self.redis_cache is None:
return
current_redis_buffer_size = await self.redis_cache.async_rpush(
key=redis_key,
values=list_of_transactions,
)
await self._emit_new_item_added_to_redis_buffer_event(
queue_size=current_redis_buffer_size,
service=service_type,
)
async def store_in_memory_spend_updates_in_redis(
self,
spend_update_queue: SpendUpdateQueue,
daily_spend_update_queue: DailySpendUpdateQueue,
daily_team_spend_update_queue: DailySpendUpdateQueue,
daily_tag_spend_update_queue: DailySpendUpdateQueue,
):
"""
Stores the in-memory spend updates to Redis
@ -124,18 +155,23 @@ class RedisUpdateBuffer:
)
return
# Get all transactions
db_spend_update_transactions = (
await spend_update_queue.flush_and_get_aggregated_db_spend_update_transactions()
)
verbose_proxy_logger.debug(
"ALL DB SPEND UPDATE TRANSACTIONS: %s", db_spend_update_transactions
)
daily_spend_update_transactions = (
await daily_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
)
daily_team_spend_update_transactions = (
await daily_team_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
)
daily_tag_spend_update_transactions = (
await daily_tag_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
)
verbose_proxy_logger.debug(
"ALL DB SPEND UPDATE TRANSACTIONS: %s", db_spend_update_transactions
)
verbose_proxy_logger.debug(
"ALL DAILY SPEND UPDATE TRANSACTIONS: %s", daily_spend_update_transactions
)
@ -147,40 +183,29 @@ class RedisUpdateBuffer:
):
return
list_of_transactions = [safe_dumps(db_spend_update_transactions)]
current_redis_buffer_size = await self.redis_cache.async_rpush(
key=REDIS_UPDATE_BUFFER_KEY,
values=list_of_transactions,
)
await self._emit_new_item_added_to_redis_buffer_event(
queue_size=current_redis_buffer_size,
service=ServiceTypes.REDIS_SPEND_UPDATE_QUEUE,
# Store all transaction types using the helper method
await self._store_transactions_in_redis(
transactions=db_spend_update_transactions,
redis_key=REDIS_UPDATE_BUFFER_KEY,
service_type=ServiceTypes.REDIS_SPEND_UPDATE_QUEUE,
)
list_of_daily_spend_update_transactions = [
safe_dumps(daily_spend_update_transactions)
]
current_redis_buffer_size = await self.redis_cache.async_rpush(
key=REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY,
values=list_of_daily_spend_update_transactions,
)
await self._emit_new_item_added_to_redis_buffer_event(
queue_size=current_redis_buffer_size,
service=ServiceTypes.REDIS_DAILY_SPEND_UPDATE_QUEUE,
await self._store_transactions_in_redis(
transactions=daily_spend_update_transactions,
redis_key=REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY,
service_type=ServiceTypes.REDIS_DAILY_SPEND_UPDATE_QUEUE,
)
list_of_daily_team_spend_update_transactions = [
safe_dumps(daily_team_spend_update_transactions)
]
current_redis_buffer_size = await self.redis_cache.async_rpush(
key=REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY,
values=list_of_daily_team_spend_update_transactions,
await self._store_transactions_in_redis(
transactions=daily_team_spend_update_transactions,
redis_key=REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY,
service_type=ServiceTypes.REDIS_DAILY_TEAM_SPEND_UPDATE_QUEUE,
)
await self._emit_new_item_added_to_redis_buffer_event(
queue_size=current_redis_buffer_size,
service=ServiceTypes.REDIS_DAILY_TEAM_SPEND_UPDATE_QUEUE,
await self._store_transactions_in_redis(
transactions=daily_tag_spend_update_transactions,
redis_key=REDIS_DAILY_TAG_SPEND_UPDATE_BUFFER_KEY,
service_type=ServiceTypes.REDIS_DAILY_TAG_SPEND_UPDATE_QUEUE,
)
@staticmethod

View file

@ -353,6 +353,8 @@ model LiteLLM_DailyTeamSpend {
custom_llm_provider String?
prompt_tokens Int @default(0)
completion_tokens Int @default(0)
cache_read_input_tokens Int @default(0)
cache_creation_input_tokens Int @default(0)
spend Float @default(0.0)
api_requests Int @default(0)
successful_requests Int @default(0)
@ -367,6 +369,33 @@ model LiteLLM_DailyTeamSpend {
@@index([model])
}
// Track daily team spend metrics per model and key
model LiteLLM_DailyTagSpend {
id String @id @default(uuid())
tag String @unique
date String
api_key String
model String
model_group String?
custom_llm_provider String?
prompt_tokens Int @default(0)
completion_tokens Int @default(0)
cache_read_input_tokens Int @default(0)
cache_creation_input_tokens Int @default(0)
spend Float @default(0.0)
api_requests Int @default(0)
successful_requests Int @default(0)
failed_requests Int @default(0)
created_at DateTime @default(now())
updated_at DateTime @updatedAt
@@unique([tag, date, api_key, model, custom_llm_provider])
@@index([date])
@@index([tag])
@@index([api_key])
@@index([model])
}
// Track the status of cron jobs running. Only allow one pod to run the job at a time
model LiteLLM_CronJob {

View file

@ -2796,50 +2796,3 @@ def _premium_user_check():
"error": f"This feature is only available for LiteLLM Enterprise users. {CommonProxyErrors.not_premium_user.value}"
},
)
async def _update_daily_spend_batch(prisma_client, spend_aggregates):
"""Helper function to update daily spend in batches"""
async with prisma_client.db.batch_() as batcher:
for (
user_id,
date,
api_key,
model,
model_group,
provider,
), metrics in spend_aggregates.items():
if not user_id: # Skip if no user_id
continue
batcher.litellm_dailyuserspend.upsert(
where={
"user_id_date_api_key_model_custom_llm_provider": {
"user_id": user_id,
"date": date,
"api_key": api_key,
"model": model,
"custom_llm_provider": provider,
}
},
data={
"create": {
"user_id": user_id,
"date": date,
"api_key": api_key,
"model": model,
"model_group": model_group,
"custom_llm_provider": provider,
"prompt_tokens": metrics["prompt_tokens"],
"completion_tokens": metrics["completion_tokens"],
"spend": metrics["spend"],
},
"update": {
"prompt_tokens": {"increment": metrics["prompt_tokens"]},
"completion_tokens": {
"increment": metrics["completion_tokens"]
},
"spend": {"increment": metrics["spend"]},
},
},
)

View file

@ -34,6 +34,7 @@ class ServiceTypes(str, enum.Enum):
IN_MEMORY_DAILY_SPEND_UPDATE_QUEUE = "in_memory_daily_spend_update_queue"
REDIS_DAILY_SPEND_UPDATE_QUEUE = "redis_daily_spend_update_queue"
REDIS_DAILY_TEAM_SPEND_UPDATE_QUEUE = "redis_daily_team_spend_update_queue"
REDIS_DAILY_TAG_SPEND_UPDATE_QUEUE = "redis_daily_tag_spend_update_queue"
# spend update queue - current spend of key, user, team
IN_MEMORY_SPEND_UPDATE_QUEUE = "in_memory_spend_update_queue"
REDIS_SPEND_UPDATE_QUEUE = "redis_spend_update_queue"

View file

@ -353,6 +353,8 @@ model LiteLLM_DailyTeamSpend {
custom_llm_provider String?
prompt_tokens Int @default(0)
completion_tokens Int @default(0)
cache_read_input_tokens Int @default(0)
cache_creation_input_tokens Int @default(0)
spend Float @default(0.0)
api_requests Int @default(0)
successful_requests Int @default(0)
@ -367,6 +369,33 @@ model LiteLLM_DailyTeamSpend {
@@index([model])
}
// Track daily team spend metrics per model and key
model LiteLLM_DailyTagSpend {
id String @id @default(uuid())
tag String @unique
date String
api_key String
model String
model_group String?
custom_llm_provider String?
prompt_tokens Int @default(0)
completion_tokens Int @default(0)
cache_read_input_tokens Int @default(0)
cache_creation_input_tokens Int @default(0)
spend Float @default(0.0)
api_requests Int @default(0)
successful_requests Int @default(0)
failed_requests Int @default(0)
created_at DateTime @default(now())
updated_at DateTime @updatedAt
@@unique([tag, date, api_key, model, custom_llm_provider])
@@index([date])
@@index([tag])
@@index([api_key])
@@index([model])
}
// Track the status of cron jobs running. Only allow one pod to run the job at a time
model LiteLLM_CronJob {