diff --git a/litellm/proxy/spend_tracking/spend_management_endpoints.py b/litellm/proxy/spend_tracking/spend_management_endpoints.py index f6d36daaf..e0fa1e092 100644 --- a/litellm/proxy/spend_tracking/spend_management_endpoints.py +++ b/litellm/proxy/spend_tracking/spend_management_endpoints.py @@ -9,6 +9,9 @@ import litellm from litellm._logging import verbose_proxy_logger from litellm.proxy._types import * from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.spend_tracking.spend_tracking_utils import ( + get_spend_by_team_and_customer, +) router = APIRouter() @@ -932,6 +935,14 @@ async def get_global_spend_report( default=None, description="View spend for a specific internal_user_id. Example internal_user_id='1234", ), + team_id: Optional[str] = fastapi.Query( + default=None, + description="View spend for a specific team_id. Example team_id='1234", + ), + customer_id: Optional[str] = fastapi.Query( + default=None, + description="View spend for a specific customer_id. Example customer_id='1234. Can be used in conjunction with team_id as well.", + ), ): """ Get Daily Spend per Team, based on specific startTime and endTime. Per team, view usage by each key, model @@ -1074,8 +1085,12 @@ async def get_global_spend_report( return [] return db_response - + elif team_id is not None and customer_id is not None: + return await get_spend_by_team_and_customer( + start_date_obj, end_date_obj, team_id, customer_id, prisma_client + ) if group_by == "team": + # first get data from spend logs -> SpendByModelApiKey # then read data from "SpendByModelApiKey" to format the response obj sql_query = """ @@ -1305,7 +1320,6 @@ async def global_get_all_tag_names(): "/global/spend/tags", tags=["Budget & Spend Tracking"], dependencies=[Depends(user_api_key_auth)], - include_in_schema=False, responses={ 200: {"model": List[LiteLLM_SpendLogs]}, }, diff --git a/litellm/proxy/spend_tracking/spend_tracking_utils.py b/litellm/proxy/spend_tracking/spend_tracking_utils.py index 30e3ae5cd..48924d521 100644 --- a/litellm/proxy/spend_tracking/spend_tracking_utils.py +++ b/litellm/proxy/spend_tracking/spend_tracking_utils.py @@ -1,7 +1,9 @@ +import datetime import json import os import secrets import traceback +from datetime import datetime as dt from typing import Optional from pydantic import BaseModel @@ -9,7 +11,7 @@ from pydantic import BaseModel import litellm from litellm._logging import verbose_proxy_logger from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload -from litellm.proxy.utils import hash_token +from litellm.proxy.utils import PrismaClient, hash_token def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool: @@ -163,3 +165,79 @@ def get_logging_payload( "Error creating spendlogs object - {}".format(str(e)) ) raise e + + +async def get_spend_by_team_and_customer( + start_date: dt, + end_date: dt, + team_id: str, + customer_id: str, + prisma_client: PrismaClient, +): + sql_query = """ + WITH SpendByModelApiKey AS ( + SELECT + date_trunc('day', sl."startTime") AS group_by_day, + COALESCE(tt.team_alias, 'Unassigned Team') AS team_name, + sl.end_user AS customer, + sl.model, + sl.api_key, + SUM(sl.spend) AS model_api_spend, + SUM(sl.total_tokens) AS model_api_tokens + FROM + "LiteLLM_SpendLogs" sl + LEFT JOIN + "LiteLLM_TeamTable" tt + ON + sl.team_id = tt.team_id + WHERE + sl."startTime" BETWEEN $1::date AND $2::date + AND sl.team_id = $3 + AND sl.end_user = $4 + GROUP BY + date_trunc('day', sl."startTime"), + tt.team_alias, + sl.end_user, + sl.model, + sl.api_key + ) + SELECT + group_by_day, + jsonb_agg(jsonb_build_object( + 'team_name', team_name, + 'customer', customer, + 'total_spend', total_spend, + 'metadata', metadata + )) AS teams_customers + FROM ( + SELECT + group_by_day, + team_name, + customer, + SUM(model_api_spend) AS total_spend, + jsonb_agg(jsonb_build_object( + 'model', model, + 'api_key', api_key, + 'spend', model_api_spend, + 'total_tokens', model_api_tokens + )) AS metadata + FROM + SpendByModelApiKey + GROUP BY + group_by_day, + team_name, + customer + ) AS aggregated + GROUP BY + group_by_day + ORDER BY + group_by_day; + """ + + db_response = await prisma_client.db.query_raw( + sql_query, start_date, end_date, team_id, customer_id + ) + if db_response is None: + return [] + + return db_response