mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
(Feat) set guardrails per team (#7993)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 35s
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 35s
* _add_guardrails_from_key_or_team_metadata * e2e test test_guardrails_with_team_controls * add try/except on team new * test_guardrails_with_team_controls * test_guardrails_with_api_key_controls
This commit is contained in:
parent
669b4fc955
commit
a7b3c664d1
4 changed files with 298 additions and 195 deletions
|
@ -691,20 +691,24 @@ def _enforced_params_check(
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def move_guardrails_to_metadata(
|
def _add_guardrails_from_key_or_team_metadata(
|
||||||
|
key_metadata: Optional[dict],
|
||||||
|
team_metadata: Optional[dict],
|
||||||
data: dict,
|
data: dict,
|
||||||
_metadata_variable_name: str,
|
metadata_variable_name: str,
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
) -> None:
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Heper to add guardrails from request to metadata
|
Helper add guardrails from key or team metadata to request data
|
||||||
|
|
||||||
- If guardrails set on API Key metadata then sets guardrails on request metadata
|
Args:
|
||||||
- If guardrails not set on API key, then checks request metadata
|
key_metadata: The key metadata dictionary to check for guardrails
|
||||||
|
team_metadata: The team metadata dictionary to check for guardrails
|
||||||
|
data: The request data to update
|
||||||
|
metadata_variable_name: The name of the metadata field in data
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if user_api_key_dict.metadata:
|
for _management_object_metadata in [key_metadata, team_metadata]:
|
||||||
if "guardrails" in user_api_key_dict.metadata:
|
if _management_object_metadata and "guardrails" in _management_object_metadata:
|
||||||
from litellm.proxy.proxy_server import premium_user
|
from litellm.proxy.proxy_server import premium_user
|
||||||
|
|
||||||
if premium_user is not True:
|
if premium_user is not True:
|
||||||
|
@ -712,11 +716,31 @@ def move_guardrails_to_metadata(
|
||||||
f"Using Guardrails on API Key {CommonProxyErrors.not_premium_user}"
|
f"Using Guardrails on API Key {CommonProxyErrors.not_premium_user}"
|
||||||
)
|
)
|
||||||
|
|
||||||
data[_metadata_variable_name]["guardrails"] = user_api_key_dict.metadata[
|
data[metadata_variable_name]["guardrails"] = _management_object_metadata[
|
||||||
"guardrails"
|
"guardrails"
|
||||||
]
|
]
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
|
def move_guardrails_to_metadata(
|
||||||
|
data: dict,
|
||||||
|
_metadata_variable_name: str,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Helper to add guardrails from request to metadata
|
||||||
|
|
||||||
|
- If guardrails set on API Key metadata then sets guardrails on request metadata
|
||||||
|
- If guardrails not set on API key, then checks request metadata
|
||||||
|
"""
|
||||||
|
# Check key-level guardrails
|
||||||
|
_add_guardrails_from_key_or_team_metadata(
|
||||||
|
key_metadata=user_api_key_dict.metadata,
|
||||||
|
team_metadata=user_api_key_dict.team_metadata,
|
||||||
|
data=data,
|
||||||
|
metadata_variable_name=_metadata_variable_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check request-level guardrails
|
||||||
if "guardrails" in data:
|
if "guardrails" in data:
|
||||||
data[_metadata_variable_name]["guardrails"] = data["guardrails"]
|
data[_metadata_variable_name]["guardrails"] = data["guardrails"]
|
||||||
del data["guardrails"]
|
del data["guardrails"]
|
||||||
|
|
|
@ -58,7 +58,8 @@ from litellm.proxy.management_helpers.utils import (
|
||||||
add_new_member,
|
add_new_member,
|
||||||
management_endpoint_wrapper,
|
management_endpoint_wrapper,
|
||||||
)
|
)
|
||||||
from litellm.proxy.utils import PrismaClient, _premium_user_check
|
from litellm.proxy.utils import PrismaClient, handle_exception_on_proxy, _premium_user_check
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
@ -175,177 +176,182 @@ async def new_team( # noqa: PLR0915
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
from litellm.proxy.proxy_server import (
|
try:
|
||||||
create_audit_log_for_update,
|
from litellm.proxy.proxy_server import (
|
||||||
duration_in_seconds,
|
create_audit_log_for_update,
|
||||||
litellm_proxy_admin_name,
|
duration_in_seconds,
|
||||||
prisma_client,
|
litellm_proxy_admin_name,
|
||||||
)
|
prisma_client,
|
||||||
|
|
||||||
if prisma_client is None:
|
|
||||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
|
||||||
|
|
||||||
if data.team_id is None:
|
|
||||||
data.team_id = str(uuid.uuid4())
|
|
||||||
else:
|
|
||||||
# Check if team_id exists already
|
|
||||||
_existing_team_id = await prisma_client.get_data(
|
|
||||||
team_id=data.team_id, table_name="team", query_type="find_unique"
|
|
||||||
)
|
)
|
||||||
if _existing_team_id is not None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail={
|
|
||||||
"error": f"Team id = {data.team_id} already exists. Please use a different team id."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
if prisma_client is None:
|
||||||
user_api_key_dict.user_role is None
|
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||||
or user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN
|
|
||||||
): # don't restrict proxy admin
|
if data.team_id is None:
|
||||||
if (
|
data.team_id = str(uuid.uuid4())
|
||||||
data.tpm_limit is not None
|
else:
|
||||||
and user_api_key_dict.tpm_limit is not None
|
# Check if team_id exists already
|
||||||
and data.tpm_limit > user_api_key_dict.tpm_limit
|
_existing_team_id = await prisma_client.get_data(
|
||||||
):
|
team_id=data.team_id, table_name="team", query_type="find_unique"
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail={
|
|
||||||
"error": f"tpm limit higher than user max. User tpm limit={user_api_key_dict.tpm_limit}. User role={user_api_key_dict.user_role}"
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
if _existing_team_id is not None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": f"Team id = {data.team_id} already exists. Please use a different team id."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
data.rpm_limit is not None
|
user_api_key_dict.user_role is None
|
||||||
and user_api_key_dict.rpm_limit is not None
|
or user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN
|
||||||
and data.rpm_limit > user_api_key_dict.rpm_limit
|
): # don't restrict proxy admin
|
||||||
):
|
if (
|
||||||
raise HTTPException(
|
data.tpm_limit is not None
|
||||||
status_code=400,
|
and user_api_key_dict.tpm_limit is not None
|
||||||
detail={
|
and data.tpm_limit > user_api_key_dict.tpm_limit
|
||||||
"error": f"rpm limit higher than user max. User rpm limit={user_api_key_dict.rpm_limit}. User role={user_api_key_dict.user_role}"
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": f"tpm limit higher than user max. User tpm limit={user_api_key_dict.tpm_limit}. User role={user_api_key_dict.user_role}"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
data.rpm_limit is not None
|
||||||
|
and user_api_key_dict.rpm_limit is not None
|
||||||
|
and data.rpm_limit > user_api_key_dict.rpm_limit
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": f"rpm limit higher than user max. User rpm limit={user_api_key_dict.rpm_limit}. User role={user_api_key_dict.user_role}"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
data.max_budget is not None
|
||||||
|
and user_api_key_dict.max_budget is not None
|
||||||
|
and data.max_budget > user_api_key_dict.max_budget
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": f"max budget higher than user max. User max budget={user_api_key_dict.max_budget}. User role={user_api_key_dict.user_role}"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if data.models is not None and len(user_api_key_dict.models) > 0:
|
||||||
|
for m in data.models:
|
||||||
|
if m not in user_api_key_dict.models:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": f"Model not in allowed user models. User allowed models={user_api_key_dict.models}. User id={user_api_key_dict.user_id}"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if user_api_key_dict.user_id is not None:
|
||||||
|
creating_user_in_list = False
|
||||||
|
for member in data.members_with_roles:
|
||||||
|
if member.user_id == user_api_key_dict.user_id:
|
||||||
|
creating_user_in_list = True
|
||||||
|
|
||||||
|
if creating_user_in_list is False:
|
||||||
|
data.members_with_roles.append(
|
||||||
|
Member(role="admin", user_id=user_api_key_dict.user_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
## ADD TO MODEL TABLE
|
||||||
|
_model_id = None
|
||||||
|
if data.model_aliases is not None and isinstance(data.model_aliases, dict):
|
||||||
|
litellm_modeltable = LiteLLM_ModelTable(
|
||||||
|
model_aliases=json.dumps(data.model_aliases),
|
||||||
|
created_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||||
|
updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||||
|
)
|
||||||
|
model_dict = await prisma_client.db.litellm_modeltable.create(
|
||||||
|
{**litellm_modeltable.json(exclude_none=True)} # type: ignore
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
_model_id = model_dict.id
|
||||||
|
|
||||||
|
## ADD TO TEAM TABLE
|
||||||
|
complete_team_data = LiteLLM_TeamTable(
|
||||||
|
**data.json(),
|
||||||
|
model_id=_model_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set Management Endpoint Metadata Fields
|
||||||
|
for field in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
|
||||||
|
if getattr(data, field) is not None:
|
||||||
|
_set_team_metadata_field(
|
||||||
|
team_data=complete_team_data,
|
||||||
|
field_name=field,
|
||||||
|
value=getattr(data, field),
|
||||||
|
)
|
||||||
|
|
||||||
|
# If budget_duration is set, set `budget_reset_at`
|
||||||
|
if complete_team_data.budget_duration is not None:
|
||||||
|
duration_s = duration_in_seconds(
|
||||||
|
duration=complete_team_data.budget_duration
|
||||||
|
)
|
||||||
|
reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
||||||
|
complete_team_data.budget_reset_at = reset_at
|
||||||
|
|
||||||
|
complete_team_data_dict = complete_team_data.model_dump(exclude_none=True)
|
||||||
|
complete_team_data_dict = prisma_client.jsonify_team_object(
|
||||||
|
db_data=complete_team_data_dict
|
||||||
|
)
|
||||||
|
team_row: LiteLLM_TeamTable = await prisma_client.db.litellm_teamtable.create(
|
||||||
|
data=complete_team_data_dict,
|
||||||
|
include={"litellm_model_table": True}, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
## ADD TEAM ID TO USER TABLE ##
|
||||||
|
for user in complete_team_data.members_with_roles:
|
||||||
|
## add team id to user row ##
|
||||||
|
await prisma_client.update_data(
|
||||||
|
user_id=user.user_id,
|
||||||
|
data={"user_id": user.user_id, "teams": [team_row.team_id]},
|
||||||
|
update_key_values_custom_query={
|
||||||
|
"teams": {
|
||||||
|
"push ": [team_row.team_id],
|
||||||
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||||
data.max_budget is not None
|
if litellm.store_audit_logs is True:
|
||||||
and user_api_key_dict.max_budget is not None
|
_updated_values = complete_team_data.json(exclude_none=True)
|
||||||
and data.max_budget > user_api_key_dict.max_budget
|
|
||||||
):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail={
|
|
||||||
"error": f"max budget higher than user max. User max budget={user_api_key_dict.max_budget}. User role={user_api_key_dict.user_role}"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if data.models is not None and len(user_api_key_dict.models) > 0:
|
_updated_values = json.dumps(_updated_values, default=str)
|
||||||
for m in data.models:
|
|
||||||
if m not in user_api_key_dict.models:
|
asyncio.create_task(
|
||||||
raise HTTPException(
|
create_audit_log_for_update(
|
||||||
status_code=400,
|
request_data=LiteLLM_AuditLogs(
|
||||||
detail={
|
id=str(uuid.uuid4()),
|
||||||
"error": f"Model not in allowed user models. User allowed models={user_api_key_dict.models}. User id={user_api_key_dict.user_id}"
|
updated_at=datetime.now(timezone.utc),
|
||||||
},
|
changed_by=litellm_changed_by
|
||||||
|
or user_api_key_dict.user_id
|
||||||
|
or litellm_proxy_admin_name,
|
||||||
|
changed_by_api_key=user_api_key_dict.api_key,
|
||||||
|
table_name=LitellmTableNames.TEAM_TABLE_NAME,
|
||||||
|
object_id=data.team_id,
|
||||||
|
action="created",
|
||||||
|
updated_values=_updated_values,
|
||||||
|
before_value=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if user_api_key_dict.user_id is not None:
|
|
||||||
creating_user_in_list = False
|
|
||||||
for member in data.members_with_roles:
|
|
||||||
if member.user_id == user_api_key_dict.user_id:
|
|
||||||
creating_user_in_list = True
|
|
||||||
|
|
||||||
if creating_user_in_list is False:
|
|
||||||
data.members_with_roles.append(
|
|
||||||
Member(role="admin", user_id=user_api_key_dict.user_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
## ADD TO MODEL TABLE
|
|
||||||
_model_id = None
|
|
||||||
if data.model_aliases is not None and isinstance(data.model_aliases, dict):
|
|
||||||
litellm_modeltable = LiteLLM_ModelTable(
|
|
||||||
model_aliases=json.dumps(data.model_aliases),
|
|
||||||
created_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
|
|
||||||
updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
|
|
||||||
)
|
|
||||||
model_dict = await prisma_client.db.litellm_modeltable.create(
|
|
||||||
{**litellm_modeltable.json(exclude_none=True)} # type: ignore
|
|
||||||
) # type: ignore
|
|
||||||
|
|
||||||
_model_id = model_dict.id
|
|
||||||
|
|
||||||
## ADD TO TEAM TABLE
|
|
||||||
complete_team_data = LiteLLM_TeamTable(
|
|
||||||
**data.json(),
|
|
||||||
model_id=_model_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set Management Endpoint Metadata Fields
|
|
||||||
for field in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
|
|
||||||
if getattr(data, field) is not None:
|
|
||||||
_set_team_metadata_field(
|
|
||||||
team_data=complete_team_data,
|
|
||||||
field_name=field,
|
|
||||||
value=getattr(data, field),
|
|
||||||
)
|
|
||||||
|
|
||||||
# If budget_duration is set, set `budget_reset_at`
|
|
||||||
if complete_team_data.budget_duration is not None:
|
|
||||||
duration_s = duration_in_seconds(duration=complete_team_data.budget_duration)
|
|
||||||
reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
|
||||||
complete_team_data.budget_reset_at = reset_at
|
|
||||||
|
|
||||||
complete_team_data_dict = complete_team_data.model_dump(exclude_none=True)
|
|
||||||
complete_team_data_dict = prisma_client.jsonify_team_object(
|
|
||||||
db_data=complete_team_data_dict
|
|
||||||
)
|
|
||||||
team_row: LiteLLM_TeamTable = await prisma_client.db.litellm_teamtable.create(
|
|
||||||
data=complete_team_data_dict,
|
|
||||||
include={"litellm_model_table": True}, # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
## ADD TEAM ID TO USER TABLE ##
|
|
||||||
for user in complete_team_data.members_with_roles:
|
|
||||||
## add team id to user row ##
|
|
||||||
await prisma_client.update_data(
|
|
||||||
user_id=user.user_id,
|
|
||||||
data={"user_id": user.user_id, "teams": [team_row.team_id]},
|
|
||||||
update_key_values_custom_query={
|
|
||||||
"teams": {
|
|
||||||
"push ": [team_row.team_id],
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
|
||||||
if litellm.store_audit_logs is True:
|
|
||||||
_updated_values = complete_team_data.json(exclude_none=True)
|
|
||||||
|
|
||||||
_updated_values = json.dumps(_updated_values, default=str)
|
|
||||||
|
|
||||||
asyncio.create_task(
|
|
||||||
create_audit_log_for_update(
|
|
||||||
request_data=LiteLLM_AuditLogs(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
updated_at=datetime.now(timezone.utc),
|
|
||||||
changed_by=litellm_changed_by
|
|
||||||
or user_api_key_dict.user_id
|
|
||||||
or litellm_proxy_admin_name,
|
|
||||||
changed_by_api_key=user_api_key_dict.api_key,
|
|
||||||
table_name=LitellmTableNames.TEAM_TABLE_NAME,
|
|
||||||
object_id=data.team_id,
|
|
||||||
action="created",
|
|
||||||
updated_values=_updated_values,
|
|
||||||
before_value=None,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return team_row.model_dump()
|
return team_row.model_dump()
|
||||||
except Exception:
|
except Exception:
|
||||||
return team_row.dict()
|
return team_row.dict()
|
||||||
|
except Exception as e:
|
||||||
|
raise handle_exception_on_proxy(e)
|
||||||
|
|
||||||
|
|
||||||
async def _update_model_table(
|
async def _update_model_table(
|
||||||
|
|
|
@ -24,5 +24,3 @@ guardrails:
|
||||||
mode: "during_call"
|
mode: "during_call"
|
||||||
guardrailIdentifier: gf3sc1mzinjw
|
guardrailIdentifier: gf3sc1mzinjw
|
||||||
guardrailVersion: "DRAFT"
|
guardrailVersion: "DRAFT"
|
||||||
default_on: true
|
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ import asyncio
|
||||||
import aiohttp, openai
|
import aiohttp, openai
|
||||||
from openai import OpenAI, AsyncOpenAI
|
from openai import OpenAI, AsyncOpenAI
|
||||||
from typing import Optional, List, Union
|
from typing import Optional, List, Union
|
||||||
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,21 +41,22 @@ async def chat_completion(
|
||||||
raise Exception(response_text)
|
raise Exception(response_text)
|
||||||
|
|
||||||
# response headers
|
# response headers
|
||||||
response_headers = response.headers
|
response_headers = dict(response.headers)
|
||||||
print("response headers=", response_headers)
|
print("response headers=", response_headers)
|
||||||
|
|
||||||
return await response.json(), response_headers
|
return await response.json(), response_headers
|
||||||
|
|
||||||
|
|
||||||
async def generate_key(session, guardrails):
|
async def generate_key(
|
||||||
|
session, guardrails: Optional[List] = None, team_id: Optional[str] = None
|
||||||
|
):
|
||||||
url = "http://0.0.0.0:4000/key/generate"
|
url = "http://0.0.0.0:4000/key/generate"
|
||||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||||
|
data = {}
|
||||||
if guardrails:
|
if guardrails:
|
||||||
data = {
|
data["guardrails"] = guardrails
|
||||||
"guardrails": guardrails,
|
if team_id:
|
||||||
}
|
data["team_id"] = team_id
|
||||||
else:
|
|
||||||
data = {}
|
|
||||||
|
|
||||||
async with session.post(url, headers=headers, json=data) as response:
|
async with session.post(url, headers=headers, json=data) as response:
|
||||||
status = response.status
|
status = response.status
|
||||||
|
@ -148,7 +150,6 @@ async def test_no_llm_guard_triggered():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.skip(reason="Aporia account disabled")
|
|
||||||
async def test_guardrails_with_api_key_controls():
|
async def test_guardrails_with_api_key_controls():
|
||||||
"""
|
"""
|
||||||
- Make two API Keys
|
- Make two API Keys
|
||||||
|
@ -161,8 +162,7 @@ async def test_guardrails_with_api_key_controls():
|
||||||
key_with_guardrails = await generate_key(
|
key_with_guardrails = await generate_key(
|
||||||
session=session,
|
session=session,
|
||||||
guardrails=[
|
guardrails=[
|
||||||
"aporia-post-guard",
|
"bedrock-pre-guard",
|
||||||
"aporia-pre-guard",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -185,19 +185,15 @@ async def test_guardrails_with_api_key_controls():
|
||||||
assert "x-litellm-applied-guardrails" not in headers
|
assert "x-litellm-applied-guardrails" not in headers
|
||||||
|
|
||||||
# test guardrails triggered for key with guardrails
|
# test guardrails triggered for key with guardrails
|
||||||
try:
|
response, headers = await chat_completion(
|
||||||
response, headers = await chat_completion(
|
session,
|
||||||
session,
|
key_with_guardrails,
|
||||||
key_with_guardrails,
|
model="fake-openai-endpoint",
|
||||||
model="fake-openai-endpoint",
|
messages=[{"role": "user", "content": f"Hello my name is ishaan@berri.ai"}],
|
||||||
messages=[
|
)
|
||||||
{"role": "user", "content": f"Hello my name is ishaan@berri.ai"}
|
|
||||||
],
|
assert "x-litellm-applied-guardrails" in headers
|
||||||
)
|
assert headers["x-litellm-applied-guardrails"] == "bedrock-pre-guard"
|
||||||
pytest.fail("Should have thrown an exception")
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
assert "Aporia detected and blocked PII" in str(e)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -241,3 +237,82 @@ async def test_custom_guardrail_during_call_triggered():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
assert "Guardrail failed words - `litellm` detected" in str(e)
|
assert "Guardrail failed words - `litellm` detected" in str(e)
|
||||||
|
|
||||||
|
|
||||||
|
async def create_team(session, guardrails: Optional[List] = None):
|
||||||
|
url = "http://0.0.0.0:4000/team/new"
|
||||||
|
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||||
|
data = {"guardrails": guardrails}
|
||||||
|
|
||||||
|
print("request data=", data)
|
||||||
|
|
||||||
|
async with session.post(url, headers=headers, json=data) as response:
|
||||||
|
status = response.status
|
||||||
|
response_text = await response.text()
|
||||||
|
|
||||||
|
print(response_text)
|
||||||
|
print()
|
||||||
|
|
||||||
|
if status != 200:
|
||||||
|
raise Exception(f"Request did not return a 200 status code: {status}")
|
||||||
|
|
||||||
|
return await response.json()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_guardrails_with_team_controls():
|
||||||
|
"""
|
||||||
|
- Create a team with guardrails
|
||||||
|
- Make two API Keys
|
||||||
|
- Key 1 not associated with team
|
||||||
|
- Key 2 associated with team (inherits team guardrails)
|
||||||
|
- Request with Key 1 -> should be success with no guardrails
|
||||||
|
- Request with Key 2 -> should error since team guardrails are triggered
|
||||||
|
"""
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
|
||||||
|
# Create team with guardrails
|
||||||
|
team = await create_team(
|
||||||
|
session=session,
|
||||||
|
guardrails=[
|
||||||
|
"bedrock-pre-guard",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
print("team=", team)
|
||||||
|
|
||||||
|
team_id = team["team_id"]
|
||||||
|
|
||||||
|
# Create key with team association
|
||||||
|
key_with_team = await generate_key(session=session, team_id=team_id)
|
||||||
|
key_with_team = key_with_team["key"]
|
||||||
|
|
||||||
|
# Create key without team
|
||||||
|
key_without_team = await generate_key(
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
key_without_team = key_without_team["key"]
|
||||||
|
|
||||||
|
# Test no guardrails triggered for key without a team
|
||||||
|
response, headers = await chat_completion(
|
||||||
|
session,
|
||||||
|
key_without_team,
|
||||||
|
model="fake-openai-endpoint",
|
||||||
|
messages=[{"role": "user", "content": "Hello my name is ishaan@berri.ai"}],
|
||||||
|
)
|
||||||
|
await asyncio.sleep(3)
|
||||||
|
|
||||||
|
print("response=", response, "response headers", headers)
|
||||||
|
assert "x-litellm-applied-guardrails" not in headers
|
||||||
|
|
||||||
|
response, headers = await chat_completion(
|
||||||
|
session,
|
||||||
|
key_with_team,
|
||||||
|
model="fake-openai-endpoint",
|
||||||
|
messages=[{"role": "user", "content": "Hello my name is ishaan@berri.ai"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
print("response headers=", json.dumps(headers, indent=4))
|
||||||
|
|
||||||
|
assert "x-litellm-applied-guardrails" in headers
|
||||||
|
assert headers["x-litellm-applied-guardrails"] == "bedrock-pre-guard"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue