mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix(proxy_server.py): support tracking org spend
currently works when org set for jwt auth
This commit is contained in:
parent
de9258e700
commit
aa5da4346a
6 changed files with 136 additions and 1 deletions
|
@ -116,6 +116,7 @@ from litellm.proxy.hooks.prompt_injection_detection import (
|
|||
from litellm.proxy.auth.auth_checks import (
|
||||
common_checks,
|
||||
get_end_user_object,
|
||||
get_org_object,
|
||||
get_team_object,
|
||||
get_user_object,
|
||||
allowed_routes_check,
|
||||
|
@ -422,6 +423,14 @@ async def user_api_key_auth(
|
|||
user_api_key_cache=user_api_key_cache,
|
||||
)
|
||||
|
||||
# [OPTIONAL] track spend for an org id - `LiteLLM_OrganizationTable`
|
||||
org_id = jwt_handler.get_org_id(token=valid_token, default_value=None)
|
||||
if org_id is not None:
|
||||
_ = await get_org_object(
|
||||
org_id=org_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
)
|
||||
# [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable`
|
||||
user_object = None
|
||||
user_id = jwt_handler.get_user_id(token=valid_token, default_value=None)
|
||||
|
@ -515,6 +524,7 @@ async def user_api_key_auth(
|
|||
team_models=team_object.models,
|
||||
user_role="app_owner",
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
)
|
||||
#### ELSE ####
|
||||
if master_key is None:
|
||||
|
@ -1233,6 +1243,7 @@ async def _PROXY_track_cost_callback(
|
|||
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
||||
user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None)
|
||||
team_id = kwargs["litellm_params"]["metadata"].get("user_api_key_team_id", None)
|
||||
org_id = kwargs["litellm_params"]["metadata"].get("user_api_key_org_id", None)
|
||||
if kwargs.get("response_cost", None) is not None:
|
||||
response_cost = kwargs["response_cost"]
|
||||
user_api_key = kwargs["litellm_params"]["metadata"].get(
|
||||
|
@ -1260,6 +1271,7 @@ async def _PROXY_track_cost_callback(
|
|||
completion_response=completion_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
org_id=org_id,
|
||||
)
|
||||
|
||||
await update_cache(
|
||||
|
@ -1321,6 +1333,7 @@ async def update_database(
|
|||
completion_response=None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
org_id=None,
|
||||
):
|
||||
try:
|
||||
global prisma_client
|
||||
|
@ -1551,9 +1564,34 @@ async def update_database(
|
|||
)
|
||||
raise e
|
||||
|
||||
### UPDATE ORG SPEND ###
|
||||
async def _update_org_db():
|
||||
try:
|
||||
verbose_proxy_logger.debug(
|
||||
"adding spend to org db. Response cost: {}. org_id: {}.".format(
|
||||
response_cost, org_id
|
||||
)
|
||||
)
|
||||
if org_id is None:
|
||||
verbose_proxy_logger.debug(
|
||||
"track_cost_callback: org_id is None. Not tracking spend for org"
|
||||
)
|
||||
return
|
||||
if prisma_client is not None:
|
||||
prisma_client.org_list_transactons[org_id] = (
|
||||
response_cost
|
||||
+ prisma_client.org_list_transactons.get(org_id, 0)
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.info(
|
||||
f"Update Org DB failed to execute - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
raise e
|
||||
|
||||
asyncio.create_task(_update_user_db())
|
||||
asyncio.create_task(_update_key_db())
|
||||
asyncio.create_task(_update_team_db())
|
||||
asyncio.create_task(_update_org_db())
|
||||
# asyncio.create_task(_insert_spend_log_to_db())
|
||||
if disable_spend_logs == False:
|
||||
await _insert_spend_log_to_db()
|
||||
|
@ -3432,6 +3470,7 @@ async def chat_completion(
|
|||
user_api_key_dict, "key_alias", None
|
||||
)
|
||||
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||
data["metadata"]["user_api_key_org_id"] = user_api_key_dict.org_id
|
||||
data["metadata"]["user_api_key_team_id"] = getattr(
|
||||
user_api_key_dict, "team_id", None
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue