(Feat) set guardrails per team (#7993)

* _add_guardrails_from_key_or_team_metadata

* e2e test test_guardrails_with_team_controls

* add try/except on team new

* test_guardrails_with_team_controls

* test_guardrails_with_api_key_controls
This commit is contained in:
Ishaan Jaff 2025-01-25 10:41:11 -08:00 committed by GitHub
parent 9e64c7ca0c
commit fe24e729a9
4 changed files with 298 additions and 195 deletions

View file

@ -691,20 +691,24 @@ def _enforced_params_check(
return True return True
def move_guardrails_to_metadata( def _add_guardrails_from_key_or_team_metadata(
key_metadata: Optional[dict],
team_metadata: Optional[dict],
data: dict, data: dict,
_metadata_variable_name: str, metadata_variable_name: str,
user_api_key_dict: UserAPIKeyAuth, ) -> None:
):
""" """
Heper to add guardrails from request to metadata Helper add guardrails from key or team metadata to request data
- If guardrails set on API Key metadata then sets guardrails on request metadata Args:
- If guardrails not set on API key, then checks request metadata key_metadata: The key metadata dictionary to check for guardrails
team_metadata: The team metadata dictionary to check for guardrails
data: The request data to update
metadata_variable_name: The name of the metadata field in data
""" """
if user_api_key_dict.metadata: for _management_object_metadata in [key_metadata, team_metadata]:
if "guardrails" in user_api_key_dict.metadata: if _management_object_metadata and "guardrails" in _management_object_metadata:
from litellm.proxy.proxy_server import premium_user from litellm.proxy.proxy_server import premium_user
if premium_user is not True: if premium_user is not True:
@ -712,11 +716,31 @@ def move_guardrails_to_metadata(
f"Using Guardrails on API Key {CommonProxyErrors.not_premium_user}" f"Using Guardrails on API Key {CommonProxyErrors.not_premium_user}"
) )
data[_metadata_variable_name]["guardrails"] = user_api_key_dict.metadata[ data[metadata_variable_name]["guardrails"] = _management_object_metadata[
"guardrails" "guardrails"
] ]
return
def move_guardrails_to_metadata(
data: dict,
_metadata_variable_name: str,
user_api_key_dict: UserAPIKeyAuth,
):
"""
Helper to add guardrails from request to metadata
- If guardrails set on API Key metadata then sets guardrails on request metadata
- If guardrails not set on API key, then checks request metadata
"""
# Check key-level guardrails
_add_guardrails_from_key_or_team_metadata(
key_metadata=user_api_key_dict.metadata,
team_metadata=user_api_key_dict.team_metadata,
data=data,
metadata_variable_name=_metadata_variable_name,
)
# Check request-level guardrails
if "guardrails" in data: if "guardrails" in data:
data[_metadata_variable_name]["guardrails"] = data["guardrails"] data[_metadata_variable_name]["guardrails"] = data["guardrails"]
del data["guardrails"] del data["guardrails"]

View file

@ -58,7 +58,8 @@ from litellm.proxy.management_helpers.utils import (
add_new_member, add_new_member,
management_endpoint_wrapper, management_endpoint_wrapper,
) )
from litellm.proxy.utils import PrismaClient, _premium_user_check from litellm.proxy.utils import PrismaClient, handle_exception_on_proxy, _premium_user_check
router = APIRouter() router = APIRouter()
@ -175,6 +176,7 @@ async def new_team( # noqa: PLR0915
}' }'
``` ```
""" """
try:
from litellm.proxy.proxy_server import ( from litellm.proxy.proxy_server import (
create_audit_log_for_update, create_audit_log_for_update,
duration_in_seconds, duration_in_seconds,
@ -292,7 +294,9 @@ async def new_team( # noqa: PLR0915
# If budget_duration is set, set `budget_reset_at` # If budget_duration is set, set `budget_reset_at`
if complete_team_data.budget_duration is not None: if complete_team_data.budget_duration is not None:
duration_s = duration_in_seconds(duration=complete_team_data.budget_duration) duration_s = duration_in_seconds(
duration=complete_team_data.budget_duration
)
reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
complete_team_data.budget_reset_at = reset_at complete_team_data.budget_reset_at = reset_at
@ -346,6 +350,8 @@ async def new_team( # noqa: PLR0915
return team_row.model_dump() return team_row.model_dump()
except Exception: except Exception:
return team_row.dict() return team_row.dict()
except Exception as e:
raise handle_exception_on_proxy(e)
async def _update_model_table( async def _update_model_table(

View file

@ -24,5 +24,3 @@ guardrails:
mode: "during_call" mode: "during_call"
guardrailIdentifier: gf3sc1mzinjw guardrailIdentifier: gf3sc1mzinjw
guardrailVersion: "DRAFT" guardrailVersion: "DRAFT"
default_on: true

View file

@ -3,6 +3,7 @@ import asyncio
import aiohttp, openai import aiohttp, openai
from openai import OpenAI, AsyncOpenAI from openai import OpenAI, AsyncOpenAI
from typing import Optional, List, Union from typing import Optional, List, Union
import json
import uuid import uuid
@ -40,21 +41,22 @@ async def chat_completion(
raise Exception(response_text) raise Exception(response_text)
# response headers # response headers
response_headers = response.headers response_headers = dict(response.headers)
print("response headers=", response_headers) print("response headers=", response_headers)
return await response.json(), response_headers return await response.json(), response_headers
async def generate_key(session, guardrails): async def generate_key(
session, guardrails: Optional[List] = None, team_id: Optional[str] = None
):
url = "http://0.0.0.0:4000/key/generate" url = "http://0.0.0.0:4000/key/generate"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
if guardrails:
data = {
"guardrails": guardrails,
}
else:
data = {} data = {}
if guardrails:
data["guardrails"] = guardrails
if team_id:
data["team_id"] = team_id
async with session.post(url, headers=headers, json=data) as response: async with session.post(url, headers=headers, json=data) as response:
status = response.status status = response.status
@ -148,7 +150,6 @@ async def test_no_llm_guard_triggered():
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.skip(reason="Aporia account disabled")
async def test_guardrails_with_api_key_controls(): async def test_guardrails_with_api_key_controls():
""" """
- Make two API Keys - Make two API Keys
@ -161,8 +162,7 @@ async def test_guardrails_with_api_key_controls():
key_with_guardrails = await generate_key( key_with_guardrails = await generate_key(
session=session, session=session,
guardrails=[ guardrails=[
"aporia-post-guard", "bedrock-pre-guard",
"aporia-pre-guard",
], ],
) )
@ -185,19 +185,15 @@ async def test_guardrails_with_api_key_controls():
assert "x-litellm-applied-guardrails" not in headers assert "x-litellm-applied-guardrails" not in headers
# test guardrails triggered for key with guardrails # test guardrails triggered for key with guardrails
try:
response, headers = await chat_completion( response, headers = await chat_completion(
session, session,
key_with_guardrails, key_with_guardrails,
model="fake-openai-endpoint", model="fake-openai-endpoint",
messages=[ messages=[{"role": "user", "content": f"Hello my name is ishaan@berri.ai"}],
{"role": "user", "content": f"Hello my name is ishaan@berri.ai"}
],
) )
pytest.fail("Should have thrown an exception")
except Exception as e: assert "x-litellm-applied-guardrails" in headers
print(e) assert headers["x-litellm-applied-guardrails"] == "bedrock-pre-guard"
assert "Aporia detected and blocked PII" in str(e)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -241,3 +237,82 @@ async def test_custom_guardrail_during_call_triggered():
except Exception as e: except Exception as e:
print(e) print(e)
assert "Guardrail failed words - `litellm` detected" in str(e) assert "Guardrail failed words - `litellm` detected" in str(e)
async def create_team(session, guardrails: Optional[List] = None):
url = "http://0.0.0.0:4000/team/new"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {"guardrails": guardrails}
print("request data=", data)
async with session.post(url, headers=headers, json=data) as response:
status = response.status
response_text = await response.text()
print(response_text)
print()
if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}")
return await response.json()
@pytest.mark.asyncio
async def test_guardrails_with_team_controls():
"""
- Create a team with guardrails
- Make two API Keys
- Key 1 not associated with team
- Key 2 associated with team (inherits team guardrails)
- Request with Key 1 -> should be success with no guardrails
- Request with Key 2 -> should error since team guardrails are triggered
"""
async with aiohttp.ClientSession() as session:
# Create team with guardrails
team = await create_team(
session=session,
guardrails=[
"bedrock-pre-guard",
],
)
print("team=", team)
team_id = team["team_id"]
# Create key with team association
key_with_team = await generate_key(session=session, team_id=team_id)
key_with_team = key_with_team["key"]
# Create key without team
key_without_team = await generate_key(
session=session,
)
key_without_team = key_without_team["key"]
# Test no guardrails triggered for key without a team
response, headers = await chat_completion(
session,
key_without_team,
model="fake-openai-endpoint",
messages=[{"role": "user", "content": "Hello my name is ishaan@berri.ai"}],
)
await asyncio.sleep(3)
print("response=", response, "response headers", headers)
assert "x-litellm-applied-guardrails" not in headers
response, headers = await chat_completion(
session,
key_with_team,
model="fake-openai-endpoint",
messages=[{"role": "user", "content": "Hello my name is ishaan@berri.ai"}],
)
print("response headers=", json.dumps(headers, indent=4))
assert "x-litellm-applied-guardrails" in headers
assert headers["x-litellm-applied-guardrails"] == "bedrock-pre-guard"