forked from phoenix/litellm-mirror
Merge pull request #2978 from BerriAI/litellm_org_spend_tracking
fix(proxy_server.py): support tracking org spend
This commit is contained in:
commit
a1cb9a51b9
6 changed files with 136 additions and 1 deletions
|
@ -46,4 +46,5 @@ general_settings:
|
|||
litellm_jwtauth:
|
||||
admin_jwt_scope: "litellm_proxy_admin"
|
||||
public_key_ttl: os.environ/LITELLM_PUBLIC_KEY_TTL
|
||||
user_id_jwt_field: "sub"
|
||||
user_id_jwt_field: "sub"
|
||||
org_id_jwt_field: "azp"
|
|
@ -140,6 +140,7 @@ class LiteLLM_JWTAuth(LiteLLMBase):
|
|||
team_allowed_routes: List[
|
||||
Literal["openai_routes", "info_routes", "management_routes"]
|
||||
] = ["openai_routes", "info_routes"]
|
||||
org_id_jwt_field: Optional[str] = None
|
||||
user_id_jwt_field: Optional[str] = None
|
||||
end_user_id_jwt_field: Optional[str] = None
|
||||
public_key_ttl: float = 600
|
||||
|
@ -514,6 +515,7 @@ class LiteLLM_BudgetTable(LiteLLMBase):
|
|||
|
||||
|
||||
class NewOrganizationRequest(LiteLLM_BudgetTable):
|
||||
organization_id: Optional[str] = None
|
||||
organization_alias: str
|
||||
models: List = []
|
||||
budget_id: Optional[str] = None
|
||||
|
@ -522,6 +524,7 @@ class NewOrganizationRequest(LiteLLM_BudgetTable):
|
|||
class LiteLLM_OrganizationTable(LiteLLMBase):
|
||||
"""Represents user-controllable params for a LiteLLM_OrganizationTable record"""
|
||||
|
||||
organization_id: Optional[str] = None
|
||||
organization_alias: Optional[str] = None
|
||||
budget_id: str
|
||||
metadata: Optional[dict] = None
|
||||
|
@ -706,6 +709,8 @@ class LiteLLM_VerificationToken(LiteLLMBase):
|
|||
soft_budget_cooldown: bool = False
|
||||
litellm_budget_table: Optional[dict] = None
|
||||
|
||||
org_id: Optional[str] = None # org id for a given key
|
||||
|
||||
# hidden params used for parallel request limiting, not required to create a token
|
||||
user_id_rate_limits: Optional[dict] = None
|
||||
team_id_rate_limits: Optional[dict] = None
|
||||
|
|
|
@ -14,6 +14,7 @@ from litellm.proxy._types import (
|
|||
LiteLLM_JWTAuth,
|
||||
LiteLLM_TeamTable,
|
||||
LiteLLMRoutes,
|
||||
LiteLLM_OrganizationTable,
|
||||
)
|
||||
from typing import Optional, Literal, Union
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
|
@ -287,3 +288,41 @@ async def get_team_object(
|
|||
raise Exception(
|
||||
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
|
||||
)
|
||||
|
||||
|
||||
async def get_org_object(
|
||||
org_id: str,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
):
|
||||
"""
|
||||
- Check if org id in proxy Org Table
|
||||
- if valid, return LiteLLM_OrganizationTable object
|
||||
- if not, then raise an error
|
||||
"""
|
||||
if prisma_client is None:
|
||||
raise Exception(
|
||||
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
|
||||
)
|
||||
|
||||
# check if in cache
|
||||
cached_org_obj = user_api_key_cache.async_get_cache(key="org_id:{}".format(org_id))
|
||||
if cached_org_obj is not None:
|
||||
if isinstance(cached_org_obj, dict):
|
||||
return cached_org_obj
|
||||
elif isinstance(cached_org_obj, LiteLLM_OrganizationTable):
|
||||
return cached_org_obj
|
||||
# else, check db
|
||||
try:
|
||||
response = await prisma_client.db.litellm_organizationtable.find_unique(
|
||||
where={"organization_id": org_id}
|
||||
)
|
||||
|
||||
if response is None:
|
||||
raise Exception
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"Organization doesn't exist in db. Organization={org_id}. Create organization via `/organization/new` call."
|
||||
)
|
||||
|
|
|
@ -84,6 +84,16 @@ class JWTHandler:
|
|||
user_id = default_value
|
||||
return user_id
|
||||
|
||||
def get_org_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||
try:
|
||||
if self.litellm_jwtauth.org_id_jwt_field is not None:
|
||||
org_id = token[self.litellm_jwtauth.org_id_jwt_field]
|
||||
else:
|
||||
org_id = None
|
||||
except KeyError:
|
||||
org_id = default_value
|
||||
return org_id
|
||||
|
||||
def get_scopes(self, token: dict) -> list:
|
||||
try:
|
||||
if isinstance(token["scope"], str):
|
||||
|
|
|
@ -116,6 +116,7 @@ from litellm.proxy.hooks.prompt_injection_detection import (
|
|||
from litellm.proxy.auth.auth_checks import (
|
||||
common_checks,
|
||||
get_end_user_object,
|
||||
get_org_object,
|
||||
get_team_object,
|
||||
get_user_object,
|
||||
allowed_routes_check,
|
||||
|
@ -422,6 +423,14 @@ async def user_api_key_auth(
|
|||
user_api_key_cache=user_api_key_cache,
|
||||
)
|
||||
|
||||
# [OPTIONAL] track spend for an org id - `LiteLLM_OrganizationTable`
|
||||
org_id = jwt_handler.get_org_id(token=valid_token, default_value=None)
|
||||
if org_id is not None:
|
||||
_ = await get_org_object(
|
||||
org_id=org_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
)
|
||||
# [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable`
|
||||
user_object = None
|
||||
user_id = jwt_handler.get_user_id(token=valid_token, default_value=None)
|
||||
|
@ -515,6 +524,7 @@ async def user_api_key_auth(
|
|||
team_models=team_object.models,
|
||||
user_role="app_owner",
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
)
|
||||
#### ELSE ####
|
||||
if master_key is None:
|
||||
|
@ -1233,6 +1243,7 @@ async def _PROXY_track_cost_callback(
|
|||
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
||||
user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None)
|
||||
team_id = kwargs["litellm_params"]["metadata"].get("user_api_key_team_id", None)
|
||||
org_id = kwargs["litellm_params"]["metadata"].get("user_api_key_org_id", None)
|
||||
if kwargs.get("response_cost", None) is not None:
|
||||
response_cost = kwargs["response_cost"]
|
||||
user_api_key = kwargs["litellm_params"]["metadata"].get(
|
||||
|
@ -1260,6 +1271,7 @@ async def _PROXY_track_cost_callback(
|
|||
completion_response=completion_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
org_id=org_id,
|
||||
)
|
||||
|
||||
await update_cache(
|
||||
|
@ -1321,6 +1333,7 @@ async def update_database(
|
|||
completion_response=None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
org_id=None,
|
||||
):
|
||||
try:
|
||||
global prisma_client
|
||||
|
@ -1551,9 +1564,34 @@ async def update_database(
|
|||
)
|
||||
raise e
|
||||
|
||||
### UPDATE ORG SPEND ###
|
||||
async def _update_org_db():
|
||||
try:
|
||||
verbose_proxy_logger.debug(
|
||||
"adding spend to org db. Response cost: {}. org_id: {}.".format(
|
||||
response_cost, org_id
|
||||
)
|
||||
)
|
||||
if org_id is None:
|
||||
verbose_proxy_logger.debug(
|
||||
"track_cost_callback: org_id is None. Not tracking spend for org"
|
||||
)
|
||||
return
|
||||
if prisma_client is not None:
|
||||
prisma_client.org_list_transactons[org_id] = (
|
||||
response_cost
|
||||
+ prisma_client.org_list_transactons.get(org_id, 0)
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.info(
|
||||
f"Update Org DB failed to execute - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
raise e
|
||||
|
||||
asyncio.create_task(_update_user_db())
|
||||
asyncio.create_task(_update_key_db())
|
||||
asyncio.create_task(_update_team_db())
|
||||
asyncio.create_task(_update_org_db())
|
||||
# asyncio.create_task(_insert_spend_log_to_db())
|
||||
if disable_spend_logs == False:
|
||||
await _insert_spend_log_to_db()
|
||||
|
@ -3432,6 +3470,7 @@ async def chat_completion(
|
|||
user_api_key_dict, "key_alias", None
|
||||
)
|
||||
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||
data["metadata"]["user_api_key_org_id"] = user_api_key_dict.org_id
|
||||
data["metadata"]["user_api_key_team_id"] = getattr(
|
||||
user_api_key_dict, "team_id", None
|
||||
)
|
||||
|
|
|
@ -567,6 +567,7 @@ class PrismaClient:
|
|||
end_user_list_transactons: dict = {}
|
||||
key_list_transactons: dict = {}
|
||||
team_list_transactons: dict = {}
|
||||
org_list_transactons: dict = {}
|
||||
spend_log_transactions: List = []
|
||||
|
||||
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
|
||||
|
@ -2150,6 +2151,46 @@ async def update_spend(
|
|||
)
|
||||
raise e
|
||||
|
||||
### UPDATE ORG TABLE ###
|
||||
if len(prisma_client.org_list_transactons.keys()) > 0:
|
||||
for i in range(n_retry_times + 1):
|
||||
try:
|
||||
async with prisma_client.db.tx(
|
||||
timeout=timedelta(seconds=60)
|
||||
) as transaction:
|
||||
async with transaction.batch_() as batcher:
|
||||
for (
|
||||
org_id,
|
||||
response_cost,
|
||||
) in prisma_client.org_list_transactons.items():
|
||||
batcher.litellm_organizationtable.update_many( # 'update_many' prevents error from being raised if no row exists
|
||||
where={"organization_id": org_id},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
prisma_client.org_list_transactons = (
|
||||
{}
|
||||
) # Clear the remaining transactions after processing all batches in the loop.
|
||||
break
|
||||
except httpx.ReadTimeout:
|
||||
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||
raise # Re-raise the last exception
|
||||
# Optionally, sleep for a bit before retrying
|
||||
await asyncio.sleep(2**i) # Exponential backoff
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
error_msg = (
|
||||
f"LiteLLM Prisma Client Exception - update org spend: {str(e)}"
|
||||
)
|
||||
print_verbose(error_msg)
|
||||
error_traceback = error_msg + "\n" + traceback.format_exc()
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.failure_handler(
|
||||
original_exception=e, traceback_str=error_traceback
|
||||
)
|
||||
)
|
||||
raise e
|
||||
|
||||
### UPDATE SPEND LOGS ###
|
||||
verbose_proxy_logger.debug(
|
||||
"Spend Logs transactions: {}".format(len(prisma_client.spend_log_transactions))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue