Merge pull request #2978 from BerriAI/litellm_org_spend_tracking

fix(proxy_server.py): support tracking org spend
This commit is contained in:
Krish Dholakia 2024-04-11 23:19:33 -07:00 committed by GitHub
commit a1cb9a51b9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 136 additions and 1 deletions

View file

@ -46,4 +46,5 @@ general_settings:
litellm_jwtauth:
admin_jwt_scope: "litellm_proxy_admin"
public_key_ttl: os.environ/LITELLM_PUBLIC_KEY_TTL
user_id_jwt_field: "sub"
user_id_jwt_field: "sub"
org_id_jwt_field: "azp"

View file

@ -140,6 +140,7 @@ class LiteLLM_JWTAuth(LiteLLMBase):
team_allowed_routes: List[
Literal["openai_routes", "info_routes", "management_routes"]
] = ["openai_routes", "info_routes"]
org_id_jwt_field: Optional[str] = None
user_id_jwt_field: Optional[str] = None
end_user_id_jwt_field: Optional[str] = None
public_key_ttl: float = 600
@ -514,6 +515,7 @@ class LiteLLM_BudgetTable(LiteLLMBase):
class NewOrganizationRequest(LiteLLM_BudgetTable):
organization_id: Optional[str] = None
organization_alias: str
models: List = []
budget_id: Optional[str] = None
@ -522,6 +524,7 @@ class NewOrganizationRequest(LiteLLM_BudgetTable):
class LiteLLM_OrganizationTable(LiteLLMBase):
"""Represents user-controllable params for a LiteLLM_OrganizationTable record"""
organization_id: Optional[str] = None
organization_alias: Optional[str] = None
budget_id: str
metadata: Optional[dict] = None
@ -706,6 +709,8 @@ class LiteLLM_VerificationToken(LiteLLMBase):
soft_budget_cooldown: bool = False
litellm_budget_table: Optional[dict] = None
org_id: Optional[str] = None # org id for a given key
# hidden params used for parallel request limiting, not required to create a token
user_id_rate_limits: Optional[dict] = None
team_id_rate_limits: Optional[dict] = None

View file

@ -14,6 +14,7 @@ from litellm.proxy._types import (
LiteLLM_JWTAuth,
LiteLLM_TeamTable,
LiteLLMRoutes,
LiteLLM_OrganizationTable,
)
from typing import Optional, Literal, Union
from litellm.proxy.utils import PrismaClient
@ -287,3 +288,41 @@ async def get_team_object(
raise Exception(
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
)
async def get_org_object(
org_id: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
):
"""
- Check if org id in proxy Org Table
- if valid, return LiteLLM_OrganizationTable object
- if not, then raise an error
"""
if prisma_client is None:
raise Exception(
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
)
# check if in cache
cached_org_obj = user_api_key_cache.async_get_cache(key="org_id:{}".format(org_id))
if cached_org_obj is not None:
if isinstance(cached_org_obj, dict):
return cached_org_obj
elif isinstance(cached_org_obj, LiteLLM_OrganizationTable):
return cached_org_obj
# else, check db
try:
response = await prisma_client.db.litellm_organizationtable.find_unique(
where={"organization_id": org_id}
)
if response is None:
raise Exception
return response
except Exception as e:
raise Exception(
f"Organization doesn't exist in db. Organization={org_id}. Create organization via `/organization/new` call."
)

View file

@ -84,6 +84,16 @@ class JWTHandler:
user_id = default_value
return user_id
def get_org_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
try:
if self.litellm_jwtauth.org_id_jwt_field is not None:
org_id = token[self.litellm_jwtauth.org_id_jwt_field]
else:
org_id = None
except KeyError:
org_id = default_value
return org_id
def get_scopes(self, token: dict) -> list:
try:
if isinstance(token["scope"], str):

View file

@ -116,6 +116,7 @@ from litellm.proxy.hooks.prompt_injection_detection import (
from litellm.proxy.auth.auth_checks import (
common_checks,
get_end_user_object,
get_org_object,
get_team_object,
get_user_object,
allowed_routes_check,
@ -422,6 +423,14 @@ async def user_api_key_auth(
user_api_key_cache=user_api_key_cache,
)
# [OPTIONAL] track spend for an org id - `LiteLLM_OrganizationTable`
org_id = jwt_handler.get_org_id(token=valid_token, default_value=None)
if org_id is not None:
_ = await get_org_object(
org_id=org_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
)
# [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable`
user_object = None
user_id = jwt_handler.get_user_id(token=valid_token, default_value=None)
@ -515,6 +524,7 @@ async def user_api_key_auth(
team_models=team_object.models,
user_role="app_owner",
user_id=user_id,
org_id=org_id,
)
#### ELSE ####
if master_key is None:
@ -1233,6 +1243,7 @@ async def _PROXY_track_cost_callback(
end_user_id = proxy_server_request.get("body", {}).get("user", None)
user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None)
team_id = kwargs["litellm_params"]["metadata"].get("user_api_key_team_id", None)
org_id = kwargs["litellm_params"]["metadata"].get("user_api_key_org_id", None)
if kwargs.get("response_cost", None) is not None:
response_cost = kwargs["response_cost"]
user_api_key = kwargs["litellm_params"]["metadata"].get(
@ -1260,6 +1271,7 @@ async def _PROXY_track_cost_callback(
completion_response=completion_response,
start_time=start_time,
end_time=end_time,
org_id=org_id,
)
await update_cache(
@ -1321,6 +1333,7 @@ async def update_database(
completion_response=None,
start_time=None,
end_time=None,
org_id=None,
):
try:
global prisma_client
@ -1551,9 +1564,34 @@ async def update_database(
)
raise e
### UPDATE ORG SPEND ###
async def _update_org_db():
try:
verbose_proxy_logger.debug(
"adding spend to org db. Response cost: {}. org_id: {}.".format(
response_cost, org_id
)
)
if org_id is None:
verbose_proxy_logger.debug(
"track_cost_callback: org_id is None. Not tracking spend for org"
)
return
if prisma_client is not None:
prisma_client.org_list_transactons[org_id] = (
response_cost
+ prisma_client.org_list_transactons.get(org_id, 0)
)
except Exception as e:
verbose_proxy_logger.info(
f"Update Org DB failed to execute - {str(e)}\n{traceback.format_exc()}"
)
raise e
asyncio.create_task(_update_user_db())
asyncio.create_task(_update_key_db())
asyncio.create_task(_update_team_db())
asyncio.create_task(_update_org_db())
# asyncio.create_task(_insert_spend_log_to_db())
if disable_spend_logs == False:
await _insert_spend_log_to_db()
@ -3432,6 +3470,7 @@ async def chat_completion(
user_api_key_dict, "key_alias", None
)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["user_api_key_org_id"] = user_api_key_dict.org_id
data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)

View file

@ -567,6 +567,7 @@ class PrismaClient:
end_user_list_transactons: dict = {}
key_list_transactons: dict = {}
team_list_transactons: dict = {}
org_list_transactons: dict = {}
spend_log_transactions: List = []
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
@ -2150,6 +2151,46 @@ async def update_spend(
)
raise e
### UPDATE ORG TABLE ###
if len(prisma_client.org_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
) as transaction:
async with transaction.batch_() as batcher:
for (
org_id,
response_cost,
) in prisma_client.org_list_transactons.items():
batcher.litellm_organizationtable.update_many( # 'update_many' prevents error from being raised if no row exists
where={"organization_id": org_id},
data={"spend": {"increment": response_cost}},
)
prisma_client.org_list_transactons = (
{}
) # Clear the remaining transactions after processing all batches in the loop.
break
except httpx.ReadTimeout:
if i >= n_retry_times: # If we've reached the maximum number of retries
raise # Re-raise the last exception
# Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # Exponential backoff
except Exception as e:
import traceback
error_msg = (
f"LiteLLM Prisma Client Exception - update org spend: {str(e)}"
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
)
)
raise e
### UPDATE SPEND LOGS ###
verbose_proxy_logger.debug(
"Spend Logs transactions: {}".format(len(prisma_client.spend_log_transactions))