(Feat) set guardrails per team (#7993)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 35s

* _add_guardrails_from_key_or_team_metadata

* e2e test test_guardrails_with_team_controls

* add try/except on team new

* test_guardrails_with_team_controls

* test_guardrails_with_api_key_controls
This commit is contained in:
Ishaan Jaff 2025-01-25 10:41:11 -08:00 committed by GitHub
parent 669b4fc955
commit a7b3c664d1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 298 additions and 195 deletions

View file

@ -691,20 +691,24 @@ def _enforced_params_check(
return True return True
def move_guardrails_to_metadata( def _add_guardrails_from_key_or_team_metadata(
key_metadata: Optional[dict],
team_metadata: Optional[dict],
data: dict, data: dict,
_metadata_variable_name: str, metadata_variable_name: str,
user_api_key_dict: UserAPIKeyAuth, ) -> None:
):
""" """
Heper to add guardrails from request to metadata Helper add guardrails from key or team metadata to request data
- If guardrails set on API Key metadata then sets guardrails on request metadata Args:
- If guardrails not set on API key, then checks request metadata key_metadata: The key metadata dictionary to check for guardrails
team_metadata: The team metadata dictionary to check for guardrails
data: The request data to update
metadata_variable_name: The name of the metadata field in data
""" """
if user_api_key_dict.metadata: for _management_object_metadata in [key_metadata, team_metadata]:
if "guardrails" in user_api_key_dict.metadata: if _management_object_metadata and "guardrails" in _management_object_metadata:
from litellm.proxy.proxy_server import premium_user from litellm.proxy.proxy_server import premium_user
if premium_user is not True: if premium_user is not True:
@ -712,11 +716,31 @@ def move_guardrails_to_metadata(
f"Using Guardrails on API Key {CommonProxyErrors.not_premium_user}" f"Using Guardrails on API Key {CommonProxyErrors.not_premium_user}"
) )
data[_metadata_variable_name]["guardrails"] = user_api_key_dict.metadata[ data[metadata_variable_name]["guardrails"] = _management_object_metadata[
"guardrails" "guardrails"
] ]
return
def move_guardrails_to_metadata(
data: dict,
_metadata_variable_name: str,
user_api_key_dict: UserAPIKeyAuth,
):
"""
Helper to add guardrails from request to metadata
- If guardrails set on API Key metadata then sets guardrails on request metadata
- If guardrails not set on API key, then checks request metadata
"""
# Check key-level guardrails
_add_guardrails_from_key_or_team_metadata(
key_metadata=user_api_key_dict.metadata,
team_metadata=user_api_key_dict.team_metadata,
data=data,
metadata_variable_name=_metadata_variable_name,
)
# Check request-level guardrails
if "guardrails" in data: if "guardrails" in data:
data[_metadata_variable_name]["guardrails"] = data["guardrails"] data[_metadata_variable_name]["guardrails"] = data["guardrails"]
del data["guardrails"] del data["guardrails"]

View file

@ -58,7 +58,8 @@ from litellm.proxy.management_helpers.utils import (
add_new_member, add_new_member,
management_endpoint_wrapper, management_endpoint_wrapper,
) )
from litellm.proxy.utils import PrismaClient, _premium_user_check from litellm.proxy.utils import PrismaClient, handle_exception_on_proxy, _premium_user_check
router = APIRouter() router = APIRouter()
@ -175,177 +176,182 @@ async def new_team( # noqa: PLR0915
}' }'
``` ```
""" """
from litellm.proxy.proxy_server import ( try:
create_audit_log_for_update, from litellm.proxy.proxy_server import (
duration_in_seconds, create_audit_log_for_update,
litellm_proxy_admin_name, duration_in_seconds,
prisma_client, litellm_proxy_admin_name,
) prisma_client,
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})
if data.team_id is None:
data.team_id = str(uuid.uuid4())
else:
# Check if team_id exists already
_existing_team_id = await prisma_client.get_data(
team_id=data.team_id, table_name="team", query_type="find_unique"
) )
if _existing_team_id is not None:
raise HTTPException(
status_code=400,
detail={
"error": f"Team id = {data.team_id} already exists. Please use a different team id."
},
)
if ( if prisma_client is None:
user_api_key_dict.user_role is None raise HTTPException(status_code=500, detail={"error": "No db connected"})
or user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN
): # don't restrict proxy admin if data.team_id is None:
if ( data.team_id = str(uuid.uuid4())
data.tpm_limit is not None else:
and user_api_key_dict.tpm_limit is not None # Check if team_id exists already
and data.tpm_limit > user_api_key_dict.tpm_limit _existing_team_id = await prisma_client.get_data(
): team_id=data.team_id, table_name="team", query_type="find_unique"
raise HTTPException(
status_code=400,
detail={
"error": f"tpm limit higher than user max. User tpm limit={user_api_key_dict.tpm_limit}. User role={user_api_key_dict.user_role}"
},
) )
if _existing_team_id is not None:
raise HTTPException(
status_code=400,
detail={
"error": f"Team id = {data.team_id} already exists. Please use a different team id."
},
)
if ( if (
data.rpm_limit is not None user_api_key_dict.user_role is None
and user_api_key_dict.rpm_limit is not None or user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN
and data.rpm_limit > user_api_key_dict.rpm_limit ): # don't restrict proxy admin
): if (
raise HTTPException( data.tpm_limit is not None
status_code=400, and user_api_key_dict.tpm_limit is not None
detail={ and data.tpm_limit > user_api_key_dict.tpm_limit
"error": f"rpm limit higher than user max. User rpm limit={user_api_key_dict.rpm_limit}. User role={user_api_key_dict.user_role}" ):
raise HTTPException(
status_code=400,
detail={
"error": f"tpm limit higher than user max. User tpm limit={user_api_key_dict.tpm_limit}. User role={user_api_key_dict.user_role}"
},
)
if (
data.rpm_limit is not None
and user_api_key_dict.rpm_limit is not None
and data.rpm_limit > user_api_key_dict.rpm_limit
):
raise HTTPException(
status_code=400,
detail={
"error": f"rpm limit higher than user max. User rpm limit={user_api_key_dict.rpm_limit}. User role={user_api_key_dict.user_role}"
},
)
if (
data.max_budget is not None
and user_api_key_dict.max_budget is not None
and data.max_budget > user_api_key_dict.max_budget
):
raise HTTPException(
status_code=400,
detail={
"error": f"max budget higher than user max. User max budget={user_api_key_dict.max_budget}. User role={user_api_key_dict.user_role}"
},
)
if data.models is not None and len(user_api_key_dict.models) > 0:
for m in data.models:
if m not in user_api_key_dict.models:
raise HTTPException(
status_code=400,
detail={
"error": f"Model not in allowed user models. User allowed models={user_api_key_dict.models}. User id={user_api_key_dict.user_id}"
},
)
if user_api_key_dict.user_id is not None:
creating_user_in_list = False
for member in data.members_with_roles:
if member.user_id == user_api_key_dict.user_id:
creating_user_in_list = True
if creating_user_in_list is False:
data.members_with_roles.append(
Member(role="admin", user_id=user_api_key_dict.user_id)
)
## ADD TO MODEL TABLE
_model_id = None
if data.model_aliases is not None and isinstance(data.model_aliases, dict):
litellm_modeltable = LiteLLM_ModelTable(
model_aliases=json.dumps(data.model_aliases),
created_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
)
model_dict = await prisma_client.db.litellm_modeltable.create(
{**litellm_modeltable.json(exclude_none=True)} # type: ignore
) # type: ignore
_model_id = model_dict.id
## ADD TO TEAM TABLE
complete_team_data = LiteLLM_TeamTable(
**data.json(),
model_id=_model_id,
)
# Set Management Endpoint Metadata Fields
for field in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
if getattr(data, field) is not None:
_set_team_metadata_field(
team_data=complete_team_data,
field_name=field,
value=getattr(data, field),
)
# If budget_duration is set, set `budget_reset_at`
if complete_team_data.budget_duration is not None:
duration_s = duration_in_seconds(
duration=complete_team_data.budget_duration
)
reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
complete_team_data.budget_reset_at = reset_at
complete_team_data_dict = complete_team_data.model_dump(exclude_none=True)
complete_team_data_dict = prisma_client.jsonify_team_object(
db_data=complete_team_data_dict
)
team_row: LiteLLM_TeamTable = await prisma_client.db.litellm_teamtable.create(
data=complete_team_data_dict,
include={"litellm_model_table": True}, # type: ignore
)
## ADD TEAM ID TO USER TABLE ##
for user in complete_team_data.members_with_roles:
## add team id to user row ##
await prisma_client.update_data(
user_id=user.user_id,
data={"user_id": user.user_id, "teams": [team_row.team_id]},
update_key_values_custom_query={
"teams": {
"push ": [team_row.team_id],
}
}, },
) )
if ( # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
data.max_budget is not None if litellm.store_audit_logs is True:
and user_api_key_dict.max_budget is not None _updated_values = complete_team_data.json(exclude_none=True)
and data.max_budget > user_api_key_dict.max_budget
):
raise HTTPException(
status_code=400,
detail={
"error": f"max budget higher than user max. User max budget={user_api_key_dict.max_budget}. User role={user_api_key_dict.user_role}"
},
)
if data.models is not None and len(user_api_key_dict.models) > 0: _updated_values = json.dumps(_updated_values, default=str)
for m in data.models:
if m not in user_api_key_dict.models: asyncio.create_task(
raise HTTPException( create_audit_log_for_update(
status_code=400, request_data=LiteLLM_AuditLogs(
detail={ id=str(uuid.uuid4()),
"error": f"Model not in allowed user models. User allowed models={user_api_key_dict.models}. User id={user_api_key_dict.user_id}" updated_at=datetime.now(timezone.utc),
}, changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.api_key,
table_name=LitellmTableNames.TEAM_TABLE_NAME,
object_id=data.team_id,
action="created",
updated_values=_updated_values,
before_value=None,
) )
if user_api_key_dict.user_id is not None:
creating_user_in_list = False
for member in data.members_with_roles:
if member.user_id == user_api_key_dict.user_id:
creating_user_in_list = True
if creating_user_in_list is False:
data.members_with_roles.append(
Member(role="admin", user_id=user_api_key_dict.user_id)
)
## ADD TO MODEL TABLE
_model_id = None
if data.model_aliases is not None and isinstance(data.model_aliases, dict):
litellm_modeltable = LiteLLM_ModelTable(
model_aliases=json.dumps(data.model_aliases),
created_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
)
model_dict = await prisma_client.db.litellm_modeltable.create(
{**litellm_modeltable.json(exclude_none=True)} # type: ignore
) # type: ignore
_model_id = model_dict.id
## ADD TO TEAM TABLE
complete_team_data = LiteLLM_TeamTable(
**data.json(),
model_id=_model_id,
)
# Set Management Endpoint Metadata Fields
for field in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
if getattr(data, field) is not None:
_set_team_metadata_field(
team_data=complete_team_data,
field_name=field,
value=getattr(data, field),
)
# If budget_duration is set, set `budget_reset_at`
if complete_team_data.budget_duration is not None:
duration_s = duration_in_seconds(duration=complete_team_data.budget_duration)
reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
complete_team_data.budget_reset_at = reset_at
complete_team_data_dict = complete_team_data.model_dump(exclude_none=True)
complete_team_data_dict = prisma_client.jsonify_team_object(
db_data=complete_team_data_dict
)
team_row: LiteLLM_TeamTable = await prisma_client.db.litellm_teamtable.create(
data=complete_team_data_dict,
include={"litellm_model_table": True}, # type: ignore
)
## ADD TEAM ID TO USER TABLE ##
for user in complete_team_data.members_with_roles:
## add team id to user row ##
await prisma_client.update_data(
user_id=user.user_id,
data={"user_id": user.user_id, "teams": [team_row.team_id]},
update_key_values_custom_query={
"teams": {
"push ": [team_row.team_id],
}
},
)
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
if litellm.store_audit_logs is True:
_updated_values = complete_team_data.json(exclude_none=True)
_updated_values = json.dumps(_updated_values, default=str)
asyncio.create_task(
create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
id=str(uuid.uuid4()),
updated_at=datetime.now(timezone.utc),
changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.api_key,
table_name=LitellmTableNames.TEAM_TABLE_NAME,
object_id=data.team_id,
action="created",
updated_values=_updated_values,
before_value=None,
) )
) )
)
try: try:
return team_row.model_dump() return team_row.model_dump()
except Exception: except Exception:
return team_row.dict() return team_row.dict()
except Exception as e:
raise handle_exception_on_proxy(e)
async def _update_model_table( async def _update_model_table(

View file

@ -24,5 +24,3 @@ guardrails:
mode: "during_call" mode: "during_call"
guardrailIdentifier: gf3sc1mzinjw guardrailIdentifier: gf3sc1mzinjw
guardrailVersion: "DRAFT" guardrailVersion: "DRAFT"
default_on: true

View file

@ -3,6 +3,7 @@ import asyncio
import aiohttp, openai import aiohttp, openai
from openai import OpenAI, AsyncOpenAI from openai import OpenAI, AsyncOpenAI
from typing import Optional, List, Union from typing import Optional, List, Union
import json
import uuid import uuid
@ -40,21 +41,22 @@ async def chat_completion(
raise Exception(response_text) raise Exception(response_text)
# response headers # response headers
response_headers = response.headers response_headers = dict(response.headers)
print("response headers=", response_headers) print("response headers=", response_headers)
return await response.json(), response_headers return await response.json(), response_headers
async def generate_key(session, guardrails): async def generate_key(
session, guardrails: Optional[List] = None, team_id: Optional[str] = None
):
url = "http://0.0.0.0:4000/key/generate" url = "http://0.0.0.0:4000/key/generate"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {}
if guardrails: if guardrails:
data = { data["guardrails"] = guardrails
"guardrails": guardrails, if team_id:
} data["team_id"] = team_id
else:
data = {}
async with session.post(url, headers=headers, json=data) as response: async with session.post(url, headers=headers, json=data) as response:
status = response.status status = response.status
@ -148,7 +150,6 @@ async def test_no_llm_guard_triggered():
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.skip(reason="Aporia account disabled")
async def test_guardrails_with_api_key_controls(): async def test_guardrails_with_api_key_controls():
""" """
- Make two API Keys - Make two API Keys
@ -161,8 +162,7 @@ async def test_guardrails_with_api_key_controls():
key_with_guardrails = await generate_key( key_with_guardrails = await generate_key(
session=session, session=session,
guardrails=[ guardrails=[
"aporia-post-guard", "bedrock-pre-guard",
"aporia-pre-guard",
], ],
) )
@ -185,19 +185,15 @@ async def test_guardrails_with_api_key_controls():
assert "x-litellm-applied-guardrails" not in headers assert "x-litellm-applied-guardrails" not in headers
# test guardrails triggered for key with guardrails # test guardrails triggered for key with guardrails
try: response, headers = await chat_completion(
response, headers = await chat_completion( session,
session, key_with_guardrails,
key_with_guardrails, model="fake-openai-endpoint",
model="fake-openai-endpoint", messages=[{"role": "user", "content": f"Hello my name is ishaan@berri.ai"}],
messages=[ )
{"role": "user", "content": f"Hello my name is ishaan@berri.ai"}
], assert "x-litellm-applied-guardrails" in headers
) assert headers["x-litellm-applied-guardrails"] == "bedrock-pre-guard"
pytest.fail("Should have thrown an exception")
except Exception as e:
print(e)
assert "Aporia detected and blocked PII" in str(e)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -241,3 +237,82 @@ async def test_custom_guardrail_during_call_triggered():
except Exception as e: except Exception as e:
print(e) print(e)
assert "Guardrail failed words - `litellm` detected" in str(e) assert "Guardrail failed words - `litellm` detected" in str(e)
async def create_team(session, guardrails: Optional[List] = None):
url = "http://0.0.0.0:4000/team/new"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {"guardrails": guardrails}
print("request data=", data)
async with session.post(url, headers=headers, json=data) as response:
status = response.status
response_text = await response.text()
print(response_text)
print()
if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}")
return await response.json()
@pytest.mark.asyncio
async def test_guardrails_with_team_controls():
"""
- Create a team with guardrails
- Make two API Keys
- Key 1 not associated with team
- Key 2 associated with team (inherits team guardrails)
- Request with Key 1 -> should be success with no guardrails
- Request with Key 2 -> should error since team guardrails are triggered
"""
async with aiohttp.ClientSession() as session:
# Create team with guardrails
team = await create_team(
session=session,
guardrails=[
"bedrock-pre-guard",
],
)
print("team=", team)
team_id = team["team_id"]
# Create key with team association
key_with_team = await generate_key(session=session, team_id=team_id)
key_with_team = key_with_team["key"]
# Create key without team
key_without_team = await generate_key(
session=session,
)
key_without_team = key_without_team["key"]
# Test no guardrails triggered for key without a team
response, headers = await chat_completion(
session,
key_without_team,
model="fake-openai-endpoint",
messages=[{"role": "user", "content": "Hello my name is ishaan@berri.ai"}],
)
await asyncio.sleep(3)
print("response=", response, "response headers", headers)
assert "x-litellm-applied-guardrails" not in headers
response, headers = await chat_completion(
session,
key_with_team,
model="fake-openai-endpoint",
messages=[{"role": "user", "content": "Hello my name is ishaan@berri.ai"}],
)
print("response headers=", json.dumps(headers, indent=4))
assert "x-litellm-applied-guardrails" in headers
assert headers["x-litellm-applied-guardrails"] == "bedrock-pre-guard"