mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
feat - guardrails v2
This commit is contained in:
parent
7721b9b176
commit
8cd1963c11
9 changed files with 211 additions and 49 deletions
|
@ -1,18 +1,10 @@
|
||||||
import Tabs from '@theme/Tabs';
|
import Tabs from '@theme/Tabs';
|
||||||
import TabItem from '@theme/TabItem';
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
# 🛡️ Guardrails
|
# 🛡️ [Beta] Guardrails
|
||||||
|
|
||||||
Setup Prompt Injection Detection, Secret Detection on LiteLLM Proxy
|
Setup Prompt Injection Detection, Secret Detection on LiteLLM Proxy
|
||||||
|
|
||||||
:::info
|
|
||||||
|
|
||||||
✨ Enterprise Only Feature
|
|
||||||
|
|
||||||
Schedule a meeting with us to get an Enterprise License 👉 Talk to founders [here](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
|
||||||
|
|
||||||
:::
|
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
### 1. Setup guardrails on litellm proxy config.yaml
|
### 1. Setup guardrails on litellm proxy config.yaml
|
||||||
|
|
|
@ -15,7 +15,7 @@ from typing import Optional, Literal, Union, Any
|
||||||
import litellm, traceback, sys, uuid
|
import litellm, traceback, sys, uuid
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
|
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
|
||||||
|
@ -29,19 +29,25 @@ from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
import httpx
|
import httpx
|
||||||
import json
|
import json
|
||||||
|
from litellm.types.guardrails import GuardrailEventHooks
|
||||||
|
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
||||||
GUARDRAIL_NAME = "aporio"
|
GUARDRAIL_NAME = "aporio"
|
||||||
|
|
||||||
|
|
||||||
class _ENTERPRISE_Aporio(CustomLogger):
|
class _ENTERPRISE_Aporio(CustomGuardrail):
|
||||||
def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None):
|
def __init__(
|
||||||
|
self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs
|
||||||
|
):
|
||||||
self.async_handler = AsyncHTTPHandler(
|
self.async_handler = AsyncHTTPHandler(
|
||||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
)
|
)
|
||||||
self.aporio_api_key = api_key or os.environ["APORIO_API_KEY"]
|
self.aporio_api_key = api_key or os.environ["APORIO_API_KEY"]
|
||||||
self.aporio_api_base = api_base or os.environ["APORIO_API_BASE"]
|
self.aporio_api_base = api_base or os.environ["APORIO_API_BASE"]
|
||||||
|
self.event_hook: GuardrailEventHooks
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
#### CALL HOOKS - proxy only ####
|
#### CALL HOOKS - proxy only ####
|
||||||
def transform_messages(self, messages: List[dict]) -> List[dict]:
|
def transform_messages(self, messages: List[dict]) -> List[dict]:
|
||||||
|
@ -140,10 +146,15 @@ class _ENTERPRISE_Aporio(CustomLogger):
|
||||||
from litellm.proxy.common_utils.callback_utils import (
|
from litellm.proxy.common_utils.callback_utils import (
|
||||||
add_guardrail_to_applied_guardrails_header,
|
add_guardrail_to_applied_guardrails_header,
|
||||||
)
|
)
|
||||||
|
from litellm.types.guardrails import GuardrailEventHooks
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Use this for the post call moderation with Guardrails
|
Use this for the post call moderation with Guardrails
|
||||||
"""
|
"""
|
||||||
|
event_type: GuardrailEventHooks = GuardrailEventHooks.post_call
|
||||||
|
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
|
||||||
|
return
|
||||||
|
|
||||||
response_str: Optional[str] = convert_litellm_response_object_to_str(response)
|
response_str: Optional[str] = convert_litellm_response_object_to_str(response)
|
||||||
if response_str is not None:
|
if response_str is not None:
|
||||||
await self.make_aporia_api_request(
|
await self.make_aporia_api_request(
|
||||||
|
@ -151,7 +162,7 @@ class _ENTERPRISE_Aporio(CustomLogger):
|
||||||
)
|
)
|
||||||
|
|
||||||
add_guardrail_to_applied_guardrails_header(
|
add_guardrail_to_applied_guardrails_header(
|
||||||
request_data=data, guardrail_name=f"post_call_{GUARDRAIL_NAME}"
|
request_data=data, guardrail_name=self.guardrail_name
|
||||||
)
|
)
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
@ -165,7 +176,13 @@ class _ENTERPRISE_Aporio(CustomLogger):
|
||||||
from litellm.proxy.common_utils.callback_utils import (
|
from litellm.proxy.common_utils.callback_utils import (
|
||||||
add_guardrail_to_applied_guardrails_header,
|
add_guardrail_to_applied_guardrails_header,
|
||||||
)
|
)
|
||||||
|
from litellm.types.guardrails import GuardrailEventHooks
|
||||||
|
|
||||||
|
event_type: GuardrailEventHooks = GuardrailEventHooks.during_call
|
||||||
|
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
|
||||||
|
return
|
||||||
|
|
||||||
|
# old implementation - backwards compatibility
|
||||||
if (
|
if (
|
||||||
await should_proceed_based_on_metadata(
|
await should_proceed_based_on_metadata(
|
||||||
data=data,
|
data=data,
|
||||||
|
@ -182,7 +199,7 @@ class _ENTERPRISE_Aporio(CustomLogger):
|
||||||
if new_messages is not None:
|
if new_messages is not None:
|
||||||
await self.make_aporia_api_request(new_messages=new_messages)
|
await self.make_aporia_api_request(new_messages=new_messages)
|
||||||
add_guardrail_to_applied_guardrails_header(
|
add_guardrail_to_applied_guardrails_header(
|
||||||
request_data=data, guardrail_name=f"during_call_{GUARDRAIL_NAME}"
|
request_data=data, guardrail_name=self.guardrail_name
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
verbose_proxy_logger.warning(
|
verbose_proxy_logger.warning(
|
||||||
|
|
32
litellm/integrations/custom_guardrail.py
Normal file
32
litellm/integrations/custom_guardrail.py
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm.types.guardrails import GuardrailEventHooks
|
||||||
|
|
||||||
|
|
||||||
|
class CustomGuardrail(CustomLogger):
|
||||||
|
|
||||||
|
def __init__(self, guardrail_name: str, event_hook: GuardrailEventHooks, **kwargs):
|
||||||
|
self.guardrail_name = guardrail_name
|
||||||
|
self.event_hook: GuardrailEventHooks = event_hook
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool:
|
||||||
|
verbose_logger.debug(
|
||||||
|
"inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s",
|
||||||
|
self.guardrail_name,
|
||||||
|
event_type,
|
||||||
|
self.event_hook,
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = data.get("metadata") or {}
|
||||||
|
requested_guardrails = metadata.get("guardrails") or []
|
||||||
|
|
||||||
|
if self.guardrail_name not in requested_guardrails:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self.event_hook != event_type:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
|
@ -37,32 +37,35 @@ async def should_proceed_based_on_metadata(data: dict, guardrail_name: str) -> b
|
||||||
|
|
||||||
requested_callback_names = []
|
requested_callback_names = []
|
||||||
|
|
||||||
# get guardrail configs from `init_guardrails.py`
|
# v1 implementation of this
|
||||||
# for all requested guardrails -> get their associated callbacks
|
if isinstance(request_guardrails, dict):
|
||||||
for _guardrail_name, should_run in request_guardrails.items():
|
|
||||||
if should_run is False:
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
"Guardrail %s skipped because request set to False",
|
|
||||||
_guardrail_name,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# lookup the guardrail in guardrail_name_config_map
|
# get guardrail configs from `init_guardrails.py`
|
||||||
guardrail_item: GuardrailItem = litellm.guardrail_name_config_map[
|
# for all requested guardrails -> get their associated callbacks
|
||||||
_guardrail_name
|
for _guardrail_name, should_run in request_guardrails.items():
|
||||||
]
|
if should_run is False:
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"Guardrail %s skipped because request set to False",
|
||||||
|
_guardrail_name,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
guardrail_callbacks = guardrail_item.callbacks
|
# lookup the guardrail in guardrail_name_config_map
|
||||||
requested_callback_names.extend(guardrail_callbacks)
|
guardrail_item: GuardrailItem = litellm.guardrail_name_config_map[
|
||||||
|
_guardrail_name
|
||||||
|
]
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
guardrail_callbacks = guardrail_item.callbacks
|
||||||
"requested_callback_names %s", requested_callback_names
|
requested_callback_names.extend(guardrail_callbacks)
|
||||||
)
|
|
||||||
if guardrail_name in requested_callback_names:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Do no proceeed if - "metadata": { "guardrails": { "lakera_prompt_injection": false } }
|
verbose_proxy_logger.debug(
|
||||||
return False
|
"requested_callback_names %s", requested_callback_names
|
||||||
|
)
|
||||||
|
if guardrail_name in requested_callback_names:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Do no proceeed if - "metadata": { "guardrails": { "lakera_prompt_injection": false } }
|
||||||
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
|
@ -1,12 +1,20 @@
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, RootModel
|
from pydantic import BaseModel, RootModel
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
|
from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
|
||||||
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
|
|
||||||
|
# v2 implementation
|
||||||
|
from litellm.types.guardrails import (
|
||||||
|
Guardrail,
|
||||||
|
GuardrailItem,
|
||||||
|
GuardrailItemSpec,
|
||||||
|
LitellmParams,
|
||||||
|
guardrailConfig,
|
||||||
|
)
|
||||||
|
|
||||||
all_guardrails: List[GuardrailItem] = []
|
all_guardrails: List[GuardrailItem] = []
|
||||||
|
|
||||||
|
@ -66,3 +74,70 @@ def initialize_guardrails(
|
||||||
"error initializing guardrails {}".format(str(e))
|
"error initializing guardrails {}".format(str(e))
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Map guardrail_name: <pre_call>, <post_call>, during_call
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def init_guardrails_v2(all_guardrails: dict):
|
||||||
|
# Convert the loaded data to the TypedDict structure
|
||||||
|
guardrail_list = []
|
||||||
|
|
||||||
|
# Parse each guardrail and replace environment variables
|
||||||
|
for guardrail in all_guardrails:
|
||||||
|
|
||||||
|
# Init litellm params for guardrail
|
||||||
|
litellm_params_data = guardrail["litellm_params"]
|
||||||
|
verbose_proxy_logger.debug("litellm_params= %s", litellm_params_data)
|
||||||
|
litellm_params = LitellmParams(
|
||||||
|
guardrail=litellm_params_data["guardrail"],
|
||||||
|
mode=litellm_params_data["mode"],
|
||||||
|
api_key=litellm_params_data["api_key"],
|
||||||
|
api_base=litellm_params_data["api_base"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if litellm_params["api_key"]:
|
||||||
|
if litellm_params["api_key"].startswith("os.environ/"):
|
||||||
|
litellm_params["api_key"] = litellm.get_secret(
|
||||||
|
litellm_params["api_key"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if litellm_params["api_base"]:
|
||||||
|
if litellm_params["api_base"].startswith("os.environ/"):
|
||||||
|
litellm_params["api_base"] = litellm.get_secret(
|
||||||
|
litellm_params["api_base"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Init guardrail CustomLoggerClass
|
||||||
|
if litellm_params["guardrail"] == "aporia":
|
||||||
|
from litellm.proxy.enterprise.enterprise_hooks.aporio_ai import (
|
||||||
|
_ENTERPRISE_Aporio,
|
||||||
|
)
|
||||||
|
|
||||||
|
_aporia_callback = _ENTERPRISE_Aporio(
|
||||||
|
api_base=litellm_params["api_base"],
|
||||||
|
api_key=litellm_params["api_key"],
|
||||||
|
guardrail_name=guardrail["guardrail_name"],
|
||||||
|
event_hook=litellm_params["mode"],
|
||||||
|
)
|
||||||
|
litellm.callbacks.append(_aporia_callback) # type: ignore
|
||||||
|
elif litellm_params["guardrail"] == "lakera":
|
||||||
|
from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import (
|
||||||
|
_ENTERPRISE_lakeraAI_Moderation,
|
||||||
|
)
|
||||||
|
|
||||||
|
_lakera_callback = _ENTERPRISE_lakeraAI_Moderation()
|
||||||
|
litellm.callbacks.append(_lakera_callback) # type: ignore
|
||||||
|
|
||||||
|
parsed_guardrail = Guardrail(
|
||||||
|
guardrail_name=guardrail["guardrail_name"], litellm_params=litellm_params
|
||||||
|
)
|
||||||
|
|
||||||
|
guardrail_list.append(parsed_guardrail)
|
||||||
|
guardrail_name = guardrail["guardrail_name"]
|
||||||
|
|
||||||
|
# pretty print guardrail_list in green
|
||||||
|
print(f"\nGuardrail List:{guardrail_list}\n") # noqa
|
||||||
|
|
|
@ -308,9 +308,20 @@ async def add_litellm_data_to_request(
|
||||||
for k, v in callback_settings_obj.callback_vars.items():
|
for k, v in callback_settings_obj.callback_vars.items():
|
||||||
data[k] = v
|
data[k] = v
|
||||||
|
|
||||||
|
# Guardrails
|
||||||
|
move_guardrails_to_metadata(
|
||||||
|
data=data, _metadata_variable_name=_metadata_variable_name
|
||||||
|
)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def move_guardrails_to_metadata(data: dict, _metadata_variable_name: str):
|
||||||
|
if "guardrails" in data:
|
||||||
|
data[_metadata_variable_name]["guardrails"] = data["guardrails"]
|
||||||
|
del data["guardrails"]
|
||||||
|
|
||||||
|
|
||||||
def add_provider_specific_headers_to_request(
|
def add_provider_specific_headers_to_request(
|
||||||
data: dict,
|
data: dict,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
|
|
|
@ -5,14 +5,15 @@ model_list:
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
|
|
||||||
guardrails:
|
guardrails:
|
||||||
- guardrail_name: prompt_injection_detection
|
- guardrail_name: "aporia-pre-guard"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
guardrail_name: openai/gpt-3.5-turbo
|
guardrail: aporia # supported values: "aporia", "bedrock", "lakera"
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
mode: "post_call"
|
||||||
api_base: os.environ/OPENAI_API_BASE
|
api_key: os.environ/APORIA_API_KEY_1
|
||||||
- guardrail_name: prompt_injection_detection
|
api_base: os.environ/APORIA_API_BASE_1
|
||||||
|
- guardrail_name: "aporia-post-guard"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
guardrail_name: openai/gpt-3.5-turbo
|
guardrail: aporia # supported values: "aporia", "bedrock", "lakera"
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
mode: "post_call"
|
||||||
api_base: os.environ/OPENAI_API_BASE
|
api_key: os.environ/APORIA_API_KEY_2
|
||||||
|
api_base: os.environ/APORIA_API_BASE_2
|
|
@ -169,7 +169,10 @@ from litellm.proxy.common_utils.openai_endpoint_utils import (
|
||||||
)
|
)
|
||||||
from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router
|
from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router
|
||||||
from litellm.proxy.fine_tuning_endpoints.endpoints import set_fine_tuning_config
|
from litellm.proxy.fine_tuning_endpoints.endpoints import set_fine_tuning_config
|
||||||
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails
|
from litellm.proxy.guardrails.init_guardrails import (
|
||||||
|
init_guardrails_v2,
|
||||||
|
initialize_guardrails,
|
||||||
|
)
|
||||||
from litellm.proxy.health_check import perform_health_check
|
from litellm.proxy.health_check import perform_health_check
|
||||||
from litellm.proxy.health_endpoints._health_endpoints import router as health_router
|
from litellm.proxy.health_endpoints._health_endpoints import router as health_router
|
||||||
from litellm.proxy.hooks.prompt_injection_detection import (
|
from litellm.proxy.hooks.prompt_injection_detection import (
|
||||||
|
@ -1939,6 +1942,11 @@ class ProxyConfig:
|
||||||
async_only_mode=True # only init async clients
|
async_only_mode=True # only init async clients
|
||||||
),
|
),
|
||||||
) # type:ignore
|
) # type:ignore
|
||||||
|
|
||||||
|
# Guardrail settings
|
||||||
|
guardrails_v2 = config.get("guardrails", None)
|
||||||
|
if guardrails_v2:
|
||||||
|
init_guardrails_v2(all_guardrails=guardrails_v2)
|
||||||
return router, router.get_model_list(), general_settings
|
return router, router.get_model_list(), general_settings
|
||||||
|
|
||||||
def get_model_info_with_id(self, model, db_model=False) -> RouterModelInfo:
|
def get_model_info_with_id(self, model, db_model=False) -> RouterModelInfo:
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, TypedDict
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from typing_extensions import Required, TypedDict
|
from typing_extensions import Required, TypedDict
|
||||||
|
@ -63,3 +63,26 @@ class GuardrailItem(BaseModel):
|
||||||
enabled_roles=enabled_roles,
|
enabled_roles=enabled_roles,
|
||||||
callback_args=callback_args,
|
callback_args=callback_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Define the TypedDicts
|
||||||
|
class LitellmParams(TypedDict):
|
||||||
|
guardrail: str
|
||||||
|
mode: str
|
||||||
|
api_key: str
|
||||||
|
api_base: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class Guardrail(TypedDict):
|
||||||
|
guardrail_name: str
|
||||||
|
litellm_params: LitellmParams
|
||||||
|
|
||||||
|
|
||||||
|
class guardrailConfig(TypedDict):
|
||||||
|
guardrails: List[Guardrail]
|
||||||
|
|
||||||
|
|
||||||
|
class GuardrailEventHooks(str, Enum):
|
||||||
|
pre_call = "pre_call"
|
||||||
|
post_call = "post_call"
|
||||||
|
during_call = "during_call"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue