diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 6196f18a2..ac30977b3 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -251,6 +251,7 @@ class Member(LiteLLMBase): class NewTeamRequest(LiteLLMBase): team_alias: Optional[str] = None team_id: Optional[str] = None + organization_id: Optional[str] = None admins: list = [] members: list = [] members_with_roles: List[Member] = [] @@ -327,19 +328,46 @@ class TeamRequest(LiteLLMBase): class LiteLLM_BudgetTable(LiteLLMBase): """Represents user-controllable params for a LiteLLM_BudgetTable record""" - - max_budget: Optional[float] = None soft_budget: Optional[float] = None + max_budget: Optional[float] = None max_parallel_requests: Optional[int] = None tpm_limit: Optional[int] = None rpm_limit: Optional[int] = None - model_max_budget: dict + model_max_budget: Optional[dict] = None budget_duration: Optional[str] = None - budget_reset_at: Optional[datetime] = None + + +class NewOrganizationRequest(LiteLLM_BudgetTable): + organization_alias: str + models: List = [] + budget_id: Optional[str] = None + + +class LiteLLM_OrganizationTable(LiteLLMBase): + """Represents user-controllable params for a LiteLLM_OrganizationTable record""" + + organization_alias: Optional[str] = None + budget_id: str + metadata: Optional[dict] = None + models: List[str] created_by: str updated_by: str +class NewOrganizationResponse(LiteLLM_OrganizationTable): + organization_id: str + created_at: datetime + updated_at: datetime + + +class OrganizationRequest(LiteLLMBase): + organizations: List[str] + + +class BudgetRequest(LiteLLMBase): + budgets: List[str] + + class KeyManagementSystem(enum.Enum): GOOGLE_KMS = "google_kms" AZURE_KEY_VAULT = "azure_key_vault" diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index db5caa9ae..ddd702c22 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -239,6 +239,7 @@ health_check_interval = None health_check_results = {} queue: List = [] litellm_proxy_budget_name = "litellm-proxy-budget" +litellm_proxy_admin_name = "default_user_id" ui_access_mode: Literal["admin", "all"] = "all" proxy_budget_rescheduler_min_time = 597 proxy_budget_rescheduler_max_time = 605 @@ -335,7 +336,11 @@ async def user_api_key_auth( # note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead is_master_key_valid = secrets.compare_digest(api_key, master_key) if is_master_key_valid: - return UserAPIKeyAuth(api_key=master_key, user_role="proxy_admin") + return UserAPIKeyAuth( + api_key=master_key, + user_role="proxy_admin", + user_id=litellm_proxy_admin_name, + ) if isinstance( api_key, str ): # if generated token, make sure it starts with sk-. @@ -360,7 +365,6 @@ async def user_api_key_auth( valid_token = await prisma_client.get_data( token=api_key, table_name="combined_view" ) - elif custom_db_client is not None: try: valid_token = await custom_db_client.get_data( @@ -2229,7 +2233,7 @@ def parse_cache_control(cache_control): @router.on_event("startup") async def startup_event(): - global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time + global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name import json ### LOAD MASTER KEY ### @@ -2276,9 +2280,8 @@ async def startup_event(): if prisma_client is not None and master_key is not None: # add master key to db - user_id = "default_user_id" if os.getenv("PROXY_ADMIN_ID", None) is not None: - user_id = os.getenv("PROXY_ADMIN_ID") + litellm_proxy_admin_name = os.getenv("PROXY_ADMIN_ID") asyncio.create_task( generate_key_helper_fn( @@ -2288,7 +2291,7 @@ async def startup_event(): config={}, spend=0, token=master_key, - user_id=user_id, + user_id=litellm_proxy_admin_name, user_role="proxy_admin", query_type="update_data", update_key_values={ @@ -5396,6 +5399,226 @@ async def team_info( ) +#### ORGANIZATION MANAGEMENT #### + + +@router.post( + "/organization/new", + tags=["organization management"], + dependencies=[Depends(user_api_key_auth)], + response_model=NewOrganizationResponse, +) +async def new_organization( + data: NewOrganizationRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Allow orgs to own teams + + Set org level budgets + model access. + + Only admins can create orgs. + + # Parameters + + - `organization_alias`: *str* = The name of the organization. + - `models`: *List* = The models the organization has access to. + - `budget_id`: *Optional[str]* = The id for a budget (tpm/rpm/max budget) for the organization. + ### IF NO BUDGET - CREATE ONE WITH THESE PARAMS ### + - `max_budget`: *Optional[float]* = Max budget for org + - `tpm_limit`: *Optional[int]* = Max tpm limit for org + - `rpm_limit`: *Optional[int]* = Max rpm limit for org + - `model_max_budget`: *Optional[dict]* = Max budget for a specific model + - `budget_duration`: *Optional[str]* = Frequency of reseting org budget + + Case 1: Create new org **without** a budget_id + + ```bash + curl --location 'http://0.0.0.0:4000/organization/new' \ + + --header 'Authorization: Bearer sk-1234' \ + + --header 'Content-Type: application/json' \ + + --data '{ + "organization_alias": "my-secret-org", + "models": ["model1", "model2"], + "max_budget": 100 + }' + + + ``` + + Case 2: Create new org **with** a budget_id + + ```bash + curl --location 'http://0.0.0.0:4000/organization/new' \ + + --header 'Authorization: Bearer sk-1234' \ + + --header 'Content-Type: application/json' \ + + --data '{ + "organization_alias": "my-secret-org", + "models": ["model1", "model2"], + "budget_id": "428eeaa8-f3ac-4e85-a8fb-7dc8d7aa8689" + }' + ``` + """ + global prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if ( + user_api_key_dict.user_role is None + or user_api_key_dict.user_role != "proxy_admin" + ): + raise HTTPException( + status_code=401, + detail={ + "error": f"Only admins can create orgs. Your role is = {user_api_key_dict.user_role}" + }, + ) + + if data.budget_id is None: + """ + Every organization needs a budget attached. + + If none provided, create one based on provided values + """ + budget_row = LiteLLM_BudgetTable(**data.json(exclude_none=True)) + + new_budget = prisma_client.jsonify_object(budget_row.json(exclude_none=True)) + + _budget = await prisma_client.db.litellm_budgettable.create( + data={ + **new_budget, # type: ignore + "created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, + "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name, + } + ) # type: ignore + + data.budget_id = _budget.budget_id + + """ + Ensure only models that user has access to, are given to org + """ + if len(user_api_key_dict.models) == 0: # user has access to all models + pass + else: + if len(data.models) == 0: + raise HTTPException( + status_code=400, + detail={ + "error": f"User not allowed to give access to all models. Select models you want org to have access to." + }, + ) + for m in data.models: + if m not in user_api_key_dict.models: + raise HTTPException( + status_code=400, + detail={ + "error": f"User not allowed to give access to model={m}. Models you have access to = {user_api_key_dict.models}" + }, + ) + organization_row = LiteLLM_OrganizationTable( + **data.json(exclude_none=True), + created_by=user_api_key_dict.user_id or litellm_proxy_admin_name, + updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name, + ) + new_organization_row = prisma_client.jsonify_object( + organization_row.json(exclude_none=True) + ) + response = await prisma_client.db.litellm_organizationtable.create( + data={ + **new_organization_row, # type: ignore + } + ) + + return response + + +@router.post( + "/organization/update", + tags=["organization management"], + dependencies=[Depends(user_api_key_auth)], +) +async def update_organization(): + """[TODO] Not Implemented yet. Let us know if you need this - https://github.com/BerriAI/litellm/issues""" + pass + + +@router.post( + "/organization/delete", + tags=["organization management"], + dependencies=[Depends(user_api_key_auth)], +) +async def delete_organization(): + """[TODO] Not Implemented yet. Let us know if you need this - https://github.com/BerriAI/litellm/issues""" + pass + + +@router.post( + "/organization/info", + tags=["organization management"], + dependencies=[Depends(user_api_key_auth)], +) +async def info_organization(data: OrganizationRequest): + """ + Get the org specific information + """ + global prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if len(data.organizations) == 0: + raise HTTPException( + status_code=400, + detail={ + "error": f"Specify list of organization id's to query. Passed in={data.organizations}" + }, + ) + response = await prisma_client.db.litellm_organizationtable.find_many( + where={"organization_id": {"in": data.organizations}}, + include={"litellm_budget_table": True}, + ) + + return response + + +#### BUDGET TABLE MANAGEMENT #### + + +@router.post( + "/budget/info", + tags=["organization management"], + dependencies=[Depends(user_api_key_auth)], +) +async def info_budget(data: BudgetRequest): + """ + Get the budget id specific information + """ + global prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if len(data.budgets) == 0: + raise HTTPException( + status_code=400, + detail={ + "error": f"Specify list of budget id's to query. Passed in={data.budgets}" + }, + ) + response = await prisma_client.db.litellm_budgettable.find_many( + where={"budget_id": {"in": data.budgets}}, + ) + + return response + + #### MODEL MANAGEMENT #### diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 1fe55f24e..93a9b3123 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -15,7 +15,7 @@ model LiteLLM_BudgetTable { max_parallel_requests Int? tpm_limit BigInt? rpm_limit BigInt? - model_max_budget Json @default("{}") + model_max_budget Json? budget_duration String? budget_reset_at DateTime? created_at DateTime @default(now()) @map("created_at") diff --git a/schema.prisma b/schema.prisma index 1fe55f24e..93a9b3123 100644 --- a/schema.prisma +++ b/schema.prisma @@ -15,7 +15,7 @@ model LiteLLM_BudgetTable { max_parallel_requests Int? tpm_limit BigInt? rpm_limit BigInt? - model_max_budget Json @default("{}") + model_max_budget Json? budget_duration String? budget_reset_at DateTime? created_at DateTime @default(now()) @map("created_at") diff --git a/tests/test_organizations.py b/tests/test_organizations.py new file mode 100644 index 000000000..00e99cb66 --- /dev/null +++ b/tests/test_organizations.py @@ -0,0 +1,46 @@ +# What this tests ? +## Tests /organization endpoints. +import pytest +import asyncio +import aiohttp +import time, uuid +from openai import AsyncOpenAI + + +async def new_organization(session, i, organization_alias, max_budget=None): + url = "http://0.0.0.0:4000/organization/new" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = { + "organization_alias": organization_alias, + "models": ["azure-models"], + "max_budget": max_budget, + } + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(f"Response {i} (Status code: {status}):") + print(response_text) + print() + + if status != 200: + raise Exception(f"Request {i} did not return a 200 status code: {status}") + + return await response.json() + + +@pytest.mark.asyncio +async def test_organization_new(): + """ + Make 20 parallel calls to /user/new. Assert all worked. + """ + organization_alias = f"Organization: {uuid.uuid4()}" + async with aiohttp.ClientSession() as session: + tasks = [ + new_organization( + session=session, i=0, organization_alias=organization_alias + ) + for i in range(1, 20) + ] + await asyncio.gather(*tasks)