From f2570fdf00b40065a6dc74f5dca58d6eaf050e22 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 15 Jun 2024 11:40:36 -0700 Subject: [PATCH] feat - refactor team endpoints --- litellm/proxy/_types.py | 7 + litellm/proxy/caching_routes.py | 194 ++++ .../common_utils/management_endpoint_utils.py | 113 -- litellm/proxy/management_helpers/utils.py | 113 ++ litellm/proxy/proxy_server.py | 1016 +---------------- litellm/proxy/team_endpoints.py | 902 +++++++++++++++ litellm/tests/test_key_generate_prisma.py | 5 +- 7 files changed, 1224 insertions(+), 1126 deletions(-) create mode 100644 litellm/proxy/caching_routes.py delete mode 100644 litellm/proxy/common_utils/management_endpoint_utils.py create mode 100644 litellm/proxy/team_endpoints.py diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index fdbd9e22b..9a2847839 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1609,3 +1609,10 @@ class ProxyException(Exception): "param": self.param, "code": self.code, } + + +class CommonProxyErrors(enum.Enum): + db_not_connected_error = "DB not connected" + no_llm_router = "No models configured on proxy" + not_allowed_access = "Admin-only endpoint. Not allowed to access this." + not_premium_user = "You must be a LiteLLM Enterprise user to use this feature. If you have a license please set `LITELLM_LICENSE` in your env. If you want to obtain a license meet with us here: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat" diff --git a/litellm/proxy/caching_routes.py b/litellm/proxy/caching_routes.py new file mode 100644 index 000000000..bad747793 --- /dev/null +++ b/litellm/proxy/caching_routes.py @@ -0,0 +1,194 @@ +from typing import Optional +from fastapi import Depends, Request, APIRouter +from fastapi import HTTPException +import copy +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth + + +router = APIRouter( + prefix="/cache", + tags=["caching"], +) + + +@router.get( + "/ping", + dependencies=[Depends(user_api_key_auth)], +) +async def cache_ping(): + """ + Endpoint for checking if cache can be pinged + """ + try: + litellm_cache_params = {} + specific_cache_params = {} + + if litellm.cache is None: + raise HTTPException( + status_code=503, detail="Cache not initialized. litellm.cache is None" + ) + + for k, v in vars(litellm.cache).items(): + try: + if k == "cache": + continue + litellm_cache_params[k] = str(copy.deepcopy(v)) + except Exception: + litellm_cache_params[k] = "" + for k, v in vars(litellm.cache.cache).items(): + try: + specific_cache_params[k] = str(v) + except Exception: + specific_cache_params[k] = "" + if litellm.cache.type == "redis": + # ping the redis cache + ping_response = await litellm.cache.ping() + verbose_proxy_logger.debug( + "/cache/ping: ping_response: " + str(ping_response) + ) + # making a set cache call + # add cache does not return anything + await litellm.cache.async_add_cache( + result="test_key", + model="test-model", + messages=[{"role": "user", "content": "test from litellm"}], + ) + verbose_proxy_logger.debug("/cache/ping: done with set_cache()") + return { + "status": "healthy", + "cache_type": litellm.cache.type, + "ping_response": True, + "set_cache_response": "success", + "litellm_cache_params": litellm_cache_params, + "redis_cache_params": specific_cache_params, + } + else: + return { + "status": "healthy", + "cache_type": litellm.cache.type, + "litellm_cache_params": litellm_cache_params, + } + except Exception as e: + raise HTTPException( + status_code=503, + detail=f"Service Unhealthy ({str(e)}).Cache parameters: {litellm_cache_params}.specific_cache_params: {specific_cache_params}", + ) + + +@router.post( + "/delete", + tags=["caching"], + dependencies=[Depends(user_api_key_auth)], +) +async def cache_delete(request: Request): + """ + Endpoint for deleting a key from the cache. All responses from litellm proxy have `x-litellm-cache-key` in the headers + + Parameters: + - **keys**: *Optional[List[str]]* - A list of keys to delete from the cache. Example {"keys": ["key1", "key2"]} + + ```shell + curl -X POST "http://0.0.0.0:4000/cache/delete" \ + -H "Authorization: Bearer sk-1234" \ + -d '{"keys": ["key1", "key2"]}' + ``` + + """ + try: + if litellm.cache is None: + raise HTTPException( + status_code=503, detail="Cache not initialized. litellm.cache is None" + ) + + request_data = await request.json() + keys = request_data.get("keys", None) + + if litellm.cache.type == "redis": + await litellm.cache.delete_cache_keys(keys=keys) + return { + "status": "success", + } + else: + raise HTTPException( + status_code=500, + detail=f"Cache type {litellm.cache.type} does not support deleting a key. only `redis` is supported", + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Cache Delete Failed({str(e)})", + ) + + +@router.get( + "/redis/info", + dependencies=[Depends(user_api_key_auth)], +) +async def cache_redis_info(): + """ + Endpoint for getting /redis/info + """ + try: + if litellm.cache is None: + raise HTTPException( + status_code=503, detail="Cache not initialized. litellm.cache is None" + ) + if litellm.cache.type == "redis": + client_list = litellm.cache.cache.client_list() + redis_info = litellm.cache.cache.info() + num_clients = len(client_list) + return { + "num_clients": num_clients, + "clients": client_list, + "info": redis_info, + } + else: + raise HTTPException( + status_code=500, + detail=f"Cache type {litellm.cache.type} does not support flushing", + ) + except Exception as e: + raise HTTPException( + status_code=503, + detail=f"Service Unhealthy ({str(e)})", + ) + + +@router.post( + "/flushall", + tags=["caching"], + dependencies=[Depends(user_api_key_auth)], +) +async def cache_flushall(): + """ + A function to flush all items from the cache. (All items will be deleted from the cache with this) + Raises HTTPException if the cache is not initialized or if the cache type does not support flushing. + Returns a dictionary with the status of the operation. + + Usage: + ``` + curl -X POST http://0.0.0.0:4000/cache/flushall -H "Authorization: Bearer sk-1234" + ``` + """ + try: + if litellm.cache is None: + raise HTTPException( + status_code=503, detail="Cache not initialized. litellm.cache is None" + ) + if litellm.cache.type == "redis": + litellm.cache.cache.flushall() + return { + "status": "success", + } + else: + raise HTTPException( + status_code=500, + detail=f"Cache type {litellm.cache.type} does not support flushing", + ) + except Exception as e: + raise HTTPException( + status_code=503, + detail=f"Service Unhealthy ({str(e)})", + ) diff --git a/litellm/proxy/common_utils/management_endpoint_utils.py b/litellm/proxy/common_utils/management_endpoint_utils.py deleted file mode 100644 index 2b4465caf..000000000 --- a/litellm/proxy/common_utils/management_endpoint_utils.py +++ /dev/null @@ -1,113 +0,0 @@ -from datetime import datetime -from functools import wraps -from litellm.proxy._types import UserAPIKeyAuth, ManagementEndpointLoggingPayload -from litellm.proxy.common_utils.http_parsing_utils import _read_request_body -from litellm._logging import verbose_logger -from fastapi import Request - - -def management_endpoint_wrapper(func): - """ - This wrapper does the following: - - 1. Log I/O, Exceptions to OTEL - 2. Create an Audit log for success calls - """ - - @wraps(func) - async def wrapper(*args, **kwargs): - start_time = datetime.now() - - try: - result = await func(*args, **kwargs) - end_time = datetime.now() - try: - if kwargs is None: - kwargs = {} - user_api_key_dict: UserAPIKeyAuth = ( - kwargs.get("user_api_key_dict") or UserAPIKeyAuth() - ) - _http_request: Request = kwargs.get("http_request") - parent_otel_span = user_api_key_dict.parent_otel_span - if parent_otel_span is not None: - from litellm.proxy.proxy_server import open_telemetry_logger - - if open_telemetry_logger is not None: - if _http_request: - _route = _http_request.url.path - _request_body: dict = await _read_request_body( - request=_http_request - ) - _response = dict(result) if result is not None else None - - logging_payload = ManagementEndpointLoggingPayload( - route=_route, - request_data=_request_body, - response=_response, - start_time=start_time, - end_time=end_time, - ) - - await open_telemetry_logger.async_management_endpoint_success_hook( - logging_payload=logging_payload, - parent_otel_span=parent_otel_span, - ) - - if _http_request: - _route = _http_request.url.path - # Flush user_api_key cache if this was an update/delete call to /key, /team, or /user - if _route in [ - "/key/update", - "/key/delete", - "/team/update", - "/team/delete", - "/user/update", - "/user/delete", - "/customer/update", - "/customer/delete", - ]: - from litellm.proxy.proxy_server import user_api_key_cache - - user_api_key_cache.flush_cache() - except Exception as e: - # Non-Blocking Exception - verbose_logger.debug("Error in management endpoint wrapper: %s", str(e)) - pass - - return result - except Exception as e: - end_time = datetime.now() - - if kwargs is None: - kwargs = {} - user_api_key_dict: UserAPIKeyAuth = ( - kwargs.get("user_api_key_dict") or UserAPIKeyAuth() - ) - parent_otel_span = user_api_key_dict.parent_otel_span - if parent_otel_span is not None: - from litellm.proxy.proxy_server import open_telemetry_logger - - if open_telemetry_logger is not None: - _http_request: Request = kwargs.get("http_request") - if _http_request: - _route = _http_request.url.path - _request_body: dict = await _read_request_body( - request=_http_request - ) - logging_payload = ManagementEndpointLoggingPayload( - route=_route, - request_data=_request_body, - response=None, - start_time=start_time, - end_time=end_time, - exception=e, - ) - - await open_telemetry_logger.async_management_endpoint_failure_hook( - logging_payload=logging_payload, - parent_otel_span=parent_otel_span, - ) - - raise e - - return wrapper diff --git a/litellm/proxy/management_helpers/utils.py b/litellm/proxy/management_helpers/utils.py index 6c035d3ef..1cf22480b 100644 --- a/litellm/proxy/management_helpers/utils.py +++ b/litellm/proxy/management_helpers/utils.py @@ -1,5 +1,11 @@ # What is this? ## Helper utils for the management endpoints (keys/users/teams) +from datetime import datetime +from functools import wraps +from litellm.proxy._types import UserAPIKeyAuth, ManagementEndpointLoggingPayload +from litellm.proxy.common_utils.http_parsing_utils import _read_request_body +from litellm._logging import verbose_logger +from fastapi import Request from litellm.proxy._types import LiteLLM_TeamTable, Member, UserAPIKeyAuth from litellm.proxy.utils import PrismaClient @@ -61,3 +67,110 @@ async def add_new_member( "budget_id": _budget_id, } ) + + +def management_endpoint_wrapper(func): + """ + This wrapper does the following: + + 1. Log I/O, Exceptions to OTEL + 2. Create an Audit log for success calls + """ + + @wraps(func) + async def wrapper(*args, **kwargs): + start_time = datetime.now() + + try: + result = await func(*args, **kwargs) + end_time = datetime.now() + try: + if kwargs is None: + kwargs = {} + user_api_key_dict: UserAPIKeyAuth = ( + kwargs.get("user_api_key_dict") or UserAPIKeyAuth() + ) + _http_request: Request = kwargs.get("http_request") + parent_otel_span = user_api_key_dict.parent_otel_span + if parent_otel_span is not None: + from litellm.proxy.proxy_server import open_telemetry_logger + + if open_telemetry_logger is not None: + if _http_request: + _route = _http_request.url.path + _request_body: dict = await _read_request_body( + request=_http_request + ) + _response = dict(result) if result is not None else None + + logging_payload = ManagementEndpointLoggingPayload( + route=_route, + request_data=_request_body, + response=_response, + start_time=start_time, + end_time=end_time, + ) + + await open_telemetry_logger.async_management_endpoint_success_hook( + logging_payload=logging_payload, + parent_otel_span=parent_otel_span, + ) + + if _http_request: + _route = _http_request.url.path + # Flush user_api_key cache if this was an update/delete call to /key, /team, or /user + if _route in [ + "/key/update", + "/key/delete", + "/team/update", + "/team/delete", + "/user/update", + "/user/delete", + "/customer/update", + "/customer/delete", + ]: + from litellm.proxy.proxy_server import user_api_key_cache + + user_api_key_cache.flush_cache() + except Exception as e: + # Non-Blocking Exception + verbose_logger.debug("Error in management endpoint wrapper: %s", str(e)) + pass + + return result + except Exception as e: + end_time = datetime.now() + + if kwargs is None: + kwargs = {} + user_api_key_dict: UserAPIKeyAuth = ( + kwargs.get("user_api_key_dict") or UserAPIKeyAuth() + ) + parent_otel_span = user_api_key_dict.parent_otel_span + if parent_otel_span is not None: + from litellm.proxy.proxy_server import open_telemetry_logger + + if open_telemetry_logger is not None: + _http_request: Request = kwargs.get("http_request") + if _http_request: + _route = _http_request.url.path + _request_body: dict = await _read_request_body( + request=_http_request + ) + logging_payload = ManagementEndpointLoggingPayload( + route=_route, + request_data=_request_body, + response=None, + start_time=start_time, + end_time=end_time, + exception=e, + ) + + await open_telemetry_logger.async_management_endpoint_failure_hook( + logging_payload=logging_payload, + parent_otel_span=parent_otel_span, + ) + + raise e + + return wrapper diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 9a9f976be..1e202807d 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -90,7 +90,6 @@ from litellm.types.llms.openai import ( HttpxBinaryResponseContent, ) from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request -from litellm.proxy.management_helpers.utils import add_new_member from litellm.proxy.utils import ( PrismaClient, DBClient, @@ -162,14 +161,14 @@ from litellm.proxy.auth.auth_checks import ( get_actual_routes, log_to_opentelemetry, ) -from litellm.proxy.common_utils.management_endpoint_utils import ( - management_endpoint_wrapper, -) from litellm.llms.custom_httpx.httpx_handler import HTTPHandler from litellm.exceptions import RejectedRequestError from litellm.integrations.slack_alerting import SlackAlertingArgs, SlackAlerting from litellm.scheduler import Scheduler, FlowItem, DefaultPriorities +## Import All Misc routes here ## +from caching_routes import router as caching_router + try: from litellm._version import version except: @@ -277,13 +276,6 @@ class UserAPIKeyCacheTTLEnum(enum.Enum): in_memory_cache_ttl = 60 # 1 min ttl ## configure via `general_settings::user_api_key_cache_ttl: ` -class CommonProxyErrors(enum.Enum): - db_not_connected_error = "DB not connected" - no_llm_router = "No models configured on proxy" - not_allowed_access = "Admin-only endpoint. Not allowed to access this." - not_premium_user = "You must be a LiteLLM Enterprise user to use this feature. If you have a license please set `LITELLM_LICENSE` in your env. If you want to obtain a license meet with us here: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat" - - @app.exception_handler(ProxyException) async def openai_exception_handler(request: Request, exc: ProxyException): # NOTE: DO NOT MODIFY THIS, its crucial to map to Openai exceptions @@ -7781,6 +7773,8 @@ async def new_user(data: NewUserRequest): # Admin UI Logic # if team_id passed add this user to the team if data_json.get("team_id", None) is not None: + from litellm.proxy.team_endpoints import team_member_add + await team_member_add( data=TeamMemberAddRequest( team_id=data_json.get("team_id", None), @@ -8816,235 +8810,6 @@ async def delete_end_user( pass -#### TEAM MANAGEMENT #### - - -@router.post( - "/team/new", - tags=["team management"], - dependencies=[Depends(user_api_key_auth)], - response_model=LiteLLM_TeamTable, -) -@management_endpoint_wrapper -async def new_team( - data: NewTeamRequest, - http_request: Request, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), - litellm_changed_by: Optional[str] = Header( - None, - description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", - ), -): - """ - Allow users to create a new team. Apply user permissions to their team. - - 👉 [Detailed Doc on setting team budgets](https://docs.litellm.ai/docs/proxy/team_budgets) - - - Parameters: - - team_alias: Optional[str] - User defined team alias - - team_id: Optional[str] - The team id of the user. If none passed, we'll generate it. - - members_with_roles: List[{"role": "admin" or "user", "user_id": ""}] - A list of users and their roles in the team. Get user_id when making a new user via `/user/new`. - - metadata: Optional[dict] - Metadata for team, store information for team. Example metadata = {"extra_info": "some info"} - - tpm_limit: Optional[int] - The TPM (Tokens Per Minute) limit for this team - all keys with this team_id will have at max this TPM limit - - rpm_limit: Optional[int] - The RPM (Requests Per Minute) limit for this team - all keys associated with this team_id will have at max this RPM limit - - max_budget: Optional[float] - The maximum budget allocated to the team - all keys for this team_id will have at max this max_budget - - budget_duration: Optional[str] - The duration of the budget for the team. Doc [here](https://docs.litellm.ai/docs/proxy/team_budgets) - - models: Optional[list] - A list of models associated with the team - all keys for this team_id will have at most, these models. If empty, assumes all models are allowed. - - blocked: bool - Flag indicating if the team is blocked or not - will stop all calls from keys with this team_id. - - Returns: - - team_id: (str) Unique team id - used for tracking spend across multiple keys for same team id. - - _deprecated_params: - - admins: list - A list of user_id's for the admin role - - users: list - A list of user_id's for the user role - - Example Request: - ``` - curl --location 'http://0.0.0.0:4000/team/new' \ - - --header 'Authorization: Bearer sk-1234' \ - - --header 'Content-Type: application/json' \ - - --data '{ - "team_alias": "my-new-team_2", - "members_with_roles": [{"role": "admin", "user_id": "user-1234"}, - {"role": "user", "user_id": "user-2434"}] - }' - - ``` - - ``` - curl --location 'http://0.0.0.0:4000/team/new' \ - - --header 'Authorization: Bearer sk-1234' \ - - --header 'Content-Type: application/json' \ - - --data '{ - "team_alias": "QA Prod Bot", - "max_budget": 0.000000001, - "budget_duration": "1d" - }' - - ``` - """ - global prisma_client - - if prisma_client is None: - raise HTTPException(status_code=500, detail={"error": "No db connected"}) - - if data.team_id is None: - data.team_id = str(uuid.uuid4()) - else: - # Check if team_id exists already - _existing_team_id = await prisma_client.get_data( - team_id=data.team_id, table_name="team", query_type="find_unique" - ) - if _existing_team_id is not None: - raise HTTPException( - status_code=400, - detail={ - "error": f"Team id = {data.team_id} already exists. Please use a different team id." - }, - ) - - if ( - user_api_key_dict.user_role is None - or user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN - ): # don't restrict proxy admin - if ( - data.tpm_limit is not None - and user_api_key_dict.tpm_limit is not None - and data.tpm_limit > user_api_key_dict.tpm_limit - ): - raise HTTPException( - status_code=400, - detail={ - "error": f"tpm limit higher than user max. User tpm limit={user_api_key_dict.tpm_limit}. User role={user_api_key_dict.user_role}" - }, - ) - - if ( - data.rpm_limit is not None - and user_api_key_dict.rpm_limit is not None - and data.rpm_limit > user_api_key_dict.rpm_limit - ): - raise HTTPException( - status_code=400, - detail={ - "error": f"rpm limit higher than user max. User rpm limit={user_api_key_dict.rpm_limit}. User role={user_api_key_dict.user_role}" - }, - ) - - if ( - data.max_budget is not None - and user_api_key_dict.max_budget is not None - and data.max_budget > user_api_key_dict.max_budget - ): - raise HTTPException( - status_code=400, - detail={ - "error": f"max budget higher than user max. User max budget={user_api_key_dict.max_budget}. User role={user_api_key_dict.user_role}" - }, - ) - - if data.models is not None and len(user_api_key_dict.models) > 0: - for m in data.models: - if m not in user_api_key_dict.models: - raise HTTPException( - status_code=400, - detail={ - "error": f"Model not in allowed user models. User allowed models={user_api_key_dict.models}. User id={user_api_key_dict.user_id}" - }, - ) - - if user_api_key_dict.user_id is not None: - creating_user_in_list = False - for member in data.members_with_roles: - if member.user_id == user_api_key_dict.user_id: - creating_user_in_list = True - - if creating_user_in_list == False: - data.members_with_roles.append( - Member(role="admin", user_id=user_api_key_dict.user_id) - ) - - ## ADD TO MODEL TABLE - _model_id = None - if data.model_aliases is not None and isinstance(data.model_aliases, dict): - litellm_modeltable = LiteLLM_ModelTable( - model_aliases=json.dumps(data.model_aliases), - created_by=user_api_key_dict.user_id or litellm_proxy_admin_name, - updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name, - ) - model_dict = await prisma_client.db.litellm_modeltable.create( - {**litellm_modeltable.json(exclude_none=True)} # type: ignore - ) # type: ignore - - _model_id = model_dict.id - - ## ADD TO TEAM TABLE - complete_team_data = LiteLLM_TeamTable( - **data.json(), - model_id=_model_id, - ) - - # If budget_duration is set, set `budget_reset_at` - if complete_team_data.budget_duration is not None: - duration_s = _duration_in_seconds(duration=complete_team_data.budget_duration) - reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) - complete_team_data.budget_reset_at = reset_at - - team_row = await prisma_client.insert_data( - data=complete_team_data.json(exclude_none=True), table_name="team" - ) - - ## ADD TEAM ID TO USER TABLE ## - for user in complete_team_data.members_with_roles: - ## add team id to user row ## - await prisma_client.update_data( - user_id=user.user_id, - data={"user_id": user.user_id, "teams": [team_row.team_id]}, - update_key_values_custom_query={ - "teams": { - "push ": [team_row.team_id], - } - }, - ) - - # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True - if litellm.store_audit_logs is True: - _updated_values = complete_team_data.json(exclude_none=True) - - _updated_values = json.dumps(_updated_values, default=str) - - asyncio.create_task( - create_audit_log_for_update( - request_data=LiteLLM_AuditLogs( - id=str(uuid.uuid4()), - updated_at=datetime.now(timezone.utc), - changed_by=litellm_changed_by - or user_api_key_dict.user_id - or litellm_proxy_admin_name, - changed_by_api_key=user_api_key_dict.api_key, - table_name=LitellmTableNames.TEAM_TABLE_NAME, - object_id=data.team_id, - action="created", - updated_values=_updated_values, - before_value=None, - ) - ) - ) - - try: - return team_row.model_dump() - except Exception as e: - return team_row.dict() - - async def create_audit_log_for_update(request_data: LiteLLM_AuditLogs): if premium_user is not True: return @@ -9077,593 +8842,6 @@ async def create_audit_log_for_update(request_data: LiteLLM_AuditLogs): return -@router.post( - "/team/update", tags=["team management"], dependencies=[Depends(user_api_key_auth)] -) -@management_endpoint_wrapper -async def update_team( - data: UpdateTeamRequest, - http_request: Request, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), - litellm_changed_by: Optional[str] = Header( - None, - description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", - ), -): - """ - Use `/team/member_add` AND `/team/member/delete` to add/remove new team members - - You can now update team budget / rate limits via /team/update - - Parameters: - - team_id: str - The team id of the user. Required param. - - team_alias: Optional[str] - User defined team alias - - metadata: Optional[dict] - Metadata for team, store information for team. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" } - - tpm_limit: Optional[int] - The TPM (Tokens Per Minute) limit for this team - all keys with this team_id will have at max this TPM limit - - rpm_limit: Optional[int] - The RPM (Requests Per Minute) limit for this team - all keys associated with this team_id will have at max this RPM limit - - max_budget: Optional[float] - The maximum budget allocated to the team - all keys for this team_id will have at max this max_budget - - budget_duration: Optional[str] - The duration of the budget for the team. Doc [here](https://docs.litellm.ai/docs/proxy/team_budgets) - - models: Optional[list] - A list of models associated with the team - all keys for this team_id will have at most, these models. If empty, assumes all models are allowed. - - blocked: bool - Flag indicating if the team is blocked or not - will stop all calls from keys with this team_id. - - Example - update team TPM Limit - - ``` - curl --location 'http://0.0.0.0:8000/team/update' \ - - --header 'Authorization: Bearer sk-1234' \ - - --header 'Content-Type: application/json' \ - - --data-raw '{ - "team_id": "litellm-test-client-id-new", - "tpm_limit": 100 - }' - ``` - - Example - Update Team `max_budget` budget - ``` - curl --location 'http://0.0.0.0:8000/team/update' \ - - --header 'Authorization: Bearer sk-1234' \ - - --header 'Content-Type: application/json' \ - - --data-raw '{ - "team_id": "litellm-test-client-id-new", - "max_budget": 10 - }' - ``` - """ - global prisma_client - - if prisma_client is None: - raise HTTPException(status_code=500, detail={"error": "No db connected"}) - - if data.team_id is None: - raise HTTPException(status_code=400, detail={"error": "No team id passed in"}) - verbose_proxy_logger.debug("/team/update - %s", data) - - existing_team_row = await prisma_client.get_data( - team_id=data.team_id, table_name="team", query_type="find_unique" - ) - if existing_team_row is None: - raise HTTPException( - status_code=404, - detail={"error": f"Team not found, passed team_id={data.team_id}"}, - ) - - updated_kv = data.json(exclude_none=True) - - # Check budget_duration and budget_reset_at - if data.budget_duration is not None: - duration_s = _duration_in_seconds(duration=data.budget_duration) - reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) - - # set the budget_reset_at in DB - updated_kv["budget_reset_at"] = reset_at - - team_row = await prisma_client.update_data( - update_key_values=updated_kv, - data=updated_kv, - table_name="team", - team_id=data.team_id, - ) - - # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True - if litellm.store_audit_logs is True: - _before_value = existing_team_row.json(exclude_none=True) - _before_value = json.dumps(_before_value, default=str) - _after_value: str = json.dumps(updated_kv, default=str) - - asyncio.create_task( - create_audit_log_for_update( - request_data=LiteLLM_AuditLogs( - id=str(uuid.uuid4()), - updated_at=datetime.now(timezone.utc), - changed_by=litellm_changed_by - or user_api_key_dict.user_id - or litellm_proxy_admin_name, - changed_by_api_key=user_api_key_dict.api_key, - table_name=LitellmTableNames.TEAM_TABLE_NAME, - object_id=data.team_id, - action="updated", - updated_values=_after_value, - before_value=_before_value, - ) - ) - ) - - return team_row - - -@router.post( - "/team/member_add", - tags=["team management"], - dependencies=[Depends(user_api_key_auth)], -) -@management_endpoint_wrapper -async def team_member_add( - data: TeamMemberAddRequest, - http_request: Request, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """ - [BETA] - - Add new members (either via user_email or user_id) to a team - - If user doesn't exist, new user row will also be added to User Table - - ``` - - curl -X POST 'http://0.0.0.0:4000/team/member_add' \ - -H 'Authorization: Bearer sk-1234' \ - -H 'Content-Type: application/json' \ - -d '{"team_id": "45e3e396-ee08-4a61-a88e-16b3ce7e0849", "member": {"role": "user", "user_id": "krrish247652@berri.ai"}}' - - ``` - """ - if prisma_client is None: - raise HTTPException(status_code=500, detail={"error": "No db connected"}) - - if data.team_id is None: - raise HTTPException(status_code=400, detail={"error": "No team id passed in"}) - - if data.member is None: - raise HTTPException( - status_code=400, detail={"error": "No member/members passed in"} - ) - - existing_team_row = await prisma_client.db.litellm_teamtable.find_unique( - where={"team_id": data.team_id} - ) - if existing_team_row is None: - raise HTTPException( - status_code=404, - detail={ - "error": f"Team not found for team_id={getattr(data, 'team_id', None)}" - }, - ) - - complete_team_data = LiteLLM_TeamTable(**existing_team_row.model_dump()) - - if isinstance(data.member, Member): - # add to team db - new_member = data.member - - complete_team_data.members_with_roles.append(new_member) - - elif isinstance(data.member, List): - # add to team db - new_members = data.member - - complete_team_data.members_with_roles.extend(new_members) - - # ADD MEMBER TO TEAM - _db_team_members = [m.model_dump() for m in complete_team_data.members_with_roles] - updated_team = await prisma_client.db.litellm_teamtable.update( - where={"team_id": data.team_id}, - data={"members_with_roles": json.dumps(_db_team_members)}, # type: ignore - ) - - if isinstance(data.member, Member): - await add_new_member( - new_member=data.member, - max_budget_in_team=data.max_budget_in_team, - prisma_client=prisma_client, - user_api_key_dict=user_api_key_dict, - litellm_proxy_admin_name=litellm_proxy_admin_name, - team_id=data.team_id, - ) - elif isinstance(data.member, List): - tasks: List = [] - for m in data.member: - await add_new_member( - new_member=m, - max_budget_in_team=data.max_budget_in_team, - prisma_client=prisma_client, - user_api_key_dict=user_api_key_dict, - litellm_proxy_admin_name=litellm_proxy_admin_name, - team_id=data.team_id, - ) - await asyncio.gather(*tasks) - - return updated_team - - -@router.post( - "/team/member_delete", - tags=["team management"], - dependencies=[Depends(user_api_key_auth)], -) -@management_endpoint_wrapper -async def team_member_delete( - data: TeamMemberDeleteRequest, - http_request: Request, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """ - [BETA] - - delete members (either via user_email or user_id) from a team - - If user doesn't exist, an exception will be raised - ``` - curl -X POST 'http://0.0.0.0:8000/team/update' \ - - -H 'Authorization: Bearer sk-1234' \ - - -H 'Content-Type: application/json' \ - - -D '{ - "team_id": "45e3e396-ee08-4a61-a88e-16b3ce7e0849", - "user_id": "krrish247652@berri.ai" - }' - ``` - """ - if prisma_client is None: - raise HTTPException(status_code=500, detail={"error": "No db connected"}) - - if data.team_id is None: - raise HTTPException(status_code=400, detail={"error": "No team id passed in"}) - - if data.user_id is None and data.user_email is None: - raise HTTPException( - status_code=400, - detail={"error": "Either user_id or user_email needs to be passed in"}, - ) - - _existing_team_row = await prisma_client.db.litellm_teamtable.find_unique( - where={"team_id": data.team_id} - ) - - if _existing_team_row is None: - raise HTTPException( - status_code=400, - detail={"error": "Team id={} does not exist in db".format(data.team_id)}, - ) - existing_team_row = LiteLLM_TeamTable(**_existing_team_row.model_dump()) - - ## DELETE MEMBER FROM TEAM - new_team_members: List[Member] = [] - for m in existing_team_row.members_with_roles: - if ( - data.user_id is not None - and m.user_id is not None - and data.user_id == m.user_id - ): - continue - elif ( - data.user_email is not None - and m.user_email is not None - and data.user_email == m.user_email - ): - continue - new_team_members.append(m) - existing_team_row.members_with_roles = new_team_members - - _db_new_team_members: List[dict] = [m.model_dump() for m in new_team_members] - - _ = await prisma_client.db.litellm_teamtable.update( - where={ - "team_id": data.team_id, - }, - data={"members_with_roles": json.dumps(_db_new_team_members)}, # type: ignore - ) - - ## DELETE TEAM ID from USER ROW, IF EXISTS ## - # get user row - key_val = {} - if data.user_id is not None: - key_val["user_id"] = data.user_id - elif data.user_email is not None: - key_val["user_email"] = data.user_email - existing_user_rows = await prisma_client.db.litellm_usertable.find_many( - where=key_val # type: ignore - ) - - if existing_user_rows is not None and ( - isinstance(existing_user_rows, list) and len(existing_user_rows) > 0 - ): - for existing_user in existing_user_rows: - team_list = [] - if data.team_id in existing_user.teams: - team_list = existing_user.teams - team_list.remove(data.team_id) - await prisma_client.db.litellm_usertable.update( - where={ - "user_id": existing_user.user_id, - }, - data={"teams": {"set": team_list}}, - ) - - return existing_team_row - - -@router.post( - "/team/delete", tags=["team management"], dependencies=[Depends(user_api_key_auth)] -) -@management_endpoint_wrapper -async def delete_team( - data: DeleteTeamRequest, - http_request: Request, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), - litellm_changed_by: Optional[str] = Header( - None, - description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", - ), -): - """ - delete team and associated team keys - - ``` - curl --location 'http://0.0.0.0:8000/team/delete' \ - - --header 'Authorization: Bearer sk-1234' \ - - --header 'Content-Type: application/json' \ - - --data-raw '{ - "team_ids": ["45e3e396-ee08-4a61-a88e-16b3ce7e0849"] - }' - ``` - """ - global prisma_client - - if prisma_client is None: - raise HTTPException(status_code=500, detail={"error": "No db connected"}) - - if data.team_ids is None: - raise HTTPException(status_code=400, detail={"error": "No team id passed in"}) - - # check that all teams passed exist - for team_id in data.team_ids: - team_row = await prisma_client.get_data( # type: ignore - team_id=team_id, table_name="team", query_type="find_unique" - ) - if team_row is None: - raise HTTPException( - status_code=404, - detail={"error": f"Team not found, passed team_id={team_id}"}, - ) - - # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True - # we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes - if litellm.store_audit_logs is True: - # make an audit log for each team deleted - for team_id in data.team_ids: - team_row = await prisma_client.get_data( # type: ignore - team_id=team_id, table_name="team", query_type="find_unique" - ) - - _team_row = team_row.json(exclude_none=True) - - asyncio.create_task( - create_audit_log_for_update( - request_data=LiteLLM_AuditLogs( - id=str(uuid.uuid4()), - updated_at=datetime.now(timezone.utc), - changed_by=litellm_changed_by - or user_api_key_dict.user_id - or litellm_proxy_admin_name, - changed_by_api_key=user_api_key_dict.api_key, - table_name=LitellmTableNames.TEAM_TABLE_NAME, - object_id=team_id, - action="deleted", - updated_values="{}", - before_value=_team_row, - ) - ) - ) - - # End of Audit logging - - ## DELETE ASSOCIATED KEYS - await prisma_client.delete_data(team_id_list=data.team_ids, table_name="key") - ## DELETE TEAMS - deleted_teams = await prisma_client.delete_data( - team_id_list=data.team_ids, table_name="team" - ) - return deleted_teams - - -@router.get( - "/team/info", tags=["team management"], dependencies=[Depends(user_api_key_auth)] -) -@management_endpoint_wrapper -async def team_info( - http_request: Request, - team_id: str = fastapi.Query( - default=None, description="Team ID in the request parameters" - ), -): - """ - get info on team + related keys - - ``` - curl --location 'http://localhost:4000/team/info' \ - --header 'Authorization: Bearer sk-1234' \ - --header 'Content-Type: application/json' \ - --data '{ - "teams": ["",..] - }' - ``` - """ - global prisma_client - try: - if prisma_client is None: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail={ - "error": f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" - }, - ) - if team_id is None: - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail={"message": "Malformed request. No team id passed in."}, - ) - - team_info = await prisma_client.get_data( - team_id=team_id, table_name="team", query_type="find_unique" - ) - if team_info is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail={"message": f"Team not found, passed team id: {team_id}."}, - ) - - ## GET ALL KEYS ## - keys = await prisma_client.get_data( - team_id=team_id, - table_name="key", - query_type="find_all", - expires=datetime.now(), - ) - - if team_info is None: - ## make sure we still return a total spend ## - spend = 0 - for k in keys: - spend += getattr(k, "spend", 0) - team_info = {"spend": spend} - - ## REMOVE HASHED TOKEN INFO before returning ## - for key in keys: - try: - key = key.model_dump() # noqa - except: - # if using pydantic v1 - key = key.dict() - key.pop("token", None) - return {"team_id": team_id, "team_info": team_info, "keys": keys} - - except Exception as e: - if isinstance(e, HTTPException): - raise ProxyException( - message=getattr(e, "detail", f"Authentication Error({str(e)})"), - type="auth_error", - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), - ) - elif isinstance(e, ProxyException): - raise e - raise ProxyException( - message="Authentication Error, " + str(e), - type="auth_error", - param=getattr(e, "param", "None"), - code=status.HTTP_400_BAD_REQUEST, - ) - - -@router.post( - "/team/block", tags=["team management"], dependencies=[Depends(user_api_key_auth)] -) -@management_endpoint_wrapper -async def block_team( - data: BlockTeamRequest, - http_request: Request, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """ - Blocks all calls from keys with this team id. - """ - global prisma_client - - if prisma_client is None: - raise Exception("No DB Connected.") - - record = await prisma_client.db.litellm_teamtable.update( - where={"team_id": data.team_id}, data={"blocked": True} # type: ignore - ) - - return record - - -@router.post( - "/team/unblock", tags=["team management"], dependencies=[Depends(user_api_key_auth)] -) -@management_endpoint_wrapper -async def unblock_team( - data: BlockTeamRequest, - http_request: Request, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """ - Blocks all calls from keys with this team id. - """ - global prisma_client - - if prisma_client is None: - raise Exception("No DB Connected.") - - record = await prisma_client.db.litellm_teamtable.update( - where={"team_id": data.team_id}, data={"blocked": False} # type: ignore - ) - - return record - - -@router.get( - "/team/list", tags=["team management"], dependencies=[Depends(user_api_key_auth)] -) -@management_endpoint_wrapper -async def list_team( - http_request: Request, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """ - [Admin-only] List all available teams - - ``` - curl --location --request GET 'http://0.0.0.0:4000/team/list' \ - --header 'Authorization: Bearer sk-1234' - ``` - """ - global prisma_client - - if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN: - raise HTTPException( - status_code=401, - detail={ - "error": "Admin-only endpoint. Your user role={}".format( - user_api_key_dict.user_role - ) - }, - ) - - if prisma_client is None: - raise HTTPException( - status_code=400, - detail={"error": CommonProxyErrors.db_not_connected_error.value}, - ) - - response = await prisma_client.db.litellm_teamtable.find_many() - - return response - - #### ORGANIZATION MANAGEMENT #### @@ -13492,189 +12670,6 @@ async def health_liveliness(): return "I'm alive!" -@router.get( - "/cache/ping", - tags=["caching"], - dependencies=[Depends(user_api_key_auth)], -) -async def cache_ping(): - """ - Endpoint for checking if cache can be pinged - """ - try: - litellm_cache_params = {} - specific_cache_params = {} - - if litellm.cache is None: - raise HTTPException( - status_code=503, detail="Cache not initialized. litellm.cache is None" - ) - - for k, v in vars(litellm.cache).items(): - try: - if k == "cache": - continue - litellm_cache_params[k] = str(copy.deepcopy(v)) - except Exception: - litellm_cache_params[k] = "" - for k, v in vars(litellm.cache.cache).items(): - try: - specific_cache_params[k] = str(v) - except Exception: - specific_cache_params[k] = "" - if litellm.cache.type == "redis": - # ping the redis cache - ping_response = await litellm.cache.ping() - verbose_proxy_logger.debug( - "/cache/ping: ping_response: " + str(ping_response) - ) - # making a set cache call - # add cache does not return anything - await litellm.cache.async_add_cache( - result="test_key", - model="test-model", - messages=[{"role": "user", "content": "test from litellm"}], - ) - verbose_proxy_logger.debug("/cache/ping: done with set_cache()") - return { - "status": "healthy", - "cache_type": litellm.cache.type, - "ping_response": True, - "set_cache_response": "success", - "litellm_cache_params": litellm_cache_params, - "redis_cache_params": specific_cache_params, - } - else: - return { - "status": "healthy", - "cache_type": litellm.cache.type, - "litellm_cache_params": litellm_cache_params, - } - except Exception as e: - raise HTTPException( - status_code=503, - detail=f"Service Unhealthy ({str(e)}).Cache parameters: {litellm_cache_params}.specific_cache_params: {specific_cache_params}", - ) - - -@router.post( - "/cache/delete", - tags=["caching"], - dependencies=[Depends(user_api_key_auth)], -) -async def cache_delete(request: Request): - """ - Endpoint for deleting a key from the cache. All responses from litellm proxy have `x-litellm-cache-key` in the headers - - Parameters: - - **keys**: *Optional[List[str]]* - A list of keys to delete from the cache. Example {"keys": ["key1", "key2"]} - - ```shell - curl -X POST "http://0.0.0.0:4000/cache/delete" \ - -H "Authorization: Bearer sk-1234" \ - -d '{"keys": ["key1", "key2"]}' - ``` - - """ - try: - if litellm.cache is None: - raise HTTPException( - status_code=503, detail="Cache not initialized. litellm.cache is None" - ) - - request_data = await request.json() - keys = request_data.get("keys", None) - - if litellm.cache.type == "redis": - await litellm.cache.delete_cache_keys(keys=keys) - return { - "status": "success", - } - else: - raise HTTPException( - status_code=500, - detail=f"Cache type {litellm.cache.type} does not support deleting a key. only `redis` is supported", - ) - except Exception as e: - raise HTTPException( - status_code=500, - detail=f"Cache Delete Failed({str(e)})", - ) - - -@router.get( - "/cache/redis/info", - tags=["caching"], - dependencies=[Depends(user_api_key_auth)], -) -async def cache_redis_info(): - """ - Endpoint for getting /redis/info - """ - try: - if litellm.cache is None: - raise HTTPException( - status_code=503, detail="Cache not initialized. litellm.cache is None" - ) - if litellm.cache.type == "redis": - client_list = litellm.cache.cache.client_list() - redis_info = litellm.cache.cache.info() - num_clients = len(client_list) - return { - "num_clients": num_clients, - "clients": client_list, - "info": redis_info, - } - else: - raise HTTPException( - status_code=500, - detail=f"Cache type {litellm.cache.type} does not support flushing", - ) - except Exception as e: - raise HTTPException( - status_code=503, - detail=f"Service Unhealthy ({str(e)})", - ) - - -@router.post( - "/cache/flushall", - tags=["caching"], - dependencies=[Depends(user_api_key_auth)], -) -async def cache_flushall(): - """ - A function to flush all items from the cache. (All items will be deleted from the cache with this) - Raises HTTPException if the cache is not initialized or if the cache type does not support flushing. - Returns a dictionary with the status of the operation. - - Usage: - ``` - curl -X POST http://0.0.0.0:4000/cache/flushall -H "Authorization: Bearer sk-1234" - ``` - """ - try: - if litellm.cache is None: - raise HTTPException( - status_code=503, detail="Cache not initialized. litellm.cache is None" - ) - if litellm.cache.type == "redis": - litellm.cache.cache.flushall() - return { - "status": "success", - } - else: - raise HTTPException( - status_code=500, - detail=f"Cache type {litellm.cache.type} does not support flushing", - ) - except Exception as e: - raise HTTPException( - status_code=503, - detail=f"Service Unhealthy ({str(e)})", - ) - - @router.get( "/get/litellm_model_cost_map", include_in_schema=False, @@ -13787,3 +12782,4 @@ def cleanup_router_config_variables(): app.include_router(router) +app.include_router(caching_router) diff --git a/litellm/proxy/team_endpoints.py b/litellm/proxy/team_endpoints.py new file mode 100644 index 000000000..dcc8e9c44 --- /dev/null +++ b/litellm/proxy/team_endpoints.py @@ -0,0 +1,902 @@ +from typing import Optional, List +import fastapi +from fastapi import Depends, Request, APIRouter, Header, status +from fastapi import HTTPException +import copy +import json +import uuid +import litellm +import asyncio +from datetime import datetime, timedelta, timezone +from litellm._logging import verbose_proxy_logger +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy._types import ( + UserAPIKeyAuth, + LiteLLM_TeamTable, + LiteLLM_ModelTable, + LitellmUserRoles, + NewTeamRequest, + TeamMemberAddRequest, + UpdateTeamRequest, + BlockTeamRequest, + DeleteTeamRequest, + Member, + LitellmTableNames, + LiteLLM_AuditLogs, + TeamMemberDeleteRequest, + ProxyException, + CommonProxyErrors, +) +from litellm.proxy.management_helpers.utils import ( + add_new_member, + management_endpoint_wrapper, +) + +router = APIRouter( + prefix="/team", + tags=["team management"], +) + + +#### TEAM MANAGEMENT #### +@router.post( + "/team/new", + tags=["team management"], + dependencies=[Depends(user_api_key_auth)], + response_model=LiteLLM_TeamTable, +) +@management_endpoint_wrapper +async def new_team( + data: NewTeamRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + litellm_changed_by: Optional[str] = Header( + None, + description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", + ), +): + """ + Allow users to create a new team. Apply user permissions to their team. + + 👉 [Detailed Doc on setting team budgets](https://docs.litellm.ai/docs/proxy/team_budgets) + + + Parameters: + - team_alias: Optional[str] - User defined team alias + - team_id: Optional[str] - The team id of the user. If none passed, we'll generate it. + - members_with_roles: List[{"role": "admin" or "user", "user_id": ""}] - A list of users and their roles in the team. Get user_id when making a new user via `/user/new`. + - metadata: Optional[dict] - Metadata for team, store information for team. Example metadata = {"extra_info": "some info"} + - tpm_limit: Optional[int] - The TPM (Tokens Per Minute) limit for this team - all keys with this team_id will have at max this TPM limit + - rpm_limit: Optional[int] - The RPM (Requests Per Minute) limit for this team - all keys associated with this team_id will have at max this RPM limit + - max_budget: Optional[float] - The maximum budget allocated to the team - all keys for this team_id will have at max this max_budget + - budget_duration: Optional[str] - The duration of the budget for the team. Doc [here](https://docs.litellm.ai/docs/proxy/team_budgets) + - models: Optional[list] - A list of models associated with the team - all keys for this team_id will have at most, these models. If empty, assumes all models are allowed. + - blocked: bool - Flag indicating if the team is blocked or not - will stop all calls from keys with this team_id. + + Returns: + - team_id: (str) Unique team id - used for tracking spend across multiple keys for same team id. + + _deprecated_params: + - admins: list - A list of user_id's for the admin role + - users: list - A list of user_id's for the user role + + Example Request: + ``` + curl --location 'http://0.0.0.0:4000/team/new' \ + + --header 'Authorization: Bearer sk-1234' \ + + --header 'Content-Type: application/json' \ + + --data '{ + "team_alias": "my-new-team_2", + "members_with_roles": [{"role": "admin", "user_id": "user-1234"}, + {"role": "user", "user_id": "user-2434"}] + }' + + ``` + + ``` + curl --location 'http://0.0.0.0:4000/team/new' \ + + --header 'Authorization: Bearer sk-1234' \ + + --header 'Content-Type: application/json' \ + + --data '{ + "team_alias": "QA Prod Bot", + "max_budget": 0.000000001, + "budget_duration": "1d" + }' + + ``` + """ + from litellm.proxy.proxy_server import ( + prisma_client, + litellm_proxy_admin_name, + create_audit_log_for_update, + _duration_in_seconds, + ) + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if data.team_id is None: + data.team_id = str(uuid.uuid4()) + else: + # Check if team_id exists already + _existing_team_id = await prisma_client.get_data( + team_id=data.team_id, table_name="team", query_type="find_unique" + ) + if _existing_team_id is not None: + raise HTTPException( + status_code=400, + detail={ + "error": f"Team id = {data.team_id} already exists. Please use a different team id." + }, + ) + + if ( + user_api_key_dict.user_role is None + or user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN + ): # don't restrict proxy admin + if ( + data.tpm_limit is not None + and user_api_key_dict.tpm_limit is not None + and data.tpm_limit > user_api_key_dict.tpm_limit + ): + raise HTTPException( + status_code=400, + detail={ + "error": f"tpm limit higher than user max. User tpm limit={user_api_key_dict.tpm_limit}. User role={user_api_key_dict.user_role}" + }, + ) + + if ( + data.rpm_limit is not None + and user_api_key_dict.rpm_limit is not None + and data.rpm_limit > user_api_key_dict.rpm_limit + ): + raise HTTPException( + status_code=400, + detail={ + "error": f"rpm limit higher than user max. User rpm limit={user_api_key_dict.rpm_limit}. User role={user_api_key_dict.user_role}" + }, + ) + + if ( + data.max_budget is not None + and user_api_key_dict.max_budget is not None + and data.max_budget > user_api_key_dict.max_budget + ): + raise HTTPException( + status_code=400, + detail={ + "error": f"max budget higher than user max. User max budget={user_api_key_dict.max_budget}. User role={user_api_key_dict.user_role}" + }, + ) + + if data.models is not None and len(user_api_key_dict.models) > 0: + for m in data.models: + if m not in user_api_key_dict.models: + raise HTTPException( + status_code=400, + detail={ + "error": f"Model not in allowed user models. User allowed models={user_api_key_dict.models}. User id={user_api_key_dict.user_id}" + }, + ) + + if user_api_key_dict.user_id is not None: + creating_user_in_list = False + for member in data.members_with_roles: + if member.user_id == user_api_key_dict.user_id: + creating_user_in_list = True + + if creating_user_in_list == False: + data.members_with_roles.append( + Member(role="admin", user_id=user_api_key_dict.user_id) + ) + + ## ADD TO MODEL TABLE + _model_id = None + if data.model_aliases is not None and isinstance(data.model_aliases, dict): + litellm_modeltable = LiteLLM_ModelTable( + model_aliases=json.dumps(data.model_aliases), + created_by=user_api_key_dict.user_id or litellm_proxy_admin_name, + updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name, + ) + model_dict = await prisma_client.db.litellm_modeltable.create( + {**litellm_modeltable.json(exclude_none=True)} # type: ignore + ) # type: ignore + + _model_id = model_dict.id + + ## ADD TO TEAM TABLE + complete_team_data = LiteLLM_TeamTable( + **data.json(), + model_id=_model_id, + ) + + # If budget_duration is set, set `budget_reset_at` + if complete_team_data.budget_duration is not None: + duration_s = _duration_in_seconds(duration=complete_team_data.budget_duration) + reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) + complete_team_data.budget_reset_at = reset_at + + team_row = await prisma_client.insert_data( + data=complete_team_data.json(exclude_none=True), table_name="team" + ) + + ## ADD TEAM ID TO USER TABLE ## + for user in complete_team_data.members_with_roles: + ## add team id to user row ## + await prisma_client.update_data( + user_id=user.user_id, + data={"user_id": user.user_id, "teams": [team_row.team_id]}, + update_key_values_custom_query={ + "teams": { + "push ": [team_row.team_id], + } + }, + ) + + # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True + if litellm.store_audit_logs is True: + _updated_values = complete_team_data.json(exclude_none=True) + + _updated_values = json.dumps(_updated_values, default=str) + + asyncio.create_task( + create_audit_log_for_update( + request_data=LiteLLM_AuditLogs( + id=str(uuid.uuid4()), + updated_at=datetime.now(timezone.utc), + changed_by=litellm_changed_by + or user_api_key_dict.user_id + or litellm_proxy_admin_name, + changed_by_api_key=user_api_key_dict.api_key, + table_name=LitellmTableNames.TEAM_TABLE_NAME, + object_id=data.team_id, + action="created", + updated_values=_updated_values, + before_value=None, + ) + ) + ) + + try: + return team_row.model_dump() + except Exception as e: + return team_row.dict() + + +@router.post( + "/team/update", tags=["team management"], dependencies=[Depends(user_api_key_auth)] +) +@management_endpoint_wrapper +async def update_team( + data: UpdateTeamRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + litellm_changed_by: Optional[str] = Header( + None, + description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", + ), +): + """ + Use `/team/member_add` AND `/team/member/delete` to add/remove new team members + + You can now update team budget / rate limits via /team/update + + Parameters: + - team_id: str - The team id of the user. Required param. + - team_alias: Optional[str] - User defined team alias + - metadata: Optional[dict] - Metadata for team, store information for team. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" } + - tpm_limit: Optional[int] - The TPM (Tokens Per Minute) limit for this team - all keys with this team_id will have at max this TPM limit + - rpm_limit: Optional[int] - The RPM (Requests Per Minute) limit for this team - all keys associated with this team_id will have at max this RPM limit + - max_budget: Optional[float] - The maximum budget allocated to the team - all keys for this team_id will have at max this max_budget + - budget_duration: Optional[str] - The duration of the budget for the team. Doc [here](https://docs.litellm.ai/docs/proxy/team_budgets) + - models: Optional[list] - A list of models associated with the team - all keys for this team_id will have at most, these models. If empty, assumes all models are allowed. + - blocked: bool - Flag indicating if the team is blocked or not - will stop all calls from keys with this team_id. + + Example - update team TPM Limit + + ``` + curl --location 'http://0.0.0.0:8000/team/update' \ + + --header 'Authorization: Bearer sk-1234' \ + + --header 'Content-Type: application/json' \ + + --data-raw '{ + "team_id": "litellm-test-client-id-new", + "tpm_limit": 100 + }' + ``` + + Example - Update Team `max_budget` budget + ``` + curl --location 'http://0.0.0.0:8000/team/update' \ + + --header 'Authorization: Bearer sk-1234' \ + + --header 'Content-Type: application/json' \ + + --data-raw '{ + "team_id": "litellm-test-client-id-new", + "max_budget": 10 + }' + ``` + """ + from litellm.proxy.proxy_server import ( + prisma_client, + litellm_proxy_admin_name, + create_audit_log_for_update, + _duration_in_seconds, + ) + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if data.team_id is None: + raise HTTPException(status_code=400, detail={"error": "No team id passed in"}) + verbose_proxy_logger.debug("/team/update - %s", data) + + existing_team_row = await prisma_client.get_data( + team_id=data.team_id, table_name="team", query_type="find_unique" + ) + if existing_team_row is None: + raise HTTPException( + status_code=404, + detail={"error": f"Team not found, passed team_id={data.team_id}"}, + ) + + updated_kv = data.json(exclude_none=True) + + # Check budget_duration and budget_reset_at + if data.budget_duration is not None: + duration_s = _duration_in_seconds(duration=data.budget_duration) + reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) + + # set the budget_reset_at in DB + updated_kv["budget_reset_at"] = reset_at + + team_row = await prisma_client.update_data( + update_key_values=updated_kv, + data=updated_kv, + table_name="team", + team_id=data.team_id, + ) + + # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True + if litellm.store_audit_logs is True: + _before_value = existing_team_row.json(exclude_none=True) + _before_value = json.dumps(_before_value, default=str) + _after_value: str = json.dumps(updated_kv, default=str) + + asyncio.create_task( + create_audit_log_for_update( + request_data=LiteLLM_AuditLogs( + id=str(uuid.uuid4()), + updated_at=datetime.now(timezone.utc), + changed_by=litellm_changed_by + or user_api_key_dict.user_id + or litellm_proxy_admin_name, + changed_by_api_key=user_api_key_dict.api_key, + table_name=LitellmTableNames.TEAM_TABLE_NAME, + object_id=data.team_id, + action="updated", + updated_values=_after_value, + before_value=_before_value, + ) + ) + ) + + return team_row + + +@router.post( + "/team/member_add", + tags=["team management"], + dependencies=[Depends(user_api_key_auth)], +) +@management_endpoint_wrapper +async def team_member_add( + data: TeamMemberAddRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + [BETA] + + Add new members (either via user_email or user_id) to a team + + If user doesn't exist, new user row will also be added to User Table + + ``` + + curl -X POST 'http://0.0.0.0:4000/team/member_add' \ + -H 'Authorization: Bearer sk-1234' \ + -H 'Content-Type: application/json' \ + -d '{"team_id": "45e3e396-ee08-4a61-a88e-16b3ce7e0849", "member": {"role": "user", "user_id": "krrish247652@berri.ai"}}' + + ``` + """ + from litellm.proxy.proxy_server import ( + prisma_client, + litellm_proxy_admin_name, + create_audit_log_for_update, + _duration_in_seconds, + ) + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if data.team_id is None: + raise HTTPException(status_code=400, detail={"error": "No team id passed in"}) + + if data.member is None: + raise HTTPException( + status_code=400, detail={"error": "No member/members passed in"} + ) + + existing_team_row = await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": data.team_id} + ) + if existing_team_row is None: + raise HTTPException( + status_code=404, + detail={ + "error": f"Team not found for team_id={getattr(data, 'team_id', None)}" + }, + ) + + complete_team_data = LiteLLM_TeamTable(**existing_team_row.model_dump()) + + if isinstance(data.member, Member): + # add to team db + new_member = data.member + + complete_team_data.members_with_roles.append(new_member) + + elif isinstance(data.member, List): + # add to team db + new_members = data.member + + complete_team_data.members_with_roles.extend(new_members) + + # ADD MEMBER TO TEAM + _db_team_members = [m.model_dump() for m in complete_team_data.members_with_roles] + updated_team = await prisma_client.db.litellm_teamtable.update( + where={"team_id": data.team_id}, + data={"members_with_roles": json.dumps(_db_team_members)}, # type: ignore + ) + + if isinstance(data.member, Member): + await add_new_member( + new_member=data.member, + max_budget_in_team=data.max_budget_in_team, + prisma_client=prisma_client, + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name=litellm_proxy_admin_name, + team_id=data.team_id, + ) + elif isinstance(data.member, List): + tasks: List = [] + for m in data.member: + await add_new_member( + new_member=m, + max_budget_in_team=data.max_budget_in_team, + prisma_client=prisma_client, + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name=litellm_proxy_admin_name, + team_id=data.team_id, + ) + await asyncio.gather(*tasks) + + return updated_team + + +@router.post( + "/team/member_delete", + tags=["team management"], + dependencies=[Depends(user_api_key_auth)], +) +@management_endpoint_wrapper +async def team_member_delete( + data: TeamMemberDeleteRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + [BETA] + + delete members (either via user_email or user_id) from a team + + If user doesn't exist, an exception will be raised + ``` + curl -X POST 'http://0.0.0.0:8000/team/update' \ + + -H 'Authorization: Bearer sk-1234' \ + + -H 'Content-Type: application/json' \ + + -D '{ + "team_id": "45e3e396-ee08-4a61-a88e-16b3ce7e0849", + "user_id": "krrish247652@berri.ai" + }' + ``` + """ + from litellm.proxy.proxy_server import ( + prisma_client, + litellm_proxy_admin_name, + create_audit_log_for_update, + _duration_in_seconds, + ) + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if data.team_id is None: + raise HTTPException(status_code=400, detail={"error": "No team id passed in"}) + + if data.user_id is None and data.user_email is None: + raise HTTPException( + status_code=400, + detail={"error": "Either user_id or user_email needs to be passed in"}, + ) + + _existing_team_row = await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": data.team_id} + ) + + if _existing_team_row is None: + raise HTTPException( + status_code=400, + detail={"error": "Team id={} does not exist in db".format(data.team_id)}, + ) + existing_team_row = LiteLLM_TeamTable(**_existing_team_row.model_dump()) + + ## DELETE MEMBER FROM TEAM + new_team_members: List[Member] = [] + for m in existing_team_row.members_with_roles: + if ( + data.user_id is not None + and m.user_id is not None + and data.user_id == m.user_id + ): + continue + elif ( + data.user_email is not None + and m.user_email is not None + and data.user_email == m.user_email + ): + continue + new_team_members.append(m) + existing_team_row.members_with_roles = new_team_members + + _db_new_team_members: List[dict] = [m.model_dump() for m in new_team_members] + + _ = await prisma_client.db.litellm_teamtable.update( + where={ + "team_id": data.team_id, + }, + data={"members_with_roles": json.dumps(_db_new_team_members)}, # type: ignore + ) + + ## DELETE TEAM ID from USER ROW, IF EXISTS ## + # get user row + key_val = {} + if data.user_id is not None: + key_val["user_id"] = data.user_id + elif data.user_email is not None: + key_val["user_email"] = data.user_email + existing_user_rows = await prisma_client.db.litellm_usertable.find_many( + where=key_val # type: ignore + ) + + if existing_user_rows is not None and ( + isinstance(existing_user_rows, list) and len(existing_user_rows) > 0 + ): + for existing_user in existing_user_rows: + team_list = [] + if data.team_id in existing_user.teams: + team_list = existing_user.teams + team_list.remove(data.team_id) + await prisma_client.db.litellm_usertable.update( + where={ + "user_id": existing_user.user_id, + }, + data={"teams": {"set": team_list}}, + ) + + return existing_team_row + + +@router.post( + "/team/delete", tags=["team management"], dependencies=[Depends(user_api_key_auth)] +) +@management_endpoint_wrapper +async def delete_team( + data: DeleteTeamRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + litellm_changed_by: Optional[str] = Header( + None, + description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", + ), +): + """ + delete team and associated team keys + + ``` + curl --location 'http://0.0.0.0:8000/team/delete' \ + + --header 'Authorization: Bearer sk-1234' \ + + --header 'Content-Type: application/json' \ + + --data-raw '{ + "team_ids": ["45e3e396-ee08-4a61-a88e-16b3ce7e0849"] + }' + ``` + """ + from litellm.proxy.proxy_server import ( + prisma_client, + litellm_proxy_admin_name, + create_audit_log_for_update, + _duration_in_seconds, + ) + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if data.team_ids is None: + raise HTTPException(status_code=400, detail={"error": "No team id passed in"}) + + # check that all teams passed exist + for team_id in data.team_ids: + team_row = await prisma_client.get_data( # type: ignore + team_id=team_id, table_name="team", query_type="find_unique" + ) + if team_row is None: + raise HTTPException( + status_code=404, + detail={"error": f"Team not found, passed team_id={team_id}"}, + ) + + # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True + # we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes + if litellm.store_audit_logs is True: + # make an audit log for each team deleted + for team_id in data.team_ids: + team_row = await prisma_client.get_data( # type: ignore + team_id=team_id, table_name="team", query_type="find_unique" + ) + + _team_row = team_row.json(exclude_none=True) + + asyncio.create_task( + create_audit_log_for_update( + request_data=LiteLLM_AuditLogs( + id=str(uuid.uuid4()), + updated_at=datetime.now(timezone.utc), + changed_by=litellm_changed_by + or user_api_key_dict.user_id + or litellm_proxy_admin_name, + changed_by_api_key=user_api_key_dict.api_key, + table_name=LitellmTableNames.TEAM_TABLE_NAME, + object_id=team_id, + action="deleted", + updated_values="{}", + before_value=_team_row, + ) + ) + ) + + # End of Audit logging + + ## DELETE ASSOCIATED KEYS + await prisma_client.delete_data(team_id_list=data.team_ids, table_name="key") + ## DELETE TEAMS + deleted_teams = await prisma_client.delete_data( + team_id_list=data.team_ids, table_name="team" + ) + return deleted_teams + + +@router.get( + "/team/info", tags=["team management"], dependencies=[Depends(user_api_key_auth)] +) +@management_endpoint_wrapper +async def team_info( + http_request: Request, + team_id: str = fastapi.Query( + default=None, description="Team ID in the request parameters" + ), +): + """ + get info on team + related keys + + ``` + curl --location 'http://localhost:4000/team/info' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "teams": ["",..] + }' + ``` + """ + from litellm.proxy.proxy_server import ( + prisma_client, + litellm_proxy_admin_name, + create_audit_log_for_update, + _duration_in_seconds, + ) + + try: + if prisma_client is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "error": f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" + }, + ) + if team_id is None: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail={"message": "Malformed request. No team id passed in."}, + ) + + team_info = await prisma_client.get_data( + team_id=team_id, table_name="team", query_type="find_unique" + ) + if team_info is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={"message": f"Team not found, passed team id: {team_id}."}, + ) + + ## GET ALL KEYS ## + keys = await prisma_client.get_data( + team_id=team_id, + table_name="key", + query_type="find_all", + expires=datetime.now(), + ) + + if team_info is None: + ## make sure we still return a total spend ## + spend = 0 + for k in keys: + spend += getattr(k, "spend", 0) + team_info = {"spend": spend} + + ## REMOVE HASHED TOKEN INFO before returning ## + for key in keys: + try: + key = key.model_dump() # noqa + except: + # if using pydantic v1 + key = key.dict() + key.pop("token", None) + return {"team_id": team_id, "team_info": team_info, "keys": keys} + + except Exception as e: + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type="auth_error", + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type="auth_error", + param=getattr(e, "param", "None"), + code=status.HTTP_400_BAD_REQUEST, + ) + + +@router.post( + "/team/block", tags=["team management"], dependencies=[Depends(user_api_key_auth)] +) +@management_endpoint_wrapper +async def block_team( + data: BlockTeamRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Blocks all calls from keys with this team id. + """ + from litellm.proxy.proxy_server import ( + prisma_client, + litellm_proxy_admin_name, + create_audit_log_for_update, + _duration_in_seconds, + ) + + if prisma_client is None: + raise Exception("No DB Connected.") + + record = await prisma_client.db.litellm_teamtable.update( + where={"team_id": data.team_id}, data={"blocked": True} # type: ignore + ) + + return record + + +@router.post( + "/team/unblock", tags=["team management"], dependencies=[Depends(user_api_key_auth)] +) +@management_endpoint_wrapper +async def unblock_team( + data: BlockTeamRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Blocks all calls from keys with this team id. + """ + from litellm.proxy.proxy_server import ( + prisma_client, + litellm_proxy_admin_name, + create_audit_log_for_update, + _duration_in_seconds, + ) + + if prisma_client is None: + raise Exception("No DB Connected.") + + record = await prisma_client.db.litellm_teamtable.update( + where={"team_id": data.team_id}, data={"blocked": False} # type: ignore + ) + + return record + + +@router.get( + "/team/list", tags=["team management"], dependencies=[Depends(user_api_key_auth)] +) +@management_endpoint_wrapper +async def list_team( + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + [Admin-only] List all available teams + + ``` + curl --location --request GET 'http://0.0.0.0:4000/team/list' \ + --header 'Authorization: Bearer sk-1234' + ``` + """ + from litellm.proxy.proxy_server import ( + prisma_client, + litellm_proxy_admin_name, + create_audit_log_for_update, + _duration_in_seconds, + ) + + if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN: + raise HTTPException( + status_code=401, + detail={ + "error": "Admin-only endpoint. Your user role={}".format( + user_api_key_dict.user_role + ) + }, + ) + + if prisma_client is None: + raise HTTPException( + status_code=400, + detail={"error": CommonProxyErrors.db_not_connected_error.value}, + ) + + response = await prisma_client.db.litellm_teamtable.find_many() + + return response diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index f53e241ce..cf2e62e8d 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -50,10 +50,7 @@ from litellm.proxy.proxy_server import ( spend_key_fn, view_spend_logs, user_info, - team_info, info_key_fn, - new_team, - update_team, chat_completion, completion, embeddings, @@ -63,6 +60,8 @@ from litellm.proxy.proxy_server import ( model_list, LitellmUserRoles, ) + +from litellm.proxy.team_endpoints import team_info, new_team, update_team from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend from litellm._logging import verbose_proxy_logger