forked from phoenix/litellm-mirror
Merge pull request #2978 from BerriAI/litellm_org_spend_tracking
fix(proxy_server.py): support tracking org spend
This commit is contained in:
commit
a1cb9a51b9
6 changed files with 136 additions and 1 deletions
|
@ -46,4 +46,5 @@ general_settings:
|
||||||
litellm_jwtauth:
|
litellm_jwtauth:
|
||||||
admin_jwt_scope: "litellm_proxy_admin"
|
admin_jwt_scope: "litellm_proxy_admin"
|
||||||
public_key_ttl: os.environ/LITELLM_PUBLIC_KEY_TTL
|
public_key_ttl: os.environ/LITELLM_PUBLIC_KEY_TTL
|
||||||
user_id_jwt_field: "sub"
|
user_id_jwt_field: "sub"
|
||||||
|
org_id_jwt_field: "azp"
|
|
@ -140,6 +140,7 @@ class LiteLLM_JWTAuth(LiteLLMBase):
|
||||||
team_allowed_routes: List[
|
team_allowed_routes: List[
|
||||||
Literal["openai_routes", "info_routes", "management_routes"]
|
Literal["openai_routes", "info_routes", "management_routes"]
|
||||||
] = ["openai_routes", "info_routes"]
|
] = ["openai_routes", "info_routes"]
|
||||||
|
org_id_jwt_field: Optional[str] = None
|
||||||
user_id_jwt_field: Optional[str] = None
|
user_id_jwt_field: Optional[str] = None
|
||||||
end_user_id_jwt_field: Optional[str] = None
|
end_user_id_jwt_field: Optional[str] = None
|
||||||
public_key_ttl: float = 600
|
public_key_ttl: float = 600
|
||||||
|
@ -514,6 +515,7 @@ class LiteLLM_BudgetTable(LiteLLMBase):
|
||||||
|
|
||||||
|
|
||||||
class NewOrganizationRequest(LiteLLM_BudgetTable):
|
class NewOrganizationRequest(LiteLLM_BudgetTable):
|
||||||
|
organization_id: Optional[str] = None
|
||||||
organization_alias: str
|
organization_alias: str
|
||||||
models: List = []
|
models: List = []
|
||||||
budget_id: Optional[str] = None
|
budget_id: Optional[str] = None
|
||||||
|
@ -522,6 +524,7 @@ class NewOrganizationRequest(LiteLLM_BudgetTable):
|
||||||
class LiteLLM_OrganizationTable(LiteLLMBase):
|
class LiteLLM_OrganizationTable(LiteLLMBase):
|
||||||
"""Represents user-controllable params for a LiteLLM_OrganizationTable record"""
|
"""Represents user-controllable params for a LiteLLM_OrganizationTable record"""
|
||||||
|
|
||||||
|
organization_id: Optional[str] = None
|
||||||
organization_alias: Optional[str] = None
|
organization_alias: Optional[str] = None
|
||||||
budget_id: str
|
budget_id: str
|
||||||
metadata: Optional[dict] = None
|
metadata: Optional[dict] = None
|
||||||
|
@ -706,6 +709,8 @@ class LiteLLM_VerificationToken(LiteLLMBase):
|
||||||
soft_budget_cooldown: bool = False
|
soft_budget_cooldown: bool = False
|
||||||
litellm_budget_table: Optional[dict] = None
|
litellm_budget_table: Optional[dict] = None
|
||||||
|
|
||||||
|
org_id: Optional[str] = None # org id for a given key
|
||||||
|
|
||||||
# hidden params used for parallel request limiting, not required to create a token
|
# hidden params used for parallel request limiting, not required to create a token
|
||||||
user_id_rate_limits: Optional[dict] = None
|
user_id_rate_limits: Optional[dict] = None
|
||||||
team_id_rate_limits: Optional[dict] = None
|
team_id_rate_limits: Optional[dict] = None
|
||||||
|
|
|
@ -14,6 +14,7 @@ from litellm.proxy._types import (
|
||||||
LiteLLM_JWTAuth,
|
LiteLLM_JWTAuth,
|
||||||
LiteLLM_TeamTable,
|
LiteLLM_TeamTable,
|
||||||
LiteLLMRoutes,
|
LiteLLMRoutes,
|
||||||
|
LiteLLM_OrganizationTable,
|
||||||
)
|
)
|
||||||
from typing import Optional, Literal, Union
|
from typing import Optional, Literal, Union
|
||||||
from litellm.proxy.utils import PrismaClient
|
from litellm.proxy.utils import PrismaClient
|
||||||
|
@ -287,3 +288,41 @@ async def get_team_object(
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
|
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_org_object(
|
||||||
|
org_id: str,
|
||||||
|
prisma_client: Optional[PrismaClient],
|
||||||
|
user_api_key_cache: DualCache,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
- Check if org id in proxy Org Table
|
||||||
|
- if valid, return LiteLLM_OrganizationTable object
|
||||||
|
- if not, then raise an error
|
||||||
|
"""
|
||||||
|
if prisma_client is None:
|
||||||
|
raise Exception(
|
||||||
|
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
|
||||||
|
)
|
||||||
|
|
||||||
|
# check if in cache
|
||||||
|
cached_org_obj = user_api_key_cache.async_get_cache(key="org_id:{}".format(org_id))
|
||||||
|
if cached_org_obj is not None:
|
||||||
|
if isinstance(cached_org_obj, dict):
|
||||||
|
return cached_org_obj
|
||||||
|
elif isinstance(cached_org_obj, LiteLLM_OrganizationTable):
|
||||||
|
return cached_org_obj
|
||||||
|
# else, check db
|
||||||
|
try:
|
||||||
|
response = await prisma_client.db.litellm_organizationtable.find_unique(
|
||||||
|
where={"organization_id": org_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
raise Exception
|
||||||
|
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(
|
||||||
|
f"Organization doesn't exist in db. Organization={org_id}. Create organization via `/organization/new` call."
|
||||||
|
)
|
||||||
|
|
|
@ -84,6 +84,16 @@ class JWTHandler:
|
||||||
user_id = default_value
|
user_id = default_value
|
||||||
return user_id
|
return user_id
|
||||||
|
|
||||||
|
def get_org_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||||
|
try:
|
||||||
|
if self.litellm_jwtauth.org_id_jwt_field is not None:
|
||||||
|
org_id = token[self.litellm_jwtauth.org_id_jwt_field]
|
||||||
|
else:
|
||||||
|
org_id = None
|
||||||
|
except KeyError:
|
||||||
|
org_id = default_value
|
||||||
|
return org_id
|
||||||
|
|
||||||
def get_scopes(self, token: dict) -> list:
|
def get_scopes(self, token: dict) -> list:
|
||||||
try:
|
try:
|
||||||
if isinstance(token["scope"], str):
|
if isinstance(token["scope"], str):
|
||||||
|
|
|
@ -116,6 +116,7 @@ from litellm.proxy.hooks.prompt_injection_detection import (
|
||||||
from litellm.proxy.auth.auth_checks import (
|
from litellm.proxy.auth.auth_checks import (
|
||||||
common_checks,
|
common_checks,
|
||||||
get_end_user_object,
|
get_end_user_object,
|
||||||
|
get_org_object,
|
||||||
get_team_object,
|
get_team_object,
|
||||||
get_user_object,
|
get_user_object,
|
||||||
allowed_routes_check,
|
allowed_routes_check,
|
||||||
|
@ -422,6 +423,14 @@ async def user_api_key_auth(
|
||||||
user_api_key_cache=user_api_key_cache,
|
user_api_key_cache=user_api_key_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# [OPTIONAL] track spend for an org id - `LiteLLM_OrganizationTable`
|
||||||
|
org_id = jwt_handler.get_org_id(token=valid_token, default_value=None)
|
||||||
|
if org_id is not None:
|
||||||
|
_ = await get_org_object(
|
||||||
|
org_id=org_id,
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
)
|
||||||
# [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable`
|
# [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable`
|
||||||
user_object = None
|
user_object = None
|
||||||
user_id = jwt_handler.get_user_id(token=valid_token, default_value=None)
|
user_id = jwt_handler.get_user_id(token=valid_token, default_value=None)
|
||||||
|
@ -515,6 +524,7 @@ async def user_api_key_auth(
|
||||||
team_models=team_object.models,
|
team_models=team_object.models,
|
||||||
user_role="app_owner",
|
user_role="app_owner",
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
org_id=org_id,
|
||||||
)
|
)
|
||||||
#### ELSE ####
|
#### ELSE ####
|
||||||
if master_key is None:
|
if master_key is None:
|
||||||
|
@ -1233,6 +1243,7 @@ async def _PROXY_track_cost_callback(
|
||||||
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
||||||
user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None)
|
user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None)
|
||||||
team_id = kwargs["litellm_params"]["metadata"].get("user_api_key_team_id", None)
|
team_id = kwargs["litellm_params"]["metadata"].get("user_api_key_team_id", None)
|
||||||
|
org_id = kwargs["litellm_params"]["metadata"].get("user_api_key_org_id", None)
|
||||||
if kwargs.get("response_cost", None) is not None:
|
if kwargs.get("response_cost", None) is not None:
|
||||||
response_cost = kwargs["response_cost"]
|
response_cost = kwargs["response_cost"]
|
||||||
user_api_key = kwargs["litellm_params"]["metadata"].get(
|
user_api_key = kwargs["litellm_params"]["metadata"].get(
|
||||||
|
@ -1260,6 +1271,7 @@ async def _PROXY_track_cost_callback(
|
||||||
completion_response=completion_response,
|
completion_response=completion_response,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
|
org_id=org_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
await update_cache(
|
await update_cache(
|
||||||
|
@ -1321,6 +1333,7 @@ async def update_database(
|
||||||
completion_response=None,
|
completion_response=None,
|
||||||
start_time=None,
|
start_time=None,
|
||||||
end_time=None,
|
end_time=None,
|
||||||
|
org_id=None,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
global prisma_client
|
global prisma_client
|
||||||
|
@ -1551,9 +1564,34 @@ async def update_database(
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
### UPDATE ORG SPEND ###
|
||||||
|
async def _update_org_db():
|
||||||
|
try:
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"adding spend to org db. Response cost: {}. org_id: {}.".format(
|
||||||
|
response_cost, org_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if org_id is None:
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"track_cost_callback: org_id is None. Not tracking spend for org"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
if prisma_client is not None:
|
||||||
|
prisma_client.org_list_transactons[org_id] = (
|
||||||
|
response_cost
|
||||||
|
+ prisma_client.org_list_transactons.get(org_id, 0)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.info(
|
||||||
|
f"Update Org DB failed to execute - {str(e)}\n{traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
asyncio.create_task(_update_user_db())
|
asyncio.create_task(_update_user_db())
|
||||||
asyncio.create_task(_update_key_db())
|
asyncio.create_task(_update_key_db())
|
||||||
asyncio.create_task(_update_team_db())
|
asyncio.create_task(_update_team_db())
|
||||||
|
asyncio.create_task(_update_org_db())
|
||||||
# asyncio.create_task(_insert_spend_log_to_db())
|
# asyncio.create_task(_insert_spend_log_to_db())
|
||||||
if disable_spend_logs == False:
|
if disable_spend_logs == False:
|
||||||
await _insert_spend_log_to_db()
|
await _insert_spend_log_to_db()
|
||||||
|
@ -3432,6 +3470,7 @@ async def chat_completion(
|
||||||
user_api_key_dict, "key_alias", None
|
user_api_key_dict, "key_alias", None
|
||||||
)
|
)
|
||||||
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||||
|
data["metadata"]["user_api_key_org_id"] = user_api_key_dict.org_id
|
||||||
data["metadata"]["user_api_key_team_id"] = getattr(
|
data["metadata"]["user_api_key_team_id"] = getattr(
|
||||||
user_api_key_dict, "team_id", None
|
user_api_key_dict, "team_id", None
|
||||||
)
|
)
|
||||||
|
|
|
@ -567,6 +567,7 @@ class PrismaClient:
|
||||||
end_user_list_transactons: dict = {}
|
end_user_list_transactons: dict = {}
|
||||||
key_list_transactons: dict = {}
|
key_list_transactons: dict = {}
|
||||||
team_list_transactons: dict = {}
|
team_list_transactons: dict = {}
|
||||||
|
org_list_transactons: dict = {}
|
||||||
spend_log_transactions: List = []
|
spend_log_transactions: List = []
|
||||||
|
|
||||||
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
|
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
|
||||||
|
@ -2150,6 +2151,46 @@ async def update_spend(
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
### UPDATE ORG TABLE ###
|
||||||
|
if len(prisma_client.org_list_transactons.keys()) > 0:
|
||||||
|
for i in range(n_retry_times + 1):
|
||||||
|
try:
|
||||||
|
async with prisma_client.db.tx(
|
||||||
|
timeout=timedelta(seconds=60)
|
||||||
|
) as transaction:
|
||||||
|
async with transaction.batch_() as batcher:
|
||||||
|
for (
|
||||||
|
org_id,
|
||||||
|
response_cost,
|
||||||
|
) in prisma_client.org_list_transactons.items():
|
||||||
|
batcher.litellm_organizationtable.update_many( # 'update_many' prevents error from being raised if no row exists
|
||||||
|
where={"organization_id": org_id},
|
||||||
|
data={"spend": {"increment": response_cost}},
|
||||||
|
)
|
||||||
|
prisma_client.org_list_transactons = (
|
||||||
|
{}
|
||||||
|
) # Clear the remaining transactions after processing all batches in the loop.
|
||||||
|
break
|
||||||
|
except httpx.ReadTimeout:
|
||||||
|
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||||
|
raise # Re-raise the last exception
|
||||||
|
# Optionally, sleep for a bit before retrying
|
||||||
|
await asyncio.sleep(2**i) # Exponential backoff
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
error_msg = (
|
||||||
|
f"LiteLLM Prisma Client Exception - update org spend: {str(e)}"
|
||||||
|
)
|
||||||
|
print_verbose(error_msg)
|
||||||
|
error_traceback = error_msg + "\n" + traceback.format_exc()
|
||||||
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.failure_handler(
|
||||||
|
original_exception=e, traceback_str=error_traceback
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
### UPDATE SPEND LOGS ###
|
### UPDATE SPEND LOGS ###
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
"Spend Logs transactions: {}".format(len(prisma_client.spend_log_transactions))
|
"Spend Logs transactions: {}".format(len(prisma_client.spend_log_transactions))
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue