Add aggregate spend by tag (#10071)

* feat: initial commit adding daily tag spend table to db

* feat(db_spend_update_writer.py): correctly log tag spend transactions

* build(schema.prisma): add new tag table to root

* build: add new migration file
This commit is contained in:
Krish Dholakia 2025-04-16 12:26:21 -07:00 committed by GitHub
parent 3a3cc97fc8
commit 3f5c8ae000
11 changed files with 565 additions and 373 deletions

View file

@ -0,0 +1,45 @@
-- AlterTable
ALTER TABLE "LiteLLM_DailyTeamSpend" ADD COLUMN "cache_creation_input_tokens" INTEGER NOT NULL DEFAULT 0,
ADD COLUMN "cache_read_input_tokens" INTEGER NOT NULL DEFAULT 0;
-- CreateTable
CREATE TABLE "LiteLLM_DailyTagSpend" (
"id" TEXT NOT NULL,
"tag" TEXT NOT NULL,
"date" TEXT NOT NULL,
"api_key" TEXT NOT NULL,
"model" TEXT NOT NULL,
"model_group" TEXT,
"custom_llm_provider" TEXT,
"prompt_tokens" INTEGER NOT NULL DEFAULT 0,
"completion_tokens" INTEGER NOT NULL DEFAULT 0,
"cache_read_input_tokens" INTEGER NOT NULL DEFAULT 0,
"cache_creation_input_tokens" INTEGER NOT NULL DEFAULT 0,
"spend" DOUBLE PRECISION NOT NULL DEFAULT 0.0,
"api_requests" INTEGER NOT NULL DEFAULT 0,
"successful_requests" INTEGER NOT NULL DEFAULT 0,
"failed_requests" INTEGER NOT NULL DEFAULT 0,
"created_at" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMP(3) NOT NULL,
CONSTRAINT "LiteLLM_DailyTagSpend_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "LiteLLM_DailyTagSpend_tag_key" ON "LiteLLM_DailyTagSpend"("tag");
-- CreateIndex
CREATE INDEX "LiteLLM_DailyTagSpend_date_idx" ON "LiteLLM_DailyTagSpend"("date");
-- CreateIndex
CREATE INDEX "LiteLLM_DailyTagSpend_tag_idx" ON "LiteLLM_DailyTagSpend"("tag");
-- CreateIndex
CREATE INDEX "LiteLLM_DailyTagSpend_api_key_idx" ON "LiteLLM_DailyTagSpend"("api_key");
-- CreateIndex
CREATE INDEX "LiteLLM_DailyTagSpend_model_idx" ON "LiteLLM_DailyTagSpend"("model");
-- CreateIndex
CREATE UNIQUE INDEX "LiteLLM_DailyTagSpend_tag_date_api_key_model_custom_llm_pro_key" ON "LiteLLM_DailyTagSpend"("tag", "date", "api_key", "model", "custom_llm_provider");

View file

@ -342,6 +342,60 @@ model LiteLLM_DailyUserSpend {
@@index([model]) @@index([model])
} }
// Track daily team spend metrics per model and key
model LiteLLM_DailyTeamSpend {
id String @id @default(uuid())
team_id String
date String
api_key String
model String
model_group String?
custom_llm_provider String?
prompt_tokens Int @default(0)
completion_tokens Int @default(0)
cache_read_input_tokens Int @default(0)
cache_creation_input_tokens Int @default(0)
spend Float @default(0.0)
api_requests Int @default(0)
successful_requests Int @default(0)
failed_requests Int @default(0)
created_at DateTime @default(now())
updated_at DateTime @updatedAt
@@unique([team_id, date, api_key, model, custom_llm_provider])
@@index([date])
@@index([team_id])
@@index([api_key])
@@index([model])
}
// Track daily team spend metrics per model and key
model LiteLLM_DailyTagSpend {
id String @id @default(uuid())
tag String @unique
date String
api_key String
model String
model_group String?
custom_llm_provider String?
prompt_tokens Int @default(0)
completion_tokens Int @default(0)
cache_read_input_tokens Int @default(0)
cache_creation_input_tokens Int @default(0)
spend Float @default(0.0)
api_requests Int @default(0)
successful_requests Int @default(0)
failed_requests Int @default(0)
created_at DateTime @default(now())
updated_at DateTime @updatedAt
@@unique([tag, date, api_key, model, custom_llm_provider])
@@index([date])
@@index([tag])
@@index([api_key])
@@index([model])
}
// Track the status of cron jobs running. Only allow one pod to run the job at a time // Track the status of cron jobs running. Only allow one pod to run the job at a time
model LiteLLM_CronJob { model LiteLLM_CronJob {

View file

@ -28,6 +28,7 @@ _DEFAULT_TTL_FOR_HTTPX_CLIENTS = 3600 # 1 hour, re-use the same httpx client fo
REDIS_UPDATE_BUFFER_KEY = "litellm_spend_update_buffer" REDIS_UPDATE_BUFFER_KEY = "litellm_spend_update_buffer"
REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY = "litellm_daily_spend_update_buffer" REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY = "litellm_daily_spend_update_buffer"
REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY = "litellm_daily_team_spend_update_buffer" REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY = "litellm_daily_team_spend_update_buffer"
REDIS_DAILY_TAG_SPEND_UPDATE_BUFFER_KEY = "litellm_daily_tag_spend_update_buffer"
MAX_REDIS_BUFFER_DEQUEUE_COUNT = 100 MAX_REDIS_BUFFER_DEQUEUE_COUNT = 100
MAX_SIZE_IN_MEMORY_QUEUE = 10000 MAX_SIZE_IN_MEMORY_QUEUE = 10000
MAX_IN_MEMORY_QUEUE_FLUSH_COUNT = 1000 MAX_IN_MEMORY_QUEUE_FLUSH_COUNT = 1000

File diff suppressed because one or more lines are too long

View file

@ -650,9 +650,9 @@ class GenerateRequestBase(LiteLLMPydanticObjectBase):
allowed_cache_controls: Optional[list] = [] allowed_cache_controls: Optional[list] = []
config: Optional[dict] = {} config: Optional[dict] = {}
permissions: Optional[dict] = {} permissions: Optional[dict] = {}
model_max_budget: Optional[dict] = ( model_max_budget: Optional[
{} dict
) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {} ] = {} # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
model_rpm_limit: Optional[dict] = None model_rpm_limit: Optional[dict] = None
@ -908,12 +908,12 @@ class NewCustomerRequest(BudgetNewRequest):
alias: Optional[str] = None # human-friendly alias alias: Optional[str] = None # human-friendly alias
blocked: bool = False # allow/disallow requests for this end-user blocked: bool = False # allow/disallow requests for this end-user
budget_id: Optional[str] = None # give either a budget_id or max_budget budget_id: Optional[str] = None # give either a budget_id or max_budget
allowed_model_region: Optional[AllowedModelRegion] = ( allowed_model_region: Optional[
None # require all user requests to use models in this specific region AllowedModelRegion
) ] = None # require all user requests to use models in this specific region
default_model: Optional[str] = ( default_model: Optional[
None # if no equivalent model in allowed region - default all requests to this model str
) ] = None # if no equivalent model in allowed region - default all requests to this model
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
@ -935,12 +935,12 @@ class UpdateCustomerRequest(LiteLLMPydanticObjectBase):
blocked: bool = False # allow/disallow requests for this end-user blocked: bool = False # allow/disallow requests for this end-user
max_budget: Optional[float] = None max_budget: Optional[float] = None
budget_id: Optional[str] = None # give either a budget_id or max_budget budget_id: Optional[str] = None # give either a budget_id or max_budget
allowed_model_region: Optional[AllowedModelRegion] = ( allowed_model_region: Optional[
None # require all user requests to use models in this specific region AllowedModelRegion
) ] = None # require all user requests to use models in this specific region
default_model: Optional[str] = ( default_model: Optional[
None # if no equivalent model in allowed region - default all requests to this model str
) ] = None # if no equivalent model in allowed region - default all requests to this model
class DeleteCustomerRequest(LiteLLMPydanticObjectBase): class DeleteCustomerRequest(LiteLLMPydanticObjectBase):
@ -1076,9 +1076,9 @@ class BlockKeyRequest(LiteLLMPydanticObjectBase):
class AddTeamCallback(LiteLLMPydanticObjectBase): class AddTeamCallback(LiteLLMPydanticObjectBase):
callback_name: str callback_name: str
callback_type: Optional[Literal["success", "failure", "success_and_failure"]] = ( callback_type: Optional[
"success_and_failure" Literal["success", "failure", "success_and_failure"]
) ] = "success_and_failure"
callback_vars: Dict[str, str] callback_vars: Dict[str, str]
@model_validator(mode="before") @model_validator(mode="before")
@ -1335,9 +1335,9 @@ class ConfigList(LiteLLMPydanticObjectBase):
stored_in_db: Optional[bool] stored_in_db: Optional[bool]
field_default_value: Any field_default_value: Any
premium_field: bool = False premium_field: bool = False
nested_fields: Optional[List[FieldDetail]] = ( nested_fields: Optional[
None # For nested dictionary or Pydantic fields List[FieldDetail]
) ] = None # For nested dictionary or Pydantic fields
class ConfigGeneralSettings(LiteLLMPydanticObjectBase): class ConfigGeneralSettings(LiteLLMPydanticObjectBase):
@ -1604,9 +1604,9 @@ class LiteLLM_OrganizationMembershipTable(LiteLLMPydanticObjectBase):
budget_id: Optional[str] = None budget_id: Optional[str] = None
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
user: Optional[Any] = ( user: Optional[
None # You might want to replace 'Any' with a more specific type if available Any
) ] = None # You might want to replace 'Any' with a more specific type if available
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
@ -2354,9 +2354,9 @@ class TeamModelDeleteRequest(BaseModel):
# Organization Member Requests # Organization Member Requests
class OrganizationMemberAddRequest(OrgMemberAddRequest): class OrganizationMemberAddRequest(OrgMemberAddRequest):
organization_id: str organization_id: str
max_budget_in_organization: Optional[float] = ( max_budget_in_organization: Optional[
None # Users max budget within the organization float
) ] = None # Users max budget within the organization
class OrganizationMemberDeleteRequest(MemberDeleteRequest): class OrganizationMemberDeleteRequest(MemberDeleteRequest):
@ -2545,9 +2545,9 @@ class ProviderBudgetResponse(LiteLLMPydanticObjectBase):
Maps provider names to their budget configs. Maps provider names to their budget configs.
""" """
providers: Dict[str, ProviderBudgetResponseObject] = ( providers: Dict[
{} str, ProviderBudgetResponseObject
) # Dictionary mapping provider names to their budget configurations ] = {} # Dictionary mapping provider names to their budget configurations
class ProxyStateVariables(TypedDict): class ProxyStateVariables(TypedDict):
@ -2675,9 +2675,9 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
enforce_rbac: bool = False enforce_rbac: bool = False
roles_jwt_field: Optional[str] = None # v2 on role mappings roles_jwt_field: Optional[str] = None # v2 on role mappings
role_mappings: Optional[List[RoleMapping]] = None role_mappings: Optional[List[RoleMapping]] = None
object_id_jwt_field: Optional[str] = ( object_id_jwt_field: Optional[
None # can be either user / team, inferred from the role mapping str
) ] = None # can be either user / team, inferred from the role mapping
scope_mappings: Optional[List[ScopeMapping]] = None scope_mappings: Optional[List[ScopeMapping]] = None
enforce_scope_based_access: bool = False enforce_scope_based_access: bool = False
enforce_team_based_model_access: bool = False enforce_team_based_model_access: bool = False
@ -2799,6 +2799,10 @@ class DailyUserSpendTransaction(BaseDailySpendTransaction):
user_id: str user_id: str
class DailyTagSpendTransaction(BaseDailySpendTransaction):
tag: str
class DBSpendUpdateTransactions(TypedDict): class DBSpendUpdateTransactions(TypedDict):
""" """
Internal Data Structure for buffering spend updates in Redis or in memory before committing them to the database Internal Data Structure for buffering spend updates in Redis or in memory before committing them to the database

View file

@ -11,7 +11,7 @@ import os
import time import time
import traceback import traceback
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union, cast from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union, cast, overload
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
@ -20,6 +20,7 @@ from litellm.constants import DB_SPEND_UPDATE_JOB_NAME
from litellm.proxy._types import ( from litellm.proxy._types import (
DB_CONNECTION_ERROR_TYPES, DB_CONNECTION_ERROR_TYPES,
BaseDailySpendTransaction, BaseDailySpendTransaction,
DailyTagSpendTransaction,
DailyTeamSpendTransaction, DailyTeamSpendTransaction,
DailyUserSpendTransaction, DailyUserSpendTransaction,
DBSpendUpdateTransactions, DBSpendUpdateTransactions,
@ -61,6 +62,7 @@ class DBSpendUpdateWriter:
self.spend_update_queue = SpendUpdateQueue() self.spend_update_queue = SpendUpdateQueue()
self.daily_spend_update_queue = DailySpendUpdateQueue() self.daily_spend_update_queue = DailySpendUpdateQueue()
self.daily_team_spend_update_queue = DailySpendUpdateQueue() self.daily_team_spend_update_queue = DailySpendUpdateQueue()
self.daily_tag_spend_update_queue = DailySpendUpdateQueue()
async def update_database( async def update_database(
# LiteLLM management object fields # LiteLLM management object fields
@ -170,6 +172,13 @@ class DBSpendUpdateWriter:
) )
) )
asyncio.create_task(
self.add_spend_log_transaction_to_daily_tag_transaction(
payload=payload,
prisma_client=prisma_client,
)
)
verbose_proxy_logger.debug("Runs spend update on all tables") verbose_proxy_logger.debug("Runs spend update on all tables")
except Exception: except Exception:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
@ -394,6 +403,7 @@ class DBSpendUpdateWriter:
spend_update_queue=self.spend_update_queue, spend_update_queue=self.spend_update_queue,
daily_spend_update_queue=self.daily_spend_update_queue, daily_spend_update_queue=self.daily_spend_update_queue,
daily_team_spend_update_queue=self.daily_team_spend_update_queue, daily_team_spend_update_queue=self.daily_team_spend_update_queue,
daily_tag_spend_update_queue=self.daily_tag_spend_update_queue,
) )
# Only commit from redis to db if this pod is the leader # Only commit from redis to db if this pod is the leader
@ -495,6 +505,20 @@ class DBSpendUpdateWriter:
daily_spend_transactions=daily_team_spend_update_transactions, daily_spend_transactions=daily_team_spend_update_transactions,
) )
################## Daily Tag Spend Update Transactions ##################
# Aggregate all in memory daily tag spend transactions and commit to db
daily_tag_spend_update_transactions = cast(
Dict[str, DailyTagSpendTransaction],
await self.daily_tag_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions(),
)
await DBSpendUpdateWriter.update_daily_tag_spend(
n_retry_times=n_retry_times,
prisma_client=prisma_client,
proxy_logging_obj=proxy_logging_obj,
daily_spend_transactions=daily_tag_spend_update_transactions,
)
async def _commit_spend_updates_to_db( # noqa: PLR0915 async def _commit_spend_updates_to_db( # noqa: PLR0915
self, self,
prisma_client: PrismaClient, prisma_client: PrismaClient,
@ -740,6 +764,208 @@ class DBSpendUpdateWriter:
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
) )
@overload
@staticmethod
async def _update_daily_spend(
n_retry_times: int,
prisma_client: PrismaClient,
proxy_logging_obj: ProxyLogging,
daily_spend_transactions: Dict[str, DailyUserSpendTransaction],
entity_type: Literal["user"],
entity_id_field: str,
table_name: str,
unique_constraint_name: str,
) -> None:
...
@overload
@staticmethod
async def _update_daily_spend(
n_retry_times: int,
prisma_client: PrismaClient,
proxy_logging_obj: ProxyLogging,
daily_spend_transactions: Dict[str, DailyTeamSpendTransaction],
entity_type: Literal["team"],
entity_id_field: str,
table_name: str,
unique_constraint_name: str,
) -> None:
...
@overload
@staticmethod
async def _update_daily_spend(
n_retry_times: int,
prisma_client: PrismaClient,
proxy_logging_obj: ProxyLogging,
daily_spend_transactions: Dict[str, DailyTagSpendTransaction],
entity_type: Literal["tag"],
entity_id_field: str,
table_name: str,
unique_constraint_name: str,
) -> None:
...
@staticmethod
async def _update_daily_spend(
n_retry_times: int,
prisma_client: PrismaClient,
proxy_logging_obj: ProxyLogging,
daily_spend_transactions: Union[
Dict[str, DailyUserSpendTransaction],
Dict[str, DailyTeamSpendTransaction],
Dict[str, DailyTagSpendTransaction],
],
entity_type: Literal["user", "team", "tag"],
entity_id_field: str,
table_name: str,
unique_constraint_name: str,
) -> None:
"""
Generic function to update daily spend for any entity type (user, team, tag)
"""
from litellm.proxy.utils import _raise_failed_update_spend_exception
verbose_proxy_logger.debug(
f"Daily {entity_type.capitalize()} Spend transactions: {len(daily_spend_transactions)}"
)
BATCH_SIZE = 100
start_time = time.time()
try:
for i in range(n_retry_times + 1):
try:
transactions_to_process = dict(
list(daily_spend_transactions.items())[:BATCH_SIZE]
)
if len(transactions_to_process) == 0:
verbose_proxy_logger.debug(
f"No new transactions to process for daily {entity_type} spend update"
)
break
async with prisma_client.db.batch_() as batcher:
for _, transaction in transactions_to_process.items():
entity_id = transaction.get(entity_id_field)
if not entity_id:
continue
# Construct the where clause dynamically
where_clause = {
unique_constraint_name: {
entity_id_field: entity_id,
"date": transaction["date"],
"api_key": transaction["api_key"],
"model": transaction["model"],
"custom_llm_provider": transaction.get(
"custom_llm_provider"
),
}
}
# Get the table dynamically
table = getattr(batcher, table_name)
# Common data structure for both create and update
common_data = {
entity_id_field: entity_id,
"date": transaction["date"],
"api_key": transaction["api_key"],
"model": transaction["model"],
"model_group": transaction.get("model_group"),
"custom_llm_provider": transaction.get(
"custom_llm_provider"
),
"prompt_tokens": transaction["prompt_tokens"],
"completion_tokens": transaction["completion_tokens"],
"spend": transaction["spend"],
"api_requests": transaction["api_requests"],
"successful_requests": transaction[
"successful_requests"
],
"failed_requests": transaction["failed_requests"],
}
# Add cache-related fields if they exist
if "cache_read_input_tokens" in transaction:
common_data[
"cache_read_input_tokens"
] = transaction.get("cache_read_input_tokens", 0)
if "cache_creation_input_tokens" in transaction:
common_data[
"cache_creation_input_tokens"
] = transaction.get("cache_creation_input_tokens", 0)
# Create update data structure
update_data = {
"prompt_tokens": {
"increment": transaction["prompt_tokens"]
},
"completion_tokens": {
"increment": transaction["completion_tokens"]
},
"spend": {"increment": transaction["spend"]},
"api_requests": {
"increment": transaction["api_requests"]
},
"successful_requests": {
"increment": transaction["successful_requests"]
},
"failed_requests": {
"increment": transaction["failed_requests"]
},
}
# Add cache-related fields to update if they exist
if "cache_read_input_tokens" in transaction:
update_data["cache_read_input_tokens"] = {
"increment": transaction.get(
"cache_read_input_tokens", 0
)
}
if "cache_creation_input_tokens" in transaction:
update_data["cache_creation_input_tokens"] = {
"increment": transaction.get(
"cache_creation_input_tokens", 0
)
}
table.upsert(
where=where_clause,
data={
"create": common_data,
"update": update_data,
},
)
verbose_proxy_logger.info(
f"Processed {len(transactions_to_process)} daily {entity_type} transactions in {time.time() - start_time:.2f}s"
)
# Remove processed transactions
for key in transactions_to_process.keys():
daily_spend_transactions.pop(key, None)
break
except DB_CONNECTION_ERROR_TYPES as e:
if i >= n_retry_times:
_raise_failed_update_spend_exception(
e=e,
start_time=start_time,
proxy_logging_obj=proxy_logging_obj,
)
await asyncio.sleep(2**i)
except Exception as e:
if "transactions_to_process" in locals():
for key in transactions_to_process.keys(): # type: ignore
daily_spend_transactions.pop(key, None)
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
@staticmethod @staticmethod
async def update_daily_user_spend( async def update_daily_user_spend(
n_retry_times: int, n_retry_times: int,
@ -750,144 +976,16 @@ class DBSpendUpdateWriter:
""" """
Batch job to update LiteLLM_DailyUserSpend table using in-memory daily_spend_transactions Batch job to update LiteLLM_DailyUserSpend table using in-memory daily_spend_transactions
""" """
from litellm.proxy.utils import _raise_failed_update_spend_exception await DBSpendUpdateWriter._update_daily_spend(
n_retry_times=n_retry_times,
### UPDATE DAILY USER SPEND ### prisma_client=prisma_client,
verbose_proxy_logger.debug( proxy_logging_obj=proxy_logging_obj,
"Daily User Spend transactions: {}".format(len(daily_spend_transactions)) daily_spend_transactions=daily_spend_transactions,
entity_type="user",
entity_id_field="user_id",
table_name="litellm_dailyuserspend",
unique_constraint_name="user_id_date_api_key_model_custom_llm_provider",
) )
BATCH_SIZE = (
100 # Number of aggregated records to update in each database operation
)
start_time = time.time()
try:
for i in range(n_retry_times + 1):
try:
# Get transactions to process
transactions_to_process = dict(
list(daily_spend_transactions.items())[:BATCH_SIZE]
)
if len(transactions_to_process) == 0:
verbose_proxy_logger.debug(
"No new transactions to process for daily spend update"
)
break
# Update DailyUserSpend table in batches
async with prisma_client.db.batch_() as batcher:
for _, transaction in transactions_to_process.items():
user_id = transaction.get("user_id")
if not user_id: # Skip if no user_id
continue
batcher.litellm_dailyuserspend.upsert(
where={
"user_id_date_api_key_model_custom_llm_provider": {
"user_id": user_id,
"date": transaction["date"],
"api_key": transaction["api_key"],
"model": transaction["model"],
"custom_llm_provider": transaction.get(
"custom_llm_provider"
),
}
},
data={
"create": {
"user_id": user_id,
"date": transaction["date"],
"api_key": transaction["api_key"],
"model": transaction["model"],
"model_group": transaction.get("model_group"),
"custom_llm_provider": transaction.get(
"custom_llm_provider"
),
"prompt_tokens": transaction["prompt_tokens"],
"completion_tokens": transaction[
"completion_tokens"
],
"cache_read_input_tokens": transaction.get(
"cache_read_input_tokens", 0
),
"cache_creation_input_tokens": transaction.get(
"cache_creation_input_tokens", 0
),
"spend": transaction["spend"],
"api_requests": transaction["api_requests"],
"successful_requests": transaction[
"successful_requests"
],
"failed_requests": transaction[
"failed_requests"
],
},
"update": {
"prompt_tokens": {
"increment": transaction["prompt_tokens"]
},
"completion_tokens": {
"increment": transaction[
"completion_tokens"
]
},
"cache_read_input_tokens": {
"increment": transaction.get(
"cache_read_input_tokens", 0
)
},
"cache_creation_input_tokens": {
"increment": transaction.get(
"cache_creation_input_tokens", 0
)
},
"spend": {"increment": transaction["spend"]},
"api_requests": {
"increment": transaction["api_requests"]
},
"successful_requests": {
"increment": transaction[
"successful_requests"
]
},
"failed_requests": {
"increment": transaction["failed_requests"]
},
},
},
)
verbose_proxy_logger.info(
f"Processed {len(transactions_to_process)} daily spend transactions in {time.time() - start_time:.2f}s"
)
# Remove processed transactions
for key in transactions_to_process.keys():
daily_spend_transactions.pop(key, None)
verbose_proxy_logger.debug(
f"Processed {len(transactions_to_process)} daily spend transactions in {time.time() - start_time:.2f}s"
)
break
except DB_CONNECTION_ERROR_TYPES as e:
if i >= n_retry_times:
_raise_failed_update_spend_exception(
e=e,
start_time=start_time,
proxy_logging_obj=proxy_logging_obj,
)
await asyncio.sleep(2**i) # Exponential backoff
except Exception as e:
# Remove processed transactions even if there was an error
if "transactions_to_process" in locals():
for key in transactions_to_process.keys(): # type: ignore
daily_spend_transactions.pop(key, None)
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
@staticmethod @staticmethod
async def update_daily_team_spend( async def update_daily_team_spend(
@ -899,140 +997,53 @@ class DBSpendUpdateWriter:
""" """
Batch job to update LiteLLM_DailyTeamSpend table using in-memory daily_spend_transactions Batch job to update LiteLLM_DailyTeamSpend table using in-memory daily_spend_transactions
""" """
from litellm.proxy.utils import _raise_failed_update_spend_exception await DBSpendUpdateWriter._update_daily_spend(
n_retry_times=n_retry_times,
### UPDATE DAILY USER SPEND ### prisma_client=prisma_client,
verbose_proxy_logger.debug( proxy_logging_obj=proxy_logging_obj,
"Daily Team Spend transactions: {}".format(len(daily_spend_transactions)) daily_spend_transactions=daily_spend_transactions,
entity_type="team",
entity_id_field="team_id",
table_name="litellm_dailyteamspend",
unique_constraint_name="team_id_date_api_key_model_custom_llm_provider",
) )
BATCH_SIZE = (
100 # Number of aggregated records to update in each database operation @staticmethod
async def update_daily_tag_spend(
n_retry_times: int,
prisma_client: PrismaClient,
proxy_logging_obj: ProxyLogging,
daily_spend_transactions: Dict[str, DailyTagSpendTransaction],
):
"""
Batch job to update LiteLLM_DailyTagSpend table using in-memory daily_spend_transactions
"""
await DBSpendUpdateWriter._update_daily_spend(
n_retry_times=n_retry_times,
prisma_client=prisma_client,
proxy_logging_obj=proxy_logging_obj,
daily_spend_transactions=daily_spend_transactions,
entity_type="tag",
entity_id_field="tag",
table_name="litellm_dailytagspend",
unique_constraint_name="tag_date_api_key_model_custom_llm_provider",
) )
start_time = time.time()
try:
for i in range(n_retry_times + 1):
try:
# Get transactions to process
transactions_to_process = dict(
list(daily_spend_transactions.items())[:BATCH_SIZE]
)
if len(transactions_to_process) == 0:
verbose_proxy_logger.debug(
"No new transactions to process for daily spend update"
)
break
# Update DailyUserSpend table in batches
async with prisma_client.db.batch_() as batcher:
for _, transaction in transactions_to_process.items():
team_id = transaction.get("team_id")
if not team_id: # Skip if no team_id
continue
batcher.litellm_dailyteamspend.upsert(
where={
"team_id_date_api_key_model_custom_llm_provider": {
"team_id": team_id,
"date": transaction["date"],
"api_key": transaction["api_key"],
"model": transaction["model"],
"custom_llm_provider": transaction.get(
"custom_llm_provider"
),
}
},
data={
"create": {
"team_id": team_id,
"date": transaction["date"],
"api_key": transaction["api_key"],
"model": transaction["model"],
"model_group": transaction.get("model_group"),
"custom_llm_provider": transaction.get(
"custom_llm_provider"
),
"prompt_tokens": transaction["prompt_tokens"],
"completion_tokens": transaction[
"completion_tokens"
],
"spend": transaction["spend"],
"api_requests": transaction["api_requests"],
"successful_requests": transaction[
"successful_requests"
],
"failed_requests": transaction[
"failed_requests"
],
},
"update": {
"prompt_tokens": {
"increment": transaction["prompt_tokens"]
},
"completion_tokens": {
"increment": transaction[
"completion_tokens"
]
},
"spend": {"increment": transaction["spend"]},
"api_requests": {
"increment": transaction["api_requests"]
},
"successful_requests": {
"increment": transaction[
"successful_requests"
]
},
"failed_requests": {
"increment": transaction["failed_requests"]
},
},
},
)
verbose_proxy_logger.info(
f"Processed {len(transactions_to_process)} daily team transactions in {time.time() - start_time:.2f}s"
)
# Remove processed transactions
for key in transactions_to_process.keys():
daily_spend_transactions.pop(key, None)
verbose_proxy_logger.debug(
f"Processed {len(transactions_to_process)} daily spend transactions in {time.time() - start_time:.2f}s"
)
break
except DB_CONNECTION_ERROR_TYPES as e:
if i >= n_retry_times:
_raise_failed_update_spend_exception(
e=e,
start_time=start_time,
proxy_logging_obj=proxy_logging_obj,
)
await asyncio.sleep(2**i) # Exponential backoff
except Exception as e:
# Remove processed transactions even if there was an error
if "transactions_to_process" in locals():
for key in transactions_to_process.keys(): # type: ignore
daily_spend_transactions.pop(key, None)
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
async def _common_add_spend_log_transaction_to_daily_transaction( async def _common_add_spend_log_transaction_to_daily_transaction(
self, self,
payload: Union[dict, SpendLogsPayload], payload: Union[dict, SpendLogsPayload],
prisma_client: PrismaClient, prisma_client: PrismaClient,
type: Literal["user", "team"] = "user", type: Literal["user", "team", "request_tags"] = "user",
) -> Optional[BaseDailySpendTransaction]: ) -> Optional[BaseDailySpendTransaction]:
common_expected_keys = ["startTime", "api_key", "model", "custom_llm_provider"] common_expected_keys = ["startTime", "api_key", "model", "custom_llm_provider"]
if type == "user": if type == "user":
expected_keys = ["user", *common_expected_keys] expected_keys = ["user", *common_expected_keys]
else: elif type == "team":
expected_keys = ["team_id", *common_expected_keys] expected_keys = ["team_id", *common_expected_keys]
elif type == "request_tags":
expected_keys = ["request_tags", *common_expected_keys]
else:
raise ValueError(f"Invalid type: {type}")
if not all(key in payload for key in expected_keys): if not all(key in payload for key in expected_keys):
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
@ -1143,3 +1154,44 @@ class DBSpendUpdateWriter:
await self.daily_team_spend_update_queue.add_update( await self.daily_team_spend_update_queue.add_update(
update={daily_transaction_key: daily_transaction} update={daily_transaction_key: daily_transaction}
) )
async def add_spend_log_transaction_to_daily_tag_transaction(
self,
payload: SpendLogsPayload,
prisma_client: Optional[PrismaClient] = None,
) -> None:
if prisma_client is None:
verbose_proxy_logger.debug(
"prisma_client is None. Skipping writing spend logs to db."
)
return
base_daily_transaction = (
await self._common_add_spend_log_transaction_to_daily_transaction(
payload, prisma_client, "request_tags"
)
)
if base_daily_transaction is None:
return
if payload["request_tags"] is None:
verbose_proxy_logger.debug(
"request_tags is None for request. Skipping incrementing tag spend."
)
return
request_tags = []
if isinstance(payload["request_tags"], str):
request_tags = json.loads(payload["request_tags"])
elif isinstance(payload["request_tags"], list):
request_tags = payload["request_tags"]
else:
raise ValueError(f"Invalid request_tags: {payload['request_tags']}")
for tag in request_tags:
daily_transaction_key = f"{tag}_{base_daily_transaction['date']}_{payload['api_key']}_{payload['model']}_{payload['custom_llm_provider']}"
daily_transaction = DailyTagSpendTransaction(
tag=tag, **base_daily_transaction
)
await self.daily_tag_spend_update_queue.add_update(
update={daily_transaction_key: daily_transaction}
)

View file

@ -13,6 +13,7 @@ from litellm.caching import RedisCache
from litellm.constants import ( from litellm.constants import (
MAX_REDIS_BUFFER_DEQUEUE_COUNT, MAX_REDIS_BUFFER_DEQUEUE_COUNT,
REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY, REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY,
REDIS_DAILY_TAG_SPEND_UPDATE_BUFFER_KEY,
REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY, REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY,
REDIS_UPDATE_BUFFER_KEY, REDIS_UPDATE_BUFFER_KEY,
) )
@ -68,11 +69,41 @@ class RedisUpdateBuffer:
return False return False
return _use_redis_transaction_buffer return _use_redis_transaction_buffer
async def _store_transactions_in_redis(
self,
transactions: Any,
redis_key: str,
service_type: ServiceTypes,
) -> None:
"""
Helper method to store transactions in Redis and emit an event
Args:
transactions: The transactions to store
redis_key: The Redis key to store under
service_type: The service type for event emission
"""
if transactions is None or len(transactions) == 0:
return
list_of_transactions = [safe_dumps(transactions)]
if self.redis_cache is None:
return
current_redis_buffer_size = await self.redis_cache.async_rpush(
key=redis_key,
values=list_of_transactions,
)
await self._emit_new_item_added_to_redis_buffer_event(
queue_size=current_redis_buffer_size,
service=service_type,
)
async def store_in_memory_spend_updates_in_redis( async def store_in_memory_spend_updates_in_redis(
self, self,
spend_update_queue: SpendUpdateQueue, spend_update_queue: SpendUpdateQueue,
daily_spend_update_queue: DailySpendUpdateQueue, daily_spend_update_queue: DailySpendUpdateQueue,
daily_team_spend_update_queue: DailySpendUpdateQueue, daily_team_spend_update_queue: DailySpendUpdateQueue,
daily_tag_spend_update_queue: DailySpendUpdateQueue,
): ):
""" """
Stores the in-memory spend updates to Redis Stores the in-memory spend updates to Redis
@ -124,18 +155,23 @@ class RedisUpdateBuffer:
) )
return return
# Get all transactions
db_spend_update_transactions = ( db_spend_update_transactions = (
await spend_update_queue.flush_and_get_aggregated_db_spend_update_transactions() await spend_update_queue.flush_and_get_aggregated_db_spend_update_transactions()
) )
verbose_proxy_logger.debug(
"ALL DB SPEND UPDATE TRANSACTIONS: %s", db_spend_update_transactions
)
daily_spend_update_transactions = ( daily_spend_update_transactions = (
await daily_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions() await daily_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
) )
daily_team_spend_update_transactions = ( daily_team_spend_update_transactions = (
await daily_team_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions() await daily_team_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
) )
daily_tag_spend_update_transactions = (
await daily_tag_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
)
verbose_proxy_logger.debug(
"ALL DB SPEND UPDATE TRANSACTIONS: %s", db_spend_update_transactions
)
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"ALL DAILY SPEND UPDATE TRANSACTIONS: %s", daily_spend_update_transactions "ALL DAILY SPEND UPDATE TRANSACTIONS: %s", daily_spend_update_transactions
) )
@ -147,40 +183,29 @@ class RedisUpdateBuffer:
): ):
return return
list_of_transactions = [safe_dumps(db_spend_update_transactions)] # Store all transaction types using the helper method
current_redis_buffer_size = await self.redis_cache.async_rpush( await self._store_transactions_in_redis(
key=REDIS_UPDATE_BUFFER_KEY, transactions=db_spend_update_transactions,
values=list_of_transactions, redis_key=REDIS_UPDATE_BUFFER_KEY,
) service_type=ServiceTypes.REDIS_SPEND_UPDATE_QUEUE,
await self._emit_new_item_added_to_redis_buffer_event(
queue_size=current_redis_buffer_size,
service=ServiceTypes.REDIS_SPEND_UPDATE_QUEUE,
) )
list_of_daily_spend_update_transactions = [ await self._store_transactions_in_redis(
safe_dumps(daily_spend_update_transactions) transactions=daily_spend_update_transactions,
] redis_key=REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY,
service_type=ServiceTypes.REDIS_DAILY_SPEND_UPDATE_QUEUE,
current_redis_buffer_size = await self.redis_cache.async_rpush(
key=REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY,
values=list_of_daily_spend_update_transactions,
)
await self._emit_new_item_added_to_redis_buffer_event(
queue_size=current_redis_buffer_size,
service=ServiceTypes.REDIS_DAILY_SPEND_UPDATE_QUEUE,
) )
list_of_daily_team_spend_update_transactions = [ await self._store_transactions_in_redis(
safe_dumps(daily_team_spend_update_transactions) transactions=daily_team_spend_update_transactions,
] redis_key=REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY,
service_type=ServiceTypes.REDIS_DAILY_TEAM_SPEND_UPDATE_QUEUE,
current_redis_buffer_size = await self.redis_cache.async_rpush(
key=REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY,
values=list_of_daily_team_spend_update_transactions,
) )
await self._emit_new_item_added_to_redis_buffer_event(
queue_size=current_redis_buffer_size, await self._store_transactions_in_redis(
service=ServiceTypes.REDIS_DAILY_TEAM_SPEND_UPDATE_QUEUE, transactions=daily_tag_spend_update_transactions,
redis_key=REDIS_DAILY_TAG_SPEND_UPDATE_BUFFER_KEY,
service_type=ServiceTypes.REDIS_DAILY_TAG_SPEND_UPDATE_QUEUE,
) )
@staticmethod @staticmethod

View file

@ -353,6 +353,8 @@ model LiteLLM_DailyTeamSpend {
custom_llm_provider String? custom_llm_provider String?
prompt_tokens Int @default(0) prompt_tokens Int @default(0)
completion_tokens Int @default(0) completion_tokens Int @default(0)
cache_read_input_tokens Int @default(0)
cache_creation_input_tokens Int @default(0)
spend Float @default(0.0) spend Float @default(0.0)
api_requests Int @default(0) api_requests Int @default(0)
successful_requests Int @default(0) successful_requests Int @default(0)
@ -367,6 +369,33 @@ model LiteLLM_DailyTeamSpend {
@@index([model]) @@index([model])
} }
// Track daily team spend metrics per model and key
model LiteLLM_DailyTagSpend {
id String @id @default(uuid())
tag String @unique
date String
api_key String
model String
model_group String?
custom_llm_provider String?
prompt_tokens Int @default(0)
completion_tokens Int @default(0)
cache_read_input_tokens Int @default(0)
cache_creation_input_tokens Int @default(0)
spend Float @default(0.0)
api_requests Int @default(0)
successful_requests Int @default(0)
failed_requests Int @default(0)
created_at DateTime @default(now())
updated_at DateTime @updatedAt
@@unique([tag, date, api_key, model, custom_llm_provider])
@@index([date])
@@index([tag])
@@index([api_key])
@@index([model])
}
// Track the status of cron jobs running. Only allow one pod to run the job at a time // Track the status of cron jobs running. Only allow one pod to run the job at a time
model LiteLLM_CronJob { model LiteLLM_CronJob {

View file

@ -2796,50 +2796,3 @@ def _premium_user_check():
"error": f"This feature is only available for LiteLLM Enterprise users. {CommonProxyErrors.not_premium_user.value}" "error": f"This feature is only available for LiteLLM Enterprise users. {CommonProxyErrors.not_premium_user.value}"
}, },
) )
async def _update_daily_spend_batch(prisma_client, spend_aggregates):
"""Helper function to update daily spend in batches"""
async with prisma_client.db.batch_() as batcher:
for (
user_id,
date,
api_key,
model,
model_group,
provider,
), metrics in spend_aggregates.items():
if not user_id: # Skip if no user_id
continue
batcher.litellm_dailyuserspend.upsert(
where={
"user_id_date_api_key_model_custom_llm_provider": {
"user_id": user_id,
"date": date,
"api_key": api_key,
"model": model,
"custom_llm_provider": provider,
}
},
data={
"create": {
"user_id": user_id,
"date": date,
"api_key": api_key,
"model": model,
"model_group": model_group,
"custom_llm_provider": provider,
"prompt_tokens": metrics["prompt_tokens"],
"completion_tokens": metrics["completion_tokens"],
"spend": metrics["spend"],
},
"update": {
"prompt_tokens": {"increment": metrics["prompt_tokens"]},
"completion_tokens": {
"increment": metrics["completion_tokens"]
},
"spend": {"increment": metrics["spend"]},
},
},
)

View file

@ -34,6 +34,7 @@ class ServiceTypes(str, enum.Enum):
IN_MEMORY_DAILY_SPEND_UPDATE_QUEUE = "in_memory_daily_spend_update_queue" IN_MEMORY_DAILY_SPEND_UPDATE_QUEUE = "in_memory_daily_spend_update_queue"
REDIS_DAILY_SPEND_UPDATE_QUEUE = "redis_daily_spend_update_queue" REDIS_DAILY_SPEND_UPDATE_QUEUE = "redis_daily_spend_update_queue"
REDIS_DAILY_TEAM_SPEND_UPDATE_QUEUE = "redis_daily_team_spend_update_queue" REDIS_DAILY_TEAM_SPEND_UPDATE_QUEUE = "redis_daily_team_spend_update_queue"
REDIS_DAILY_TAG_SPEND_UPDATE_QUEUE = "redis_daily_tag_spend_update_queue"
# spend update queue - current spend of key, user, team # spend update queue - current spend of key, user, team
IN_MEMORY_SPEND_UPDATE_QUEUE = "in_memory_spend_update_queue" IN_MEMORY_SPEND_UPDATE_QUEUE = "in_memory_spend_update_queue"
REDIS_SPEND_UPDATE_QUEUE = "redis_spend_update_queue" REDIS_SPEND_UPDATE_QUEUE = "redis_spend_update_queue"

View file

@ -353,6 +353,8 @@ model LiteLLM_DailyTeamSpend {
custom_llm_provider String? custom_llm_provider String?
prompt_tokens Int @default(0) prompt_tokens Int @default(0)
completion_tokens Int @default(0) completion_tokens Int @default(0)
cache_read_input_tokens Int @default(0)
cache_creation_input_tokens Int @default(0)
spend Float @default(0.0) spend Float @default(0.0)
api_requests Int @default(0) api_requests Int @default(0)
successful_requests Int @default(0) successful_requests Int @default(0)
@ -367,6 +369,33 @@ model LiteLLM_DailyTeamSpend {
@@index([model]) @@index([model])
} }
// Track daily team spend metrics per model and key
model LiteLLM_DailyTagSpend {
id String @id @default(uuid())
tag String @unique
date String
api_key String
model String
model_group String?
custom_llm_provider String?
prompt_tokens Int @default(0)
completion_tokens Int @default(0)
cache_read_input_tokens Int @default(0)
cache_creation_input_tokens Int @default(0)
spend Float @default(0.0)
api_requests Int @default(0)
successful_requests Int @default(0)
failed_requests Int @default(0)
created_at DateTime @default(now())
updated_at DateTime @updatedAt
@@unique([tag, date, api_key, model, custom_llm_provider])
@@index([date])
@@index([tag])
@@index([api_key])
@@index([model])
}
// Track the status of cron jobs running. Only allow one pod to run the job at a time // Track the status of cron jobs running. Only allow one pod to run the job at a time
model LiteLLM_CronJob { model LiteLLM_CronJob {