forked from phoenix/litellm-mirror
feat - guardrails v2
This commit is contained in:
parent
7721b9b176
commit
8cd1963c11
9 changed files with 211 additions and 49 deletions
|
@ -1,18 +1,10 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# 🛡️ Guardrails
|
||||
# 🛡️ [Beta] Guardrails
|
||||
|
||||
Setup Prompt Injection Detection, Secret Detection on LiteLLM Proxy
|
||||
|
||||
:::info
|
||||
|
||||
✨ Enterprise Only Feature
|
||||
|
||||
Schedule a meeting with us to get an Enterprise License 👉 Talk to founders [here](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
||||
|
||||
:::
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Setup guardrails on litellm proxy config.yaml
|
||||
|
|
|
@ -15,7 +15,7 @@ from typing import Optional, Literal, Union, Any
|
|||
import litellm, traceback, sys, uuid
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from fastapi import HTTPException
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
|
||||
|
@ -29,19 +29,25 @@ from litellm._logging import verbose_proxy_logger
|
|||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
import httpx
|
||||
import json
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
GUARDRAIL_NAME = "aporio"
|
||||
|
||||
|
||||
class _ENTERPRISE_Aporio(CustomLogger):
|
||||
def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None):
|
||||
class _ENTERPRISE_Aporio(CustomGuardrail):
|
||||
def __init__(
|
||||
self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs
|
||||
):
|
||||
self.async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
)
|
||||
self.aporio_api_key = api_key or os.environ["APORIO_API_KEY"]
|
||||
self.aporio_api_base = api_base or os.environ["APORIO_API_BASE"]
|
||||
self.event_hook: GuardrailEventHooks
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
#### CALL HOOKS - proxy only ####
|
||||
def transform_messages(self, messages: List[dict]) -> List[dict]:
|
||||
|
@ -140,10 +146,15 @@ class _ENTERPRISE_Aporio(CustomLogger):
|
|||
from litellm.proxy.common_utils.callback_utils import (
|
||||
add_guardrail_to_applied_guardrails_header,
|
||||
)
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
|
||||
"""
|
||||
Use this for the post call moderation with Guardrails
|
||||
"""
|
||||
event_type: GuardrailEventHooks = GuardrailEventHooks.post_call
|
||||
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
|
||||
return
|
||||
|
||||
response_str: Optional[str] = convert_litellm_response_object_to_str(response)
|
||||
if response_str is not None:
|
||||
await self.make_aporia_api_request(
|
||||
|
@ -151,7 +162,7 @@ class _ENTERPRISE_Aporio(CustomLogger):
|
|||
)
|
||||
|
||||
add_guardrail_to_applied_guardrails_header(
|
||||
request_data=data, guardrail_name=f"post_call_{GUARDRAIL_NAME}"
|
||||
request_data=data, guardrail_name=self.guardrail_name
|
||||
)
|
||||
|
||||
pass
|
||||
|
@ -165,7 +176,13 @@ class _ENTERPRISE_Aporio(CustomLogger):
|
|||
from litellm.proxy.common_utils.callback_utils import (
|
||||
add_guardrail_to_applied_guardrails_header,
|
||||
)
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
|
||||
event_type: GuardrailEventHooks = GuardrailEventHooks.during_call
|
||||
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
|
||||
return
|
||||
|
||||
# old implementation - backwards compatibility
|
||||
if (
|
||||
await should_proceed_based_on_metadata(
|
||||
data=data,
|
||||
|
@ -182,7 +199,7 @@ class _ENTERPRISE_Aporio(CustomLogger):
|
|||
if new_messages is not None:
|
||||
await self.make_aporia_api_request(new_messages=new_messages)
|
||||
add_guardrail_to_applied_guardrails_header(
|
||||
request_data=data, guardrail_name=f"during_call_{GUARDRAIL_NAME}"
|
||||
request_data=data, guardrail_name=self.guardrail_name
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
|
|
32
litellm/integrations/custom_guardrail.py
Normal file
32
litellm/integrations/custom_guardrail.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
from typing import Literal
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
|
||||
|
||||
class CustomGuardrail(CustomLogger):
|
||||
|
||||
def __init__(self, guardrail_name: str, event_hook: GuardrailEventHooks, **kwargs):
|
||||
self.guardrail_name = guardrail_name
|
||||
self.event_hook: GuardrailEventHooks = event_hook
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool:
|
||||
verbose_logger.debug(
|
||||
"inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s",
|
||||
self.guardrail_name,
|
||||
event_type,
|
||||
self.event_hook,
|
||||
)
|
||||
|
||||
metadata = data.get("metadata") or {}
|
||||
requested_guardrails = metadata.get("guardrails") or []
|
||||
|
||||
if self.guardrail_name not in requested_guardrails:
|
||||
return False
|
||||
|
||||
if self.event_hook != event_type:
|
||||
return False
|
||||
|
||||
return True
|
|
@ -37,32 +37,35 @@ async def should_proceed_based_on_metadata(data: dict, guardrail_name: str) -> b
|
|||
|
||||
requested_callback_names = []
|
||||
|
||||
# get guardrail configs from `init_guardrails.py`
|
||||
# for all requested guardrails -> get their associated callbacks
|
||||
for _guardrail_name, should_run in request_guardrails.items():
|
||||
if should_run is False:
|
||||
verbose_proxy_logger.debug(
|
||||
"Guardrail %s skipped because request set to False",
|
||||
_guardrail_name,
|
||||
)
|
||||
continue
|
||||
# v1 implementation of this
|
||||
if isinstance(request_guardrails, dict):
|
||||
|
||||
# lookup the guardrail in guardrail_name_config_map
|
||||
guardrail_item: GuardrailItem = litellm.guardrail_name_config_map[
|
||||
_guardrail_name
|
||||
]
|
||||
# get guardrail configs from `init_guardrails.py`
|
||||
# for all requested guardrails -> get their associated callbacks
|
||||
for _guardrail_name, should_run in request_guardrails.items():
|
||||
if should_run is False:
|
||||
verbose_proxy_logger.debug(
|
||||
"Guardrail %s skipped because request set to False",
|
||||
_guardrail_name,
|
||||
)
|
||||
continue
|
||||
|
||||
guardrail_callbacks = guardrail_item.callbacks
|
||||
requested_callback_names.extend(guardrail_callbacks)
|
||||
# lookup the guardrail in guardrail_name_config_map
|
||||
guardrail_item: GuardrailItem = litellm.guardrail_name_config_map[
|
||||
_guardrail_name
|
||||
]
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"requested_callback_names %s", requested_callback_names
|
||||
)
|
||||
if guardrail_name in requested_callback_names:
|
||||
return True
|
||||
guardrail_callbacks = guardrail_item.callbacks
|
||||
requested_callback_names.extend(guardrail_callbacks)
|
||||
|
||||
# Do no proceeed if - "metadata": { "guardrails": { "lakera_prompt_injection": false } }
|
||||
return False
|
||||
verbose_proxy_logger.debug(
|
||||
"requested_callback_names %s", requested_callback_names
|
||||
)
|
||||
if guardrail_name in requested_callback_names:
|
||||
return True
|
||||
|
||||
# Do no proceeed if - "metadata": { "guardrails": { "lakera_prompt_injection": false } }
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
|
|
@ -1,12 +1,20 @@
|
|||
import traceback
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Literal
|
||||
|
||||
from pydantic import BaseModel, RootModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
|
||||
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
|
||||
|
||||
# v2 implementation
|
||||
from litellm.types.guardrails import (
|
||||
Guardrail,
|
||||
GuardrailItem,
|
||||
GuardrailItemSpec,
|
||||
LitellmParams,
|
||||
guardrailConfig,
|
||||
)
|
||||
|
||||
all_guardrails: List[GuardrailItem] = []
|
||||
|
||||
|
@ -66,3 +74,70 @@ def initialize_guardrails(
|
|||
"error initializing guardrails {}".format(str(e))
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
"""
|
||||
Map guardrail_name: <pre_call>, <post_call>, during_call
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def init_guardrails_v2(all_guardrails: dict):
|
||||
# Convert the loaded data to the TypedDict structure
|
||||
guardrail_list = []
|
||||
|
||||
# Parse each guardrail and replace environment variables
|
||||
for guardrail in all_guardrails:
|
||||
|
||||
# Init litellm params for guardrail
|
||||
litellm_params_data = guardrail["litellm_params"]
|
||||
verbose_proxy_logger.debug("litellm_params= %s", litellm_params_data)
|
||||
litellm_params = LitellmParams(
|
||||
guardrail=litellm_params_data["guardrail"],
|
||||
mode=litellm_params_data["mode"],
|
||||
api_key=litellm_params_data["api_key"],
|
||||
api_base=litellm_params_data["api_base"],
|
||||
)
|
||||
|
||||
if litellm_params["api_key"]:
|
||||
if litellm_params["api_key"].startswith("os.environ/"):
|
||||
litellm_params["api_key"] = litellm.get_secret(
|
||||
litellm_params["api_key"]
|
||||
)
|
||||
|
||||
if litellm_params["api_base"]:
|
||||
if litellm_params["api_base"].startswith("os.environ/"):
|
||||
litellm_params["api_base"] = litellm.get_secret(
|
||||
litellm_params["api_base"]
|
||||
)
|
||||
|
||||
# Init guardrail CustomLoggerClass
|
||||
if litellm_params["guardrail"] == "aporia":
|
||||
from litellm.proxy.enterprise.enterprise_hooks.aporio_ai import (
|
||||
_ENTERPRISE_Aporio,
|
||||
)
|
||||
|
||||
_aporia_callback = _ENTERPRISE_Aporio(
|
||||
api_base=litellm_params["api_base"],
|
||||
api_key=litellm_params["api_key"],
|
||||
guardrail_name=guardrail["guardrail_name"],
|
||||
event_hook=litellm_params["mode"],
|
||||
)
|
||||
litellm.callbacks.append(_aporia_callback) # type: ignore
|
||||
elif litellm_params["guardrail"] == "lakera":
|
||||
from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import (
|
||||
_ENTERPRISE_lakeraAI_Moderation,
|
||||
)
|
||||
|
||||
_lakera_callback = _ENTERPRISE_lakeraAI_Moderation()
|
||||
litellm.callbacks.append(_lakera_callback) # type: ignore
|
||||
|
||||
parsed_guardrail = Guardrail(
|
||||
guardrail_name=guardrail["guardrail_name"], litellm_params=litellm_params
|
||||
)
|
||||
|
||||
guardrail_list.append(parsed_guardrail)
|
||||
guardrail_name = guardrail["guardrail_name"]
|
||||
|
||||
# pretty print guardrail_list in green
|
||||
print(f"\nGuardrail List:{guardrail_list}\n") # noqa
|
||||
|
|
|
@ -308,9 +308,20 @@ async def add_litellm_data_to_request(
|
|||
for k, v in callback_settings_obj.callback_vars.items():
|
||||
data[k] = v
|
||||
|
||||
# Guardrails
|
||||
move_guardrails_to_metadata(
|
||||
data=data, _metadata_variable_name=_metadata_variable_name
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def move_guardrails_to_metadata(data: dict, _metadata_variable_name: str):
|
||||
if "guardrails" in data:
|
||||
data[_metadata_variable_name]["guardrails"] = data["guardrails"]
|
||||
del data["guardrails"]
|
||||
|
||||
|
||||
def add_provider_specific_headers_to_request(
|
||||
data: dict,
|
||||
headers: dict,
|
||||
|
|
|
@ -5,14 +5,15 @@ model_list:
|
|||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
guardrails:
|
||||
- guardrail_name: prompt_injection_detection
|
||||
- guardrail_name: "aporia-pre-guard"
|
||||
litellm_params:
|
||||
guardrail_name: openai/gpt-3.5-turbo
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
api_base: os.environ/OPENAI_API_BASE
|
||||
- guardrail_name: prompt_injection_detection
|
||||
guardrail: aporia # supported values: "aporia", "bedrock", "lakera"
|
||||
mode: "post_call"
|
||||
api_key: os.environ/APORIA_API_KEY_1
|
||||
api_base: os.environ/APORIA_API_BASE_1
|
||||
- guardrail_name: "aporia-post-guard"
|
||||
litellm_params:
|
||||
guardrail_name: openai/gpt-3.5-turbo
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
api_base: os.environ/OPENAI_API_BASE
|
||||
|
||||
guardrail: aporia # supported values: "aporia", "bedrock", "lakera"
|
||||
mode: "post_call"
|
||||
api_key: os.environ/APORIA_API_KEY_2
|
||||
api_base: os.environ/APORIA_API_BASE_2
|
|
@ -169,7 +169,10 @@ from litellm.proxy.common_utils.openai_endpoint_utils import (
|
|||
)
|
||||
from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router
|
||||
from litellm.proxy.fine_tuning_endpoints.endpoints import set_fine_tuning_config
|
||||
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails
|
||||
from litellm.proxy.guardrails.init_guardrails import (
|
||||
init_guardrails_v2,
|
||||
initialize_guardrails,
|
||||
)
|
||||
from litellm.proxy.health_check import perform_health_check
|
||||
from litellm.proxy.health_endpoints._health_endpoints import router as health_router
|
||||
from litellm.proxy.hooks.prompt_injection_detection import (
|
||||
|
@ -1939,6 +1942,11 @@ class ProxyConfig:
|
|||
async_only_mode=True # only init async clients
|
||||
),
|
||||
) # type:ignore
|
||||
|
||||
# Guardrail settings
|
||||
guardrails_v2 = config.get("guardrails", None)
|
||||
if guardrails_v2:
|
||||
init_guardrails_v2(all_guardrails=guardrails_v2)
|
||||
return router, router.get_model_list(), general_settings
|
||||
|
||||
def get_model_info_with_id(self, model, db_model=False) -> RouterModelInfo:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from enum import Enum
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, TypedDict
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
@ -63,3 +63,26 @@ class GuardrailItem(BaseModel):
|
|||
enabled_roles=enabled_roles,
|
||||
callback_args=callback_args,
|
||||
)
|
||||
|
||||
|
||||
# Define the TypedDicts
|
||||
class LitellmParams(TypedDict):
|
||||
guardrail: str
|
||||
mode: str
|
||||
api_key: str
|
||||
api_base: Optional[str]
|
||||
|
||||
|
||||
class Guardrail(TypedDict):
|
||||
guardrail_name: str
|
||||
litellm_params: LitellmParams
|
||||
|
||||
|
||||
class guardrailConfig(TypedDict):
|
||||
guardrails: List[Guardrail]
|
||||
|
||||
|
||||
class GuardrailEventHooks(str, Enum):
|
||||
pre_call = "pre_call"
|
||||
post_call = "post_call"
|
||||
during_call = "during_call"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue