fix(proxy_server.py): support tracking org spend

currently works when org set for jwt auth
This commit is contained in:
Krrish Dholakia 2024-04-11 23:01:21 -07:00
parent de9258e700
commit aa5da4346a
6 changed files with 136 additions and 1 deletions

View file

@ -116,6 +116,7 @@ from litellm.proxy.hooks.prompt_injection_detection import (
from litellm.proxy.auth.auth_checks import (
common_checks,
get_end_user_object,
get_org_object,
get_team_object,
get_user_object,
allowed_routes_check,
@ -422,6 +423,14 @@ async def user_api_key_auth(
user_api_key_cache=user_api_key_cache,
)
# [OPTIONAL] track spend for an org id - `LiteLLM_OrganizationTable`
org_id = jwt_handler.get_org_id(token=valid_token, default_value=None)
if org_id is not None:
_ = await get_org_object(
org_id=org_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
)
# [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable`
user_object = None
user_id = jwt_handler.get_user_id(token=valid_token, default_value=None)
@ -515,6 +524,7 @@ async def user_api_key_auth(
team_models=team_object.models,
user_role="app_owner",
user_id=user_id,
org_id=org_id,
)
#### ELSE ####
if master_key is None:
@ -1233,6 +1243,7 @@ async def _PROXY_track_cost_callback(
end_user_id = proxy_server_request.get("body", {}).get("user", None)
user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None)
team_id = kwargs["litellm_params"]["metadata"].get("user_api_key_team_id", None)
org_id = kwargs["litellm_params"]["metadata"].get("user_api_key_org_id", None)
if kwargs.get("response_cost", None) is not None:
response_cost = kwargs["response_cost"]
user_api_key = kwargs["litellm_params"]["metadata"].get(
@ -1260,6 +1271,7 @@ async def _PROXY_track_cost_callback(
completion_response=completion_response,
start_time=start_time,
end_time=end_time,
org_id=org_id,
)
await update_cache(
@ -1321,6 +1333,7 @@ async def update_database(
completion_response=None,
start_time=None,
end_time=None,
org_id=None,
):
try:
global prisma_client
@ -1551,9 +1564,34 @@ async def update_database(
)
raise e
### UPDATE ORG SPEND ###
async def _update_org_db():
try:
verbose_proxy_logger.debug(
"adding spend to org db. Response cost: {}. org_id: {}.".format(
response_cost, org_id
)
)
if org_id is None:
verbose_proxy_logger.debug(
"track_cost_callback: org_id is None. Not tracking spend for org"
)
return
if prisma_client is not None:
prisma_client.org_list_transactons[org_id] = (
response_cost
+ prisma_client.org_list_transactons.get(org_id, 0)
)
except Exception as e:
verbose_proxy_logger.info(
f"Update Org DB failed to execute - {str(e)}\n{traceback.format_exc()}"
)
raise e
asyncio.create_task(_update_user_db())
asyncio.create_task(_update_key_db())
asyncio.create_task(_update_team_db())
asyncio.create_task(_update_org_db())
# asyncio.create_task(_insert_spend_log_to_db())
if disable_spend_logs == False:
await _insert_spend_log_to_db()
@ -3432,6 +3470,7 @@ async def chat_completion(
user_api_key_dict, "key_alias", None
)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["user_api_key_org_id"] = user_api_key_dict.org_id
data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)