diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index f41a0bdcd..b1fb9081d 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -46,4 +46,5 @@ general_settings: litellm_jwtauth: admin_jwt_scope: "litellm_proxy_admin" public_key_ttl: os.environ/LITELLM_PUBLIC_KEY_TTL - user_id_jwt_field: "sub" \ No newline at end of file + user_id_jwt_field: "sub" + org_id_jwt_field: "azp" \ No newline at end of file diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 904f930c3..00256ed87 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -140,6 +140,7 @@ class LiteLLM_JWTAuth(LiteLLMBase): team_allowed_routes: List[ Literal["openai_routes", "info_routes", "management_routes"] ] = ["openai_routes", "info_routes"] + org_id_jwt_field: Optional[str] = None user_id_jwt_field: Optional[str] = None end_user_id_jwt_field: Optional[str] = None public_key_ttl: float = 600 @@ -514,6 +515,7 @@ class LiteLLM_BudgetTable(LiteLLMBase): class NewOrganizationRequest(LiteLLM_BudgetTable): + organization_id: Optional[str] = None organization_alias: str models: List = [] budget_id: Optional[str] = None @@ -522,6 +524,7 @@ class NewOrganizationRequest(LiteLLM_BudgetTable): class LiteLLM_OrganizationTable(LiteLLMBase): """Represents user-controllable params for a LiteLLM_OrganizationTable record""" + organization_id: Optional[str] = None organization_alias: Optional[str] = None budget_id: str metadata: Optional[dict] = None @@ -706,6 +709,8 @@ class LiteLLM_VerificationToken(LiteLLMBase): soft_budget_cooldown: bool = False litellm_budget_table: Optional[dict] = None + org_id: Optional[str] = None # org id for a given key + # hidden params used for parallel request limiting, not required to create a token user_id_rate_limits: Optional[dict] = None team_id_rate_limits: Optional[dict] = None diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index c56c48365..96998ee39 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -14,6 +14,7 @@ from litellm.proxy._types import ( LiteLLM_JWTAuth, LiteLLM_TeamTable, LiteLLMRoutes, + LiteLLM_OrganizationTable, ) from typing import Optional, Literal, Union from litellm.proxy.utils import PrismaClient @@ -287,3 +288,41 @@ async def get_team_object( raise Exception( f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call." ) + + +async def get_org_object( + org_id: str, + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, +): + """ + - Check if org id in proxy Org Table + - if valid, return LiteLLM_OrganizationTable object + - if not, then raise an error + """ + if prisma_client is None: + raise Exception( + "No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys" + ) + + # check if in cache + cached_org_obj = user_api_key_cache.async_get_cache(key="org_id:{}".format(org_id)) + if cached_org_obj is not None: + if isinstance(cached_org_obj, dict): + return cached_org_obj + elif isinstance(cached_org_obj, LiteLLM_OrganizationTable): + return cached_org_obj + # else, check db + try: + response = await prisma_client.db.litellm_organizationtable.find_unique( + where={"organization_id": org_id} + ) + + if response is None: + raise Exception + + return response + except Exception as e: + raise Exception( + f"Organization doesn't exist in db. Organization={org_id}. Create organization via `/organization/new` call." + ) diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index 76042ec68..1324c2c59 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -84,6 +84,16 @@ class JWTHandler: user_id = default_value return user_id + def get_org_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: + try: + if self.litellm_jwtauth.org_id_jwt_field is not None: + org_id = token[self.litellm_jwtauth.org_id_jwt_field] + else: + org_id = None + except KeyError: + org_id = default_value + return org_id + def get_scopes(self, token: dict) -> list: try: if isinstance(token["scope"], str): diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 17ddb2f05..8278fc5a0 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -116,6 +116,7 @@ from litellm.proxy.hooks.prompt_injection_detection import ( from litellm.proxy.auth.auth_checks import ( common_checks, get_end_user_object, + get_org_object, get_team_object, get_user_object, allowed_routes_check, @@ -422,6 +423,14 @@ async def user_api_key_auth( user_api_key_cache=user_api_key_cache, ) + # [OPTIONAL] track spend for an org id - `LiteLLM_OrganizationTable` + org_id = jwt_handler.get_org_id(token=valid_token, default_value=None) + if org_id is not None: + _ = await get_org_object( + org_id=org_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + ) # [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable` user_object = None user_id = jwt_handler.get_user_id(token=valid_token, default_value=None) @@ -515,6 +524,7 @@ async def user_api_key_auth( team_models=team_object.models, user_role="app_owner", user_id=user_id, + org_id=org_id, ) #### ELSE #### if master_key is None: @@ -1233,6 +1243,7 @@ async def _PROXY_track_cost_callback( end_user_id = proxy_server_request.get("body", {}).get("user", None) user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None) team_id = kwargs["litellm_params"]["metadata"].get("user_api_key_team_id", None) + org_id = kwargs["litellm_params"]["metadata"].get("user_api_key_org_id", None) if kwargs.get("response_cost", None) is not None: response_cost = kwargs["response_cost"] user_api_key = kwargs["litellm_params"]["metadata"].get( @@ -1260,6 +1271,7 @@ async def _PROXY_track_cost_callback( completion_response=completion_response, start_time=start_time, end_time=end_time, + org_id=org_id, ) await update_cache( @@ -1321,6 +1333,7 @@ async def update_database( completion_response=None, start_time=None, end_time=None, + org_id=None, ): try: global prisma_client @@ -1551,9 +1564,34 @@ async def update_database( ) raise e + ### UPDATE ORG SPEND ### + async def _update_org_db(): + try: + verbose_proxy_logger.debug( + "adding spend to org db. Response cost: {}. org_id: {}.".format( + response_cost, org_id + ) + ) + if org_id is None: + verbose_proxy_logger.debug( + "track_cost_callback: org_id is None. Not tracking spend for org" + ) + return + if prisma_client is not None: + prisma_client.org_list_transactons[org_id] = ( + response_cost + + prisma_client.org_list_transactons.get(org_id, 0) + ) + except Exception as e: + verbose_proxy_logger.info( + f"Update Org DB failed to execute - {str(e)}\n{traceback.format_exc()}" + ) + raise e + asyncio.create_task(_update_user_db()) asyncio.create_task(_update_key_db()) asyncio.create_task(_update_team_db()) + asyncio.create_task(_update_org_db()) # asyncio.create_task(_insert_spend_log_to_db()) if disable_spend_logs == False: await _insert_spend_log_to_db() @@ -3432,6 +3470,7 @@ async def chat_completion( user_api_key_dict, "key_alias", None ) data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + data["metadata"]["user_api_key_org_id"] = user_api_key_dict.org_id data["metadata"]["user_api_key_team_id"] = getattr( user_api_key_dict, "team_id", None ) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 04e5c434d..970ba45c1 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -567,6 +567,7 @@ class PrismaClient: end_user_list_transactons: dict = {} key_list_transactons: dict = {} team_list_transactons: dict = {} + org_list_transactons: dict = {} spend_log_transactions: List = [] def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging): @@ -2150,6 +2151,46 @@ async def update_spend( ) raise e + ### UPDATE ORG TABLE ### + if len(prisma_client.org_list_transactons.keys()) > 0: + for i in range(n_retry_times + 1): + try: + async with prisma_client.db.tx( + timeout=timedelta(seconds=60) + ) as transaction: + async with transaction.batch_() as batcher: + for ( + org_id, + response_cost, + ) in prisma_client.org_list_transactons.items(): + batcher.litellm_organizationtable.update_many( # 'update_many' prevents error from being raised if no row exists + where={"organization_id": org_id}, + data={"spend": {"increment": response_cost}}, + ) + prisma_client.org_list_transactons = ( + {} + ) # Clear the remaining transactions after processing all batches in the loop. + break + except httpx.ReadTimeout: + if i >= n_retry_times: # If we've reached the maximum number of retries + raise # Re-raise the last exception + # Optionally, sleep for a bit before retrying + await asyncio.sleep(2**i) # Exponential backoff + except Exception as e: + import traceback + + error_msg = ( + f"LiteLLM Prisma Client Exception - update org spend: {str(e)}" + ) + print_verbose(error_msg) + error_traceback = error_msg + "\n" + traceback.format_exc() + asyncio.create_task( + proxy_logging_obj.failure_handler( + original_exception=e, traceback_str=error_traceback + ) + ) + raise e + ### UPDATE SPEND LOGS ### verbose_proxy_logger.debug( "Spend Logs transactions: {}".format(len(prisma_client.spend_log_transactions))