feat - guardrails v2

This commit is contained in:
Ishaan Jaff 2024-08-19 18:24:20 -07:00
parent 7721b9b176
commit 8cd1963c11
9 changed files with 211 additions and 49 deletions

View file

@ -1,18 +1,10 @@
import Tabs from '@theme/Tabs'; import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem'; import TabItem from '@theme/TabItem';
# 🛡️ Guardrails # 🛡️ [Beta] Guardrails
Setup Prompt Injection Detection, Secret Detection on LiteLLM Proxy Setup Prompt Injection Detection, Secret Detection on LiteLLM Proxy
:::info
✨ Enterprise Only Feature
Schedule a meeting with us to get an Enterprise License 👉 Talk to founders [here](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
:::
## Quick Start ## Quick Start
### 1. Setup guardrails on litellm proxy config.yaml ### 1. Setup guardrails on litellm proxy config.yaml

View file

@ -15,7 +15,7 @@ from typing import Optional, Literal, Union, Any
import litellm, traceback, sys, uuid import litellm, traceback, sys, uuid
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_guardrail import CustomGuardrail
from fastapi import HTTPException from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
@ -29,19 +29,25 @@ from litellm._logging import verbose_proxy_logger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
import httpx import httpx
import json import json
from litellm.types.guardrails import GuardrailEventHooks
litellm.set_verbose = True litellm.set_verbose = True
GUARDRAIL_NAME = "aporio" GUARDRAIL_NAME = "aporio"
class _ENTERPRISE_Aporio(CustomLogger): class _ENTERPRISE_Aporio(CustomGuardrail):
def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None): def __init__(
self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs
):
self.async_handler = AsyncHTTPHandler( self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0) timeout=httpx.Timeout(timeout=600.0, connect=5.0)
) )
self.aporio_api_key = api_key or os.environ["APORIO_API_KEY"] self.aporio_api_key = api_key or os.environ["APORIO_API_KEY"]
self.aporio_api_base = api_base or os.environ["APORIO_API_BASE"] self.aporio_api_base = api_base or os.environ["APORIO_API_BASE"]
self.event_hook: GuardrailEventHooks
super().__init__(**kwargs)
#### CALL HOOKS - proxy only #### #### CALL HOOKS - proxy only ####
def transform_messages(self, messages: List[dict]) -> List[dict]: def transform_messages(self, messages: List[dict]) -> List[dict]:
@ -140,10 +146,15 @@ class _ENTERPRISE_Aporio(CustomLogger):
from litellm.proxy.common_utils.callback_utils import ( from litellm.proxy.common_utils.callback_utils import (
add_guardrail_to_applied_guardrails_header, add_guardrail_to_applied_guardrails_header,
) )
from litellm.types.guardrails import GuardrailEventHooks
""" """
Use this for the post call moderation with Guardrails Use this for the post call moderation with Guardrails
""" """
event_type: GuardrailEventHooks = GuardrailEventHooks.post_call
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
return
response_str: Optional[str] = convert_litellm_response_object_to_str(response) response_str: Optional[str] = convert_litellm_response_object_to_str(response)
if response_str is not None: if response_str is not None:
await self.make_aporia_api_request( await self.make_aporia_api_request(
@ -151,7 +162,7 @@ class _ENTERPRISE_Aporio(CustomLogger):
) )
add_guardrail_to_applied_guardrails_header( add_guardrail_to_applied_guardrails_header(
request_data=data, guardrail_name=f"post_call_{GUARDRAIL_NAME}" request_data=data, guardrail_name=self.guardrail_name
) )
pass pass
@ -165,7 +176,13 @@ class _ENTERPRISE_Aporio(CustomLogger):
from litellm.proxy.common_utils.callback_utils import ( from litellm.proxy.common_utils.callback_utils import (
add_guardrail_to_applied_guardrails_header, add_guardrail_to_applied_guardrails_header,
) )
from litellm.types.guardrails import GuardrailEventHooks
event_type: GuardrailEventHooks = GuardrailEventHooks.during_call
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
return
# old implementation - backwards compatibility
if ( if (
await should_proceed_based_on_metadata( await should_proceed_based_on_metadata(
data=data, data=data,
@ -182,7 +199,7 @@ class _ENTERPRISE_Aporio(CustomLogger):
if new_messages is not None: if new_messages is not None:
await self.make_aporia_api_request(new_messages=new_messages) await self.make_aporia_api_request(new_messages=new_messages)
add_guardrail_to_applied_guardrails_header( add_guardrail_to_applied_guardrails_header(
request_data=data, guardrail_name=f"during_call_{GUARDRAIL_NAME}" request_data=data, guardrail_name=self.guardrail_name
) )
else: else:
verbose_proxy_logger.warning( verbose_proxy_logger.warning(

View file

@ -0,0 +1,32 @@
from typing import Literal
from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.guardrails import GuardrailEventHooks
class CustomGuardrail(CustomLogger):
def __init__(self, guardrail_name: str, event_hook: GuardrailEventHooks, **kwargs):
self.guardrail_name = guardrail_name
self.event_hook: GuardrailEventHooks = event_hook
super().__init__(**kwargs)
def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool:
verbose_logger.debug(
"inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s",
self.guardrail_name,
event_type,
self.event_hook,
)
metadata = data.get("metadata") or {}
requested_guardrails = metadata.get("guardrails") or []
if self.guardrail_name not in requested_guardrails:
return False
if self.event_hook != event_type:
return False
return True

View file

@ -37,6 +37,9 @@ async def should_proceed_based_on_metadata(data: dict, guardrail_name: str) -> b
requested_callback_names = [] requested_callback_names = []
# v1 implementation of this
if isinstance(request_guardrails, dict):
# get guardrail configs from `init_guardrails.py` # get guardrail configs from `init_guardrails.py`
# for all requested guardrails -> get their associated callbacks # for all requested guardrails -> get their associated callbacks
for _guardrail_name, should_run in request_guardrails.items(): for _guardrail_name, should_run in request_guardrails.items():

View file

@ -1,12 +1,20 @@
import traceback import traceback
from typing import Dict, List from typing import Dict, List, Literal
from pydantic import BaseModel, RootModel from pydantic import BaseModel, RootModel
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
# v2 implementation
from litellm.types.guardrails import (
Guardrail,
GuardrailItem,
GuardrailItemSpec,
LitellmParams,
guardrailConfig,
)
all_guardrails: List[GuardrailItem] = [] all_guardrails: List[GuardrailItem] = []
@ -66,3 +74,70 @@ def initialize_guardrails(
"error initializing guardrails {}".format(str(e)) "error initializing guardrails {}".format(str(e))
) )
raise e raise e
"""
Map guardrail_name: <pre_call>, <post_call>, during_call
"""
def init_guardrails_v2(all_guardrails: dict):
# Convert the loaded data to the TypedDict structure
guardrail_list = []
# Parse each guardrail and replace environment variables
for guardrail in all_guardrails:
# Init litellm params for guardrail
litellm_params_data = guardrail["litellm_params"]
verbose_proxy_logger.debug("litellm_params= %s", litellm_params_data)
litellm_params = LitellmParams(
guardrail=litellm_params_data["guardrail"],
mode=litellm_params_data["mode"],
api_key=litellm_params_data["api_key"],
api_base=litellm_params_data["api_base"],
)
if litellm_params["api_key"]:
if litellm_params["api_key"].startswith("os.environ/"):
litellm_params["api_key"] = litellm.get_secret(
litellm_params["api_key"]
)
if litellm_params["api_base"]:
if litellm_params["api_base"].startswith("os.environ/"):
litellm_params["api_base"] = litellm.get_secret(
litellm_params["api_base"]
)
# Init guardrail CustomLoggerClass
if litellm_params["guardrail"] == "aporia":
from litellm.proxy.enterprise.enterprise_hooks.aporio_ai import (
_ENTERPRISE_Aporio,
)
_aporia_callback = _ENTERPRISE_Aporio(
api_base=litellm_params["api_base"],
api_key=litellm_params["api_key"],
guardrail_name=guardrail["guardrail_name"],
event_hook=litellm_params["mode"],
)
litellm.callbacks.append(_aporia_callback) # type: ignore
elif litellm_params["guardrail"] == "lakera":
from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import (
_ENTERPRISE_lakeraAI_Moderation,
)
_lakera_callback = _ENTERPRISE_lakeraAI_Moderation()
litellm.callbacks.append(_lakera_callback) # type: ignore
parsed_guardrail = Guardrail(
guardrail_name=guardrail["guardrail_name"], litellm_params=litellm_params
)
guardrail_list.append(parsed_guardrail)
guardrail_name = guardrail["guardrail_name"]
# pretty print guardrail_list in green
print(f"\nGuardrail List:{guardrail_list}\n") # noqa

View file

@ -308,9 +308,20 @@ async def add_litellm_data_to_request(
for k, v in callback_settings_obj.callback_vars.items(): for k, v in callback_settings_obj.callback_vars.items():
data[k] = v data[k] = v
# Guardrails
move_guardrails_to_metadata(
data=data, _metadata_variable_name=_metadata_variable_name
)
return data return data
def move_guardrails_to_metadata(data: dict, _metadata_variable_name: str):
if "guardrails" in data:
data[_metadata_variable_name]["guardrails"] = data["guardrails"]
del data["guardrails"]
def add_provider_specific_headers_to_request( def add_provider_specific_headers_to_request(
data: dict, data: dict,
headers: dict, headers: dict,

View file

@ -5,14 +5,15 @@ model_list:
api_key: os.environ/OPENAI_API_KEY api_key: os.environ/OPENAI_API_KEY
guardrails: guardrails:
- guardrail_name: prompt_injection_detection - guardrail_name: "aporia-pre-guard"
litellm_params: litellm_params:
guardrail_name: openai/gpt-3.5-turbo guardrail: aporia # supported values: "aporia", "bedrock", "lakera"
api_key: os.environ/OPENAI_API_KEY mode: "post_call"
api_base: os.environ/OPENAI_API_BASE api_key: os.environ/APORIA_API_KEY_1
- guardrail_name: prompt_injection_detection api_base: os.environ/APORIA_API_BASE_1
- guardrail_name: "aporia-post-guard"
litellm_params: litellm_params:
guardrail_name: openai/gpt-3.5-turbo guardrail: aporia # supported values: "aporia", "bedrock", "lakera"
api_key: os.environ/OPENAI_API_KEY mode: "post_call"
api_base: os.environ/OPENAI_API_BASE api_key: os.environ/APORIA_API_KEY_2
api_base: os.environ/APORIA_API_BASE_2

View file

@ -169,7 +169,10 @@ from litellm.proxy.common_utils.openai_endpoint_utils import (
) )
from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router
from litellm.proxy.fine_tuning_endpoints.endpoints import set_fine_tuning_config from litellm.proxy.fine_tuning_endpoints.endpoints import set_fine_tuning_config
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails from litellm.proxy.guardrails.init_guardrails import (
init_guardrails_v2,
initialize_guardrails,
)
from litellm.proxy.health_check import perform_health_check from litellm.proxy.health_check import perform_health_check
from litellm.proxy.health_endpoints._health_endpoints import router as health_router from litellm.proxy.health_endpoints._health_endpoints import router as health_router
from litellm.proxy.hooks.prompt_injection_detection import ( from litellm.proxy.hooks.prompt_injection_detection import (
@ -1939,6 +1942,11 @@ class ProxyConfig:
async_only_mode=True # only init async clients async_only_mode=True # only init async clients
), ),
) # type:ignore ) # type:ignore
# Guardrail settings
guardrails_v2 = config.get("guardrails", None)
if guardrails_v2:
init_guardrails_v2(all_guardrails=guardrails_v2)
return router, router.get_model_list(), general_settings return router, router.get_model_list(), general_settings
def get_model_info_with_id(self, model, db_model=False) -> RouterModelInfo: def get_model_info_with_id(self, model, db_model=False) -> RouterModelInfo:

View file

@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from typing import Dict, List, Optional from typing import Dict, List, Optional, TypedDict
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from typing_extensions import Required, TypedDict from typing_extensions import Required, TypedDict
@ -63,3 +63,26 @@ class GuardrailItem(BaseModel):
enabled_roles=enabled_roles, enabled_roles=enabled_roles,
callback_args=callback_args, callback_args=callback_args,
) )
# Define the TypedDicts
class LitellmParams(TypedDict):
guardrail: str
mode: str
api_key: str
api_base: Optional[str]
class Guardrail(TypedDict):
guardrail_name: str
litellm_params: LitellmParams
class guardrailConfig(TypedDict):
guardrails: List[Guardrail]
class GuardrailEventHooks(str, Enum):
pre_call = "pre_call"
post_call = "post_call"
during_call = "during_call"