diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 94ff51fd2a..2702da79fe 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -691,20 +691,24 @@ def _enforced_params_check( return True -def move_guardrails_to_metadata( +def _add_guardrails_from_key_or_team_metadata( + key_metadata: Optional[dict], + team_metadata: Optional[dict], data: dict, - _metadata_variable_name: str, - user_api_key_dict: UserAPIKeyAuth, -): + metadata_variable_name: str, +) -> None: """ - Heper to add guardrails from request to metadata + Helper add guardrails from key or team metadata to request data - - If guardrails set on API Key metadata then sets guardrails on request metadata - - If guardrails not set on API key, then checks request metadata + Args: + key_metadata: The key metadata dictionary to check for guardrails + team_metadata: The team metadata dictionary to check for guardrails + data: The request data to update + metadata_variable_name: The name of the metadata field in data """ - if user_api_key_dict.metadata: - if "guardrails" in user_api_key_dict.metadata: + for _management_object_metadata in [key_metadata, team_metadata]: + if _management_object_metadata and "guardrails" in _management_object_metadata: from litellm.proxy.proxy_server import premium_user if premium_user is not True: @@ -712,11 +716,31 @@ def move_guardrails_to_metadata( f"Using Guardrails on API Key {CommonProxyErrors.not_premium_user}" ) - data[_metadata_variable_name]["guardrails"] = user_api_key_dict.metadata[ + data[metadata_variable_name]["guardrails"] = _management_object_metadata[ "guardrails" ] - return + +def move_guardrails_to_metadata( + data: dict, + _metadata_variable_name: str, + user_api_key_dict: UserAPIKeyAuth, +): + """ + Helper to add guardrails from request to metadata + + - If guardrails set on API Key metadata then sets guardrails on request metadata + - If guardrails not set on API key, then checks request metadata + """ + # Check key-level guardrails + _add_guardrails_from_key_or_team_metadata( + key_metadata=user_api_key_dict.metadata, + team_metadata=user_api_key_dict.team_metadata, + data=data, + metadata_variable_name=_metadata_variable_name, + ) + + # Check request-level guardrails if "guardrails" in data: data[_metadata_variable_name]["guardrails"] = data["guardrails"] del data["guardrails"] diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index c20bbcb959..f7ffd8fecc 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -58,7 +58,8 @@ from litellm.proxy.management_helpers.utils import ( add_new_member, management_endpoint_wrapper, ) -from litellm.proxy.utils import PrismaClient, _premium_user_check +from litellm.proxy.utils import PrismaClient, handle_exception_on_proxy, _premium_user_check + router = APIRouter() @@ -175,177 +176,182 @@ async def new_team( # noqa: PLR0915 }' ``` """ - from litellm.proxy.proxy_server import ( - create_audit_log_for_update, - duration_in_seconds, - litellm_proxy_admin_name, - prisma_client, - ) - - if prisma_client is None: - raise HTTPException(status_code=500, detail={"error": "No db connected"}) - - if data.team_id is None: - data.team_id = str(uuid.uuid4()) - else: - # Check if team_id exists already - _existing_team_id = await prisma_client.get_data( - team_id=data.team_id, table_name="team", query_type="find_unique" + try: + from litellm.proxy.proxy_server import ( + create_audit_log_for_update, + duration_in_seconds, + litellm_proxy_admin_name, + prisma_client, ) - if _existing_team_id is not None: - raise HTTPException( - status_code=400, - detail={ - "error": f"Team id = {data.team_id} already exists. Please use a different team id." - }, - ) - if ( - user_api_key_dict.user_role is None - or user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN - ): # don't restrict proxy admin - if ( - data.tpm_limit is not None - and user_api_key_dict.tpm_limit is not None - and data.tpm_limit > user_api_key_dict.tpm_limit - ): - raise HTTPException( - status_code=400, - detail={ - "error": f"tpm limit higher than user max. User tpm limit={user_api_key_dict.tpm_limit}. User role={user_api_key_dict.user_role}" - }, + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if data.team_id is None: + data.team_id = str(uuid.uuid4()) + else: + # Check if team_id exists already + _existing_team_id = await prisma_client.get_data( + team_id=data.team_id, table_name="team", query_type="find_unique" ) + if _existing_team_id is not None: + raise HTTPException( + status_code=400, + detail={ + "error": f"Team id = {data.team_id} already exists. Please use a different team id." + }, + ) if ( - data.rpm_limit is not None - and user_api_key_dict.rpm_limit is not None - and data.rpm_limit > user_api_key_dict.rpm_limit - ): - raise HTTPException( - status_code=400, - detail={ - "error": f"rpm limit higher than user max. User rpm limit={user_api_key_dict.rpm_limit}. User role={user_api_key_dict.user_role}" + user_api_key_dict.user_role is None + or user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN + ): # don't restrict proxy admin + if ( + data.tpm_limit is not None + and user_api_key_dict.tpm_limit is not None + and data.tpm_limit > user_api_key_dict.tpm_limit + ): + raise HTTPException( + status_code=400, + detail={ + "error": f"tpm limit higher than user max. User tpm limit={user_api_key_dict.tpm_limit}. User role={user_api_key_dict.user_role}" + }, + ) + + if ( + data.rpm_limit is not None + and user_api_key_dict.rpm_limit is not None + and data.rpm_limit > user_api_key_dict.rpm_limit + ): + raise HTTPException( + status_code=400, + detail={ + "error": f"rpm limit higher than user max. User rpm limit={user_api_key_dict.rpm_limit}. User role={user_api_key_dict.user_role}" + }, + ) + + if ( + data.max_budget is not None + and user_api_key_dict.max_budget is not None + and data.max_budget > user_api_key_dict.max_budget + ): + raise HTTPException( + status_code=400, + detail={ + "error": f"max budget higher than user max. User max budget={user_api_key_dict.max_budget}. User role={user_api_key_dict.user_role}" + }, + ) + + if data.models is not None and len(user_api_key_dict.models) > 0: + for m in data.models: + if m not in user_api_key_dict.models: + raise HTTPException( + status_code=400, + detail={ + "error": f"Model not in allowed user models. User allowed models={user_api_key_dict.models}. User id={user_api_key_dict.user_id}" + }, + ) + + if user_api_key_dict.user_id is not None: + creating_user_in_list = False + for member in data.members_with_roles: + if member.user_id == user_api_key_dict.user_id: + creating_user_in_list = True + + if creating_user_in_list is False: + data.members_with_roles.append( + Member(role="admin", user_id=user_api_key_dict.user_id) + ) + + ## ADD TO MODEL TABLE + _model_id = None + if data.model_aliases is not None and isinstance(data.model_aliases, dict): + litellm_modeltable = LiteLLM_ModelTable( + model_aliases=json.dumps(data.model_aliases), + created_by=user_api_key_dict.user_id or litellm_proxy_admin_name, + updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name, + ) + model_dict = await prisma_client.db.litellm_modeltable.create( + {**litellm_modeltable.json(exclude_none=True)} # type: ignore + ) # type: ignore + + _model_id = model_dict.id + + ## ADD TO TEAM TABLE + complete_team_data = LiteLLM_TeamTable( + **data.json(), + model_id=_model_id, + ) + + # Set Management Endpoint Metadata Fields + for field in LiteLLM_ManagementEndpoint_MetadataFields_Premium: + if getattr(data, field) is not None: + _set_team_metadata_field( + team_data=complete_team_data, + field_name=field, + value=getattr(data, field), + ) + + # If budget_duration is set, set `budget_reset_at` + if complete_team_data.budget_duration is not None: + duration_s = duration_in_seconds( + duration=complete_team_data.budget_duration + ) + reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) + complete_team_data.budget_reset_at = reset_at + + complete_team_data_dict = complete_team_data.model_dump(exclude_none=True) + complete_team_data_dict = prisma_client.jsonify_team_object( + db_data=complete_team_data_dict + ) + team_row: LiteLLM_TeamTable = await prisma_client.db.litellm_teamtable.create( + data=complete_team_data_dict, + include={"litellm_model_table": True}, # type: ignore + ) + + ## ADD TEAM ID TO USER TABLE ## + for user in complete_team_data.members_with_roles: + ## add team id to user row ## + await prisma_client.update_data( + user_id=user.user_id, + data={"user_id": user.user_id, "teams": [team_row.team_id]}, + update_key_values_custom_query={ + "teams": { + "push ": [team_row.team_id], + } }, ) - if ( - data.max_budget is not None - and user_api_key_dict.max_budget is not None - and data.max_budget > user_api_key_dict.max_budget - ): - raise HTTPException( - status_code=400, - detail={ - "error": f"max budget higher than user max. User max budget={user_api_key_dict.max_budget}. User role={user_api_key_dict.user_role}" - }, - ) + # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True + if litellm.store_audit_logs is True: + _updated_values = complete_team_data.json(exclude_none=True) - if data.models is not None and len(user_api_key_dict.models) > 0: - for m in data.models: - if m not in user_api_key_dict.models: - raise HTTPException( - status_code=400, - detail={ - "error": f"Model not in allowed user models. User allowed models={user_api_key_dict.models}. User id={user_api_key_dict.user_id}" - }, + _updated_values = json.dumps(_updated_values, default=str) + + asyncio.create_task( + create_audit_log_for_update( + request_data=LiteLLM_AuditLogs( + id=str(uuid.uuid4()), + updated_at=datetime.now(timezone.utc), + changed_by=litellm_changed_by + or user_api_key_dict.user_id + or litellm_proxy_admin_name, + changed_by_api_key=user_api_key_dict.api_key, + table_name=LitellmTableNames.TEAM_TABLE_NAME, + object_id=data.team_id, + action="created", + updated_values=_updated_values, + before_value=None, ) - - if user_api_key_dict.user_id is not None: - creating_user_in_list = False - for member in data.members_with_roles: - if member.user_id == user_api_key_dict.user_id: - creating_user_in_list = True - - if creating_user_in_list is False: - data.members_with_roles.append( - Member(role="admin", user_id=user_api_key_dict.user_id) - ) - - ## ADD TO MODEL TABLE - _model_id = None - if data.model_aliases is not None and isinstance(data.model_aliases, dict): - litellm_modeltable = LiteLLM_ModelTable( - model_aliases=json.dumps(data.model_aliases), - created_by=user_api_key_dict.user_id or litellm_proxy_admin_name, - updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name, - ) - model_dict = await prisma_client.db.litellm_modeltable.create( - {**litellm_modeltable.json(exclude_none=True)} # type: ignore - ) # type: ignore - - _model_id = model_dict.id - - ## ADD TO TEAM TABLE - complete_team_data = LiteLLM_TeamTable( - **data.json(), - model_id=_model_id, - ) - - # Set Management Endpoint Metadata Fields - for field in LiteLLM_ManagementEndpoint_MetadataFields_Premium: - if getattr(data, field) is not None: - _set_team_metadata_field( - team_data=complete_team_data, - field_name=field, - value=getattr(data, field), - ) - - # If budget_duration is set, set `budget_reset_at` - if complete_team_data.budget_duration is not None: - duration_s = duration_in_seconds(duration=complete_team_data.budget_duration) - reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) - complete_team_data.budget_reset_at = reset_at - - complete_team_data_dict = complete_team_data.model_dump(exclude_none=True) - complete_team_data_dict = prisma_client.jsonify_team_object( - db_data=complete_team_data_dict - ) - team_row: LiteLLM_TeamTable = await prisma_client.db.litellm_teamtable.create( - data=complete_team_data_dict, - include={"litellm_model_table": True}, # type: ignore - ) - - ## ADD TEAM ID TO USER TABLE ## - for user in complete_team_data.members_with_roles: - ## add team id to user row ## - await prisma_client.update_data( - user_id=user.user_id, - data={"user_id": user.user_id, "teams": [team_row.team_id]}, - update_key_values_custom_query={ - "teams": { - "push ": [team_row.team_id], - } - }, - ) - - # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True - if litellm.store_audit_logs is True: - _updated_values = complete_team_data.json(exclude_none=True) - - _updated_values = json.dumps(_updated_values, default=str) - - asyncio.create_task( - create_audit_log_for_update( - request_data=LiteLLM_AuditLogs( - id=str(uuid.uuid4()), - updated_at=datetime.now(timezone.utc), - changed_by=litellm_changed_by - or user_api_key_dict.user_id - or litellm_proxy_admin_name, - changed_by_api_key=user_api_key_dict.api_key, - table_name=LitellmTableNames.TEAM_TABLE_NAME, - object_id=data.team_id, - action="created", - updated_values=_updated_values, - before_value=None, ) ) - ) - try: - return team_row.model_dump() - except Exception: - return team_row.dict() + try: + return team_row.model_dump() + except Exception: + return team_row.dict() + except Exception as e: + raise handle_exception_on_proxy(e) async def _update_model_table( diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index cf93bd679b..791dac4e8c 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -24,5 +24,3 @@ guardrails: mode: "during_call" guardrailIdentifier: gf3sc1mzinjw guardrailVersion: "DRAFT" - default_on: true - diff --git a/tests/otel_tests/test_guardrails.py b/tests/otel_tests/test_guardrails.py index 12d9d1c384..e386d5151e 100644 --- a/tests/otel_tests/test_guardrails.py +++ b/tests/otel_tests/test_guardrails.py @@ -3,6 +3,7 @@ import asyncio import aiohttp, openai from openai import OpenAI, AsyncOpenAI from typing import Optional, List, Union +import json import uuid @@ -40,21 +41,22 @@ async def chat_completion( raise Exception(response_text) # response headers - response_headers = response.headers + response_headers = dict(response.headers) print("response headers=", response_headers) return await response.json(), response_headers -async def generate_key(session, guardrails): +async def generate_key( + session, guardrails: Optional[List] = None, team_id: Optional[str] = None +): url = "http://0.0.0.0:4000/key/generate" headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = {} if guardrails: - data = { - "guardrails": guardrails, - } - else: - data = {} + data["guardrails"] = guardrails + if team_id: + data["team_id"] = team_id async with session.post(url, headers=headers, json=data) as response: status = response.status @@ -148,7 +150,6 @@ async def test_no_llm_guard_triggered(): @pytest.mark.asyncio -@pytest.mark.skip(reason="Aporia account disabled") async def test_guardrails_with_api_key_controls(): """ - Make two API Keys @@ -161,8 +162,7 @@ async def test_guardrails_with_api_key_controls(): key_with_guardrails = await generate_key( session=session, guardrails=[ - "aporia-post-guard", - "aporia-pre-guard", + "bedrock-pre-guard", ], ) @@ -185,19 +185,15 @@ async def test_guardrails_with_api_key_controls(): assert "x-litellm-applied-guardrails" not in headers # test guardrails triggered for key with guardrails - try: - response, headers = await chat_completion( - session, - key_with_guardrails, - model="fake-openai-endpoint", - messages=[ - {"role": "user", "content": f"Hello my name is ishaan@berri.ai"} - ], - ) - pytest.fail("Should have thrown an exception") - except Exception as e: - print(e) - assert "Aporia detected and blocked PII" in str(e) + response, headers = await chat_completion( + session, + key_with_guardrails, + model="fake-openai-endpoint", + messages=[{"role": "user", "content": f"Hello my name is ishaan@berri.ai"}], + ) + + assert "x-litellm-applied-guardrails" in headers + assert headers["x-litellm-applied-guardrails"] == "bedrock-pre-guard" @pytest.mark.asyncio @@ -241,3 +237,82 @@ async def test_custom_guardrail_during_call_triggered(): except Exception as e: print(e) assert "Guardrail failed words - `litellm` detected" in str(e) + + +async def create_team(session, guardrails: Optional[List] = None): + url = "http://0.0.0.0:4000/team/new" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = {"guardrails": guardrails} + + print("request data=", data) + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(response_text) + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + + return await response.json() + + +@pytest.mark.asyncio +async def test_guardrails_with_team_controls(): + """ + - Create a team with guardrails + - Make two API Keys + - Key 1 not associated with team + - Key 2 associated with team (inherits team guardrails) + - Request with Key 1 -> should be success with no guardrails + - Request with Key 2 -> should error since team guardrails are triggered + """ + async with aiohttp.ClientSession() as session: + + # Create team with guardrails + team = await create_team( + session=session, + guardrails=[ + "bedrock-pre-guard", + ], + ) + + print("team=", team) + + team_id = team["team_id"] + + # Create key with team association + key_with_team = await generate_key(session=session, team_id=team_id) + key_with_team = key_with_team["key"] + + # Create key without team + key_without_team = await generate_key( + session=session, + ) + key_without_team = key_without_team["key"] + + # Test no guardrails triggered for key without a team + response, headers = await chat_completion( + session, + key_without_team, + model="fake-openai-endpoint", + messages=[{"role": "user", "content": "Hello my name is ishaan@berri.ai"}], + ) + await asyncio.sleep(3) + + print("response=", response, "response headers", headers) + assert "x-litellm-applied-guardrails" not in headers + + response, headers = await chat_completion( + session, + key_with_team, + model="fake-openai-endpoint", + messages=[{"role": "user", "content": "Hello my name is ishaan@berri.ai"}], + ) + + print("response headers=", json.dumps(headers, indent=4)) + + assert "x-litellm-applied-guardrails" in headers + assert headers["x-litellm-applied-guardrails"] == "bedrock-pre-guard"