mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
feat - guardrails v2
This commit is contained in:
parent
7721b9b176
commit
8cd1963c11
9 changed files with 211 additions and 49 deletions
|
@ -1,12 +1,20 @@
|
|||
import traceback
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Literal
|
||||
|
||||
from pydantic import BaseModel, RootModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
|
||||
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
|
||||
|
||||
# v2 implementation
|
||||
from litellm.types.guardrails import (
|
||||
Guardrail,
|
||||
GuardrailItem,
|
||||
GuardrailItemSpec,
|
||||
LitellmParams,
|
||||
guardrailConfig,
|
||||
)
|
||||
|
||||
all_guardrails: List[GuardrailItem] = []
|
||||
|
||||
|
@ -66,3 +74,70 @@ def initialize_guardrails(
|
|||
"error initializing guardrails {}".format(str(e))
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
"""
|
||||
Map guardrail_name: <pre_call>, <post_call>, during_call
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def init_guardrails_v2(all_guardrails: dict):
|
||||
# Convert the loaded data to the TypedDict structure
|
||||
guardrail_list = []
|
||||
|
||||
# Parse each guardrail and replace environment variables
|
||||
for guardrail in all_guardrails:
|
||||
|
||||
# Init litellm params for guardrail
|
||||
litellm_params_data = guardrail["litellm_params"]
|
||||
verbose_proxy_logger.debug("litellm_params= %s", litellm_params_data)
|
||||
litellm_params = LitellmParams(
|
||||
guardrail=litellm_params_data["guardrail"],
|
||||
mode=litellm_params_data["mode"],
|
||||
api_key=litellm_params_data["api_key"],
|
||||
api_base=litellm_params_data["api_base"],
|
||||
)
|
||||
|
||||
if litellm_params["api_key"]:
|
||||
if litellm_params["api_key"].startswith("os.environ/"):
|
||||
litellm_params["api_key"] = litellm.get_secret(
|
||||
litellm_params["api_key"]
|
||||
)
|
||||
|
||||
if litellm_params["api_base"]:
|
||||
if litellm_params["api_base"].startswith("os.environ/"):
|
||||
litellm_params["api_base"] = litellm.get_secret(
|
||||
litellm_params["api_base"]
|
||||
)
|
||||
|
||||
# Init guardrail CustomLoggerClass
|
||||
if litellm_params["guardrail"] == "aporia":
|
||||
from litellm.proxy.enterprise.enterprise_hooks.aporio_ai import (
|
||||
_ENTERPRISE_Aporio,
|
||||
)
|
||||
|
||||
_aporia_callback = _ENTERPRISE_Aporio(
|
||||
api_base=litellm_params["api_base"],
|
||||
api_key=litellm_params["api_key"],
|
||||
guardrail_name=guardrail["guardrail_name"],
|
||||
event_hook=litellm_params["mode"],
|
||||
)
|
||||
litellm.callbacks.append(_aporia_callback) # type: ignore
|
||||
elif litellm_params["guardrail"] == "lakera":
|
||||
from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import (
|
||||
_ENTERPRISE_lakeraAI_Moderation,
|
||||
)
|
||||
|
||||
_lakera_callback = _ENTERPRISE_lakeraAI_Moderation()
|
||||
litellm.callbacks.append(_lakera_callback) # type: ignore
|
||||
|
||||
parsed_guardrail = Guardrail(
|
||||
guardrail_name=guardrail["guardrail_name"], litellm_params=litellm_params
|
||||
)
|
||||
|
||||
guardrail_list.append(parsed_guardrail)
|
||||
guardrail_name = guardrail["guardrail_name"]
|
||||
|
||||
# pretty print guardrail_list in green
|
||||
print(f"\nGuardrail List:{guardrail_list}\n") # noqa
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue