(Feat) set guardrails per team (#7993)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 35s

* _add_guardrails_from_key_or_team_metadata

* e2e test test_guardrails_with_team_controls

* add try/except on team new

* test_guardrails_with_team_controls

* test_guardrails_with_api_key_controls
This commit is contained in:
Ishaan Jaff 2025-01-25 10:41:11 -08:00 committed by GitHub
parent 669b4fc955
commit a7b3c664d1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 298 additions and 195 deletions

View file

@ -691,20 +691,24 @@ def _enforced_params_check(
return True
def move_guardrails_to_metadata(
def _add_guardrails_from_key_or_team_metadata(
key_metadata: Optional[dict],
team_metadata: Optional[dict],
data: dict,
_metadata_variable_name: str,
user_api_key_dict: UserAPIKeyAuth,
):
metadata_variable_name: str,
) -> None:
"""
Heper to add guardrails from request to metadata
Helper add guardrails from key or team metadata to request data
- If guardrails set on API Key metadata then sets guardrails on request metadata
- If guardrails not set on API key, then checks request metadata
Args:
key_metadata: The key metadata dictionary to check for guardrails
team_metadata: The team metadata dictionary to check for guardrails
data: The request data to update
metadata_variable_name: The name of the metadata field in data
"""
if user_api_key_dict.metadata:
if "guardrails" in user_api_key_dict.metadata:
for _management_object_metadata in [key_metadata, team_metadata]:
if _management_object_metadata and "guardrails" in _management_object_metadata:
from litellm.proxy.proxy_server import premium_user
if premium_user is not True:
@ -712,11 +716,31 @@ def move_guardrails_to_metadata(
f"Using Guardrails on API Key {CommonProxyErrors.not_premium_user}"
)
data[_metadata_variable_name]["guardrails"] = user_api_key_dict.metadata[
data[metadata_variable_name]["guardrails"] = _management_object_metadata[
"guardrails"
]
return
def move_guardrails_to_metadata(
data: dict,
_metadata_variable_name: str,
user_api_key_dict: UserAPIKeyAuth,
):
"""
Helper to add guardrails from request to metadata
- If guardrails set on API Key metadata then sets guardrails on request metadata
- If guardrails not set on API key, then checks request metadata
"""
# Check key-level guardrails
_add_guardrails_from_key_or_team_metadata(
key_metadata=user_api_key_dict.metadata,
team_metadata=user_api_key_dict.team_metadata,
data=data,
metadata_variable_name=_metadata_variable_name,
)
# Check request-level guardrails
if "guardrails" in data:
data[_metadata_variable_name]["guardrails"] = data["guardrails"]
del data["guardrails"]

View file

@ -58,7 +58,8 @@ from litellm.proxy.management_helpers.utils import (
add_new_member,
management_endpoint_wrapper,
)
from litellm.proxy.utils import PrismaClient, _premium_user_check
from litellm.proxy.utils import PrismaClient, handle_exception_on_proxy, _premium_user_check
router = APIRouter()
@ -175,6 +176,7 @@ async def new_team( # noqa: PLR0915
}'
```
"""
try:
from litellm.proxy.proxy_server import (
create_audit_log_for_update,
duration_in_seconds,
@ -292,7 +294,9 @@ async def new_team( # noqa: PLR0915
# If budget_duration is set, set `budget_reset_at`
if complete_team_data.budget_duration is not None:
duration_s = duration_in_seconds(duration=complete_team_data.budget_duration)
duration_s = duration_in_seconds(
duration=complete_team_data.budget_duration
)
reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
complete_team_data.budget_reset_at = reset_at
@ -346,6 +350,8 @@ async def new_team( # noqa: PLR0915
return team_row.model_dump()
except Exception:
return team_row.dict()
except Exception as e:
raise handle_exception_on_proxy(e)
async def _update_model_table(

View file

@ -24,5 +24,3 @@ guardrails:
mode: "during_call"
guardrailIdentifier: gf3sc1mzinjw
guardrailVersion: "DRAFT"
default_on: true

View file

@ -3,6 +3,7 @@ import asyncio
import aiohttp, openai
from openai import OpenAI, AsyncOpenAI
from typing import Optional, List, Union
import json
import uuid
@ -40,21 +41,22 @@ async def chat_completion(
raise Exception(response_text)
# response headers
response_headers = response.headers
response_headers = dict(response.headers)
print("response headers=", response_headers)
return await response.json(), response_headers
async def generate_key(session, guardrails):
async def generate_key(
session, guardrails: Optional[List] = None, team_id: Optional[str] = None
):
url = "http://0.0.0.0:4000/key/generate"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
if guardrails:
data = {
"guardrails": guardrails,
}
else:
data = {}
if guardrails:
data["guardrails"] = guardrails
if team_id:
data["team_id"] = team_id
async with session.post(url, headers=headers, json=data) as response:
status = response.status
@ -148,7 +150,6 @@ async def test_no_llm_guard_triggered():
@pytest.mark.asyncio
@pytest.mark.skip(reason="Aporia account disabled")
async def test_guardrails_with_api_key_controls():
"""
- Make two API Keys
@ -161,8 +162,7 @@ async def test_guardrails_with_api_key_controls():
key_with_guardrails = await generate_key(
session=session,
guardrails=[
"aporia-post-guard",
"aporia-pre-guard",
"bedrock-pre-guard",
],
)
@ -185,19 +185,15 @@ async def test_guardrails_with_api_key_controls():
assert "x-litellm-applied-guardrails" not in headers
# test guardrails triggered for key with guardrails
try:
response, headers = await chat_completion(
session,
key_with_guardrails,
model="fake-openai-endpoint",
messages=[
{"role": "user", "content": f"Hello my name is ishaan@berri.ai"}
],
messages=[{"role": "user", "content": f"Hello my name is ishaan@berri.ai"}],
)
pytest.fail("Should have thrown an exception")
except Exception as e:
print(e)
assert "Aporia detected and blocked PII" in str(e)
assert "x-litellm-applied-guardrails" in headers
assert headers["x-litellm-applied-guardrails"] == "bedrock-pre-guard"
@pytest.mark.asyncio
@ -241,3 +237,82 @@ async def test_custom_guardrail_during_call_triggered():
except Exception as e:
print(e)
assert "Guardrail failed words - `litellm` detected" in str(e)
async def create_team(session, guardrails: Optional[List] = None):
url = "http://0.0.0.0:4000/team/new"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {"guardrails": guardrails}
print("request data=", data)
async with session.post(url, headers=headers, json=data) as response:
status = response.status
response_text = await response.text()
print(response_text)
print()
if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}")
return await response.json()
@pytest.mark.asyncio
async def test_guardrails_with_team_controls():
"""
- Create a team with guardrails
- Make two API Keys
- Key 1 not associated with team
- Key 2 associated with team (inherits team guardrails)
- Request with Key 1 -> should be success with no guardrails
- Request with Key 2 -> should error since team guardrails are triggered
"""
async with aiohttp.ClientSession() as session:
# Create team with guardrails
team = await create_team(
session=session,
guardrails=[
"bedrock-pre-guard",
],
)
print("team=", team)
team_id = team["team_id"]
# Create key with team association
key_with_team = await generate_key(session=session, team_id=team_id)
key_with_team = key_with_team["key"]
# Create key without team
key_without_team = await generate_key(
session=session,
)
key_without_team = key_without_team["key"]
# Test no guardrails triggered for key without a team
response, headers = await chat_completion(
session,
key_without_team,
model="fake-openai-endpoint",
messages=[{"role": "user", "content": "Hello my name is ishaan@berri.ai"}],
)
await asyncio.sleep(3)
print("response=", response, "response headers", headers)
assert "x-litellm-applied-guardrails" not in headers
response, headers = await chat_completion(
session,
key_with_team,
model="fake-openai-endpoint",
messages=[{"role": "user", "content": "Hello my name is ishaan@berri.ai"}],
)
print("response headers=", json.dumps(headers, indent=4))
assert "x-litellm-applied-guardrails" in headers
assert headers["x-litellm-applied-guardrails"] == "bedrock-pre-guard"