feat - guardrails v2

This commit is contained in:
Ishaan Jaff 2024-08-19 18:24:20 -07:00
parent 7721b9b176
commit 8cd1963c11
9 changed files with 211 additions and 49 deletions

View file

@ -1,18 +1,10 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# 🛡️ Guardrails
# 🛡️ [Beta] Guardrails
Setup Prompt Injection Detection, Secret Detection on LiteLLM Proxy
:::info
✨ Enterprise Only Feature
Schedule a meeting with us to get an Enterprise License 👉 Talk to founders [here](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
:::
## Quick Start
### 1. Setup guardrails on litellm proxy config.yaml

View file

@ -15,7 +15,7 @@ from typing import Optional, Literal, Union, Any
import litellm, traceback, sys, uuid
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.custom_guardrail import CustomGuardrail
from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
@ -29,19 +29,25 @@ from litellm._logging import verbose_proxy_logger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
import httpx
import json
from litellm.types.guardrails import GuardrailEventHooks
litellm.set_verbose = True
GUARDRAIL_NAME = "aporio"
class _ENTERPRISE_Aporio(CustomLogger):
def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None):
class _ENTERPRISE_Aporio(CustomGuardrail):
def __init__(
self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs
):
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
self.aporio_api_key = api_key or os.environ["APORIO_API_KEY"]
self.aporio_api_base = api_base or os.environ["APORIO_API_BASE"]
self.event_hook: GuardrailEventHooks
super().__init__(**kwargs)
#### CALL HOOKS - proxy only ####
def transform_messages(self, messages: List[dict]) -> List[dict]:
@ -140,10 +146,15 @@ class _ENTERPRISE_Aporio(CustomLogger):
from litellm.proxy.common_utils.callback_utils import (
add_guardrail_to_applied_guardrails_header,
)
from litellm.types.guardrails import GuardrailEventHooks
"""
Use this for the post call moderation with Guardrails
"""
event_type: GuardrailEventHooks = GuardrailEventHooks.post_call
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
return
response_str: Optional[str] = convert_litellm_response_object_to_str(response)
if response_str is not None:
await self.make_aporia_api_request(
@ -151,7 +162,7 @@ class _ENTERPRISE_Aporio(CustomLogger):
)
add_guardrail_to_applied_guardrails_header(
request_data=data, guardrail_name=f"post_call_{GUARDRAIL_NAME}"
request_data=data, guardrail_name=self.guardrail_name
)
pass
@ -165,7 +176,13 @@ class _ENTERPRISE_Aporio(CustomLogger):
from litellm.proxy.common_utils.callback_utils import (
add_guardrail_to_applied_guardrails_header,
)
from litellm.types.guardrails import GuardrailEventHooks
event_type: GuardrailEventHooks = GuardrailEventHooks.during_call
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
return
# old implementation - backwards compatibility
if (
await should_proceed_based_on_metadata(
data=data,
@ -182,7 +199,7 @@ class _ENTERPRISE_Aporio(CustomLogger):
if new_messages is not None:
await self.make_aporia_api_request(new_messages=new_messages)
add_guardrail_to_applied_guardrails_header(
request_data=data, guardrail_name=f"during_call_{GUARDRAIL_NAME}"
request_data=data, guardrail_name=self.guardrail_name
)
else:
verbose_proxy_logger.warning(

View file

@ -0,0 +1,32 @@
from typing import Literal
from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.guardrails import GuardrailEventHooks
class CustomGuardrail(CustomLogger):
def __init__(self, guardrail_name: str, event_hook: GuardrailEventHooks, **kwargs):
self.guardrail_name = guardrail_name
self.event_hook: GuardrailEventHooks = event_hook
super().__init__(**kwargs)
def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool:
verbose_logger.debug(
"inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s",
self.guardrail_name,
event_type,
self.event_hook,
)
metadata = data.get("metadata") or {}
requested_guardrails = metadata.get("guardrails") or []
if self.guardrail_name not in requested_guardrails:
return False
if self.event_hook != event_type:
return False
return True

View file

@ -37,32 +37,35 @@ async def should_proceed_based_on_metadata(data: dict, guardrail_name: str) -> b
requested_callback_names = []
# get guardrail configs from `init_guardrails.py`
# for all requested guardrails -> get their associated callbacks
for _guardrail_name, should_run in request_guardrails.items():
if should_run is False:
verbose_proxy_logger.debug(
"Guardrail %s skipped because request set to False",
_guardrail_name,
)
continue
# v1 implementation of this
if isinstance(request_guardrails, dict):
# lookup the guardrail in guardrail_name_config_map
guardrail_item: GuardrailItem = litellm.guardrail_name_config_map[
_guardrail_name
]
# get guardrail configs from `init_guardrails.py`
# for all requested guardrails -> get their associated callbacks
for _guardrail_name, should_run in request_guardrails.items():
if should_run is False:
verbose_proxy_logger.debug(
"Guardrail %s skipped because request set to False",
_guardrail_name,
)
continue
guardrail_callbacks = guardrail_item.callbacks
requested_callback_names.extend(guardrail_callbacks)
# lookup the guardrail in guardrail_name_config_map
guardrail_item: GuardrailItem = litellm.guardrail_name_config_map[
_guardrail_name
]
verbose_proxy_logger.debug(
"requested_callback_names %s", requested_callback_names
)
if guardrail_name in requested_callback_names:
return True
guardrail_callbacks = guardrail_item.callbacks
requested_callback_names.extend(guardrail_callbacks)
# Do no proceeed if - "metadata": { "guardrails": { "lakera_prompt_injection": false } }
return False
verbose_proxy_logger.debug(
"requested_callback_names %s", requested_callback_names
)
if guardrail_name in requested_callback_names:
return True
# Do no proceeed if - "metadata": { "guardrails": { "lakera_prompt_injection": false } }
return False
return True

View file

@ -1,12 +1,20 @@
import traceback
from typing import Dict, List
from typing import Dict, List, Literal
from pydantic import BaseModel, RootModel
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
# v2 implementation
from litellm.types.guardrails import (
Guardrail,
GuardrailItem,
GuardrailItemSpec,
LitellmParams,
guardrailConfig,
)
all_guardrails: List[GuardrailItem] = []
@ -66,3 +74,70 @@ def initialize_guardrails(
"error initializing guardrails {}".format(str(e))
)
raise e
"""
Map guardrail_name: <pre_call>, <post_call>, during_call
"""
def init_guardrails_v2(all_guardrails: dict):
# Convert the loaded data to the TypedDict structure
guardrail_list = []
# Parse each guardrail and replace environment variables
for guardrail in all_guardrails:
# Init litellm params for guardrail
litellm_params_data = guardrail["litellm_params"]
verbose_proxy_logger.debug("litellm_params= %s", litellm_params_data)
litellm_params = LitellmParams(
guardrail=litellm_params_data["guardrail"],
mode=litellm_params_data["mode"],
api_key=litellm_params_data["api_key"],
api_base=litellm_params_data["api_base"],
)
if litellm_params["api_key"]:
if litellm_params["api_key"].startswith("os.environ/"):
litellm_params["api_key"] = litellm.get_secret(
litellm_params["api_key"]
)
if litellm_params["api_base"]:
if litellm_params["api_base"].startswith("os.environ/"):
litellm_params["api_base"] = litellm.get_secret(
litellm_params["api_base"]
)
# Init guardrail CustomLoggerClass
if litellm_params["guardrail"] == "aporia":
from litellm.proxy.enterprise.enterprise_hooks.aporio_ai import (
_ENTERPRISE_Aporio,
)
_aporia_callback = _ENTERPRISE_Aporio(
api_base=litellm_params["api_base"],
api_key=litellm_params["api_key"],
guardrail_name=guardrail["guardrail_name"],
event_hook=litellm_params["mode"],
)
litellm.callbacks.append(_aporia_callback) # type: ignore
elif litellm_params["guardrail"] == "lakera":
from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import (
_ENTERPRISE_lakeraAI_Moderation,
)
_lakera_callback = _ENTERPRISE_lakeraAI_Moderation()
litellm.callbacks.append(_lakera_callback) # type: ignore
parsed_guardrail = Guardrail(
guardrail_name=guardrail["guardrail_name"], litellm_params=litellm_params
)
guardrail_list.append(parsed_guardrail)
guardrail_name = guardrail["guardrail_name"]
# pretty print guardrail_list in green
print(f"\nGuardrail List:{guardrail_list}\n") # noqa

View file

@ -308,9 +308,20 @@ async def add_litellm_data_to_request(
for k, v in callback_settings_obj.callback_vars.items():
data[k] = v
# Guardrails
move_guardrails_to_metadata(
data=data, _metadata_variable_name=_metadata_variable_name
)
return data
def move_guardrails_to_metadata(data: dict, _metadata_variable_name: str):
if "guardrails" in data:
data[_metadata_variable_name]["guardrails"] = data["guardrails"]
del data["guardrails"]
def add_provider_specific_headers_to_request(
data: dict,
headers: dict,

View file

@ -5,14 +5,15 @@ model_list:
api_key: os.environ/OPENAI_API_KEY
guardrails:
- guardrail_name: prompt_injection_detection
- guardrail_name: "aporia-pre-guard"
litellm_params:
guardrail_name: openai/gpt-3.5-turbo
api_key: os.environ/OPENAI_API_KEY
api_base: os.environ/OPENAI_API_BASE
- guardrail_name: prompt_injection_detection
guardrail: aporia # supported values: "aporia", "bedrock", "lakera"
mode: "post_call"
api_key: os.environ/APORIA_API_KEY_1
api_base: os.environ/APORIA_API_BASE_1
- guardrail_name: "aporia-post-guard"
litellm_params:
guardrail_name: openai/gpt-3.5-turbo
api_key: os.environ/OPENAI_API_KEY
api_base: os.environ/OPENAI_API_BASE
guardrail: aporia # supported values: "aporia", "bedrock", "lakera"
mode: "post_call"
api_key: os.environ/APORIA_API_KEY_2
api_base: os.environ/APORIA_API_BASE_2

View file

@ -169,7 +169,10 @@ from litellm.proxy.common_utils.openai_endpoint_utils import (
)
from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router
from litellm.proxy.fine_tuning_endpoints.endpoints import set_fine_tuning_config
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails
from litellm.proxy.guardrails.init_guardrails import (
init_guardrails_v2,
initialize_guardrails,
)
from litellm.proxy.health_check import perform_health_check
from litellm.proxy.health_endpoints._health_endpoints import router as health_router
from litellm.proxy.hooks.prompt_injection_detection import (
@ -1939,6 +1942,11 @@ class ProxyConfig:
async_only_mode=True # only init async clients
),
) # type:ignore
# Guardrail settings
guardrails_v2 = config.get("guardrails", None)
if guardrails_v2:
init_guardrails_v2(all_guardrails=guardrails_v2)
return router, router.get_model_list(), general_settings
def get_model_info_with_id(self, model, db_model=False) -> RouterModelInfo:

View file

@ -1,5 +1,5 @@
from enum import Enum
from typing import Dict, List, Optional
from typing import Dict, List, Optional, TypedDict
from pydantic import BaseModel, ConfigDict
from typing_extensions import Required, TypedDict
@ -63,3 +63,26 @@ class GuardrailItem(BaseModel):
enabled_roles=enabled_roles,
callback_args=callback_args,
)
# Define the TypedDicts
class LitellmParams(TypedDict):
guardrail: str
mode: str
api_key: str
api_base: Optional[str]
class Guardrail(TypedDict):
guardrail_name: str
litellm_params: LitellmParams
class guardrailConfig(TypedDict):
guardrails: List[Guardrail]
class GuardrailEventHooks(str, Enum):
pre_call = "pre_call"
post_call = "post_call"
during_call = "during_call"