feat - guardrails v2

This commit is contained in:
Ishaan Jaff 2024-08-19 18:24:20 -07:00
parent 7721b9b176
commit 8cd1963c11
9 changed files with 211 additions and 49 deletions

View file

@ -1,12 +1,20 @@
import traceback
from typing import Dict, List
from typing import Dict, List, Literal
from pydantic import BaseModel, RootModel
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
# v2 implementation
from litellm.types.guardrails import (
Guardrail,
GuardrailItem,
GuardrailItemSpec,
LitellmParams,
guardrailConfig,
)
all_guardrails: List[GuardrailItem] = []
@ -66,3 +74,70 @@ def initialize_guardrails(
"error initializing guardrails {}".format(str(e))
)
raise e
"""
Map guardrail_name: <pre_call>, <post_call>, during_call
"""
def init_guardrails_v2(all_guardrails: dict):
# Convert the loaded data to the TypedDict structure
guardrail_list = []
# Parse each guardrail and replace environment variables
for guardrail in all_guardrails:
# Init litellm params for guardrail
litellm_params_data = guardrail["litellm_params"]
verbose_proxy_logger.debug("litellm_params= %s", litellm_params_data)
litellm_params = LitellmParams(
guardrail=litellm_params_data["guardrail"],
mode=litellm_params_data["mode"],
api_key=litellm_params_data["api_key"],
api_base=litellm_params_data["api_base"],
)
if litellm_params["api_key"]:
if litellm_params["api_key"].startswith("os.environ/"):
litellm_params["api_key"] = litellm.get_secret(
litellm_params["api_key"]
)
if litellm_params["api_base"]:
if litellm_params["api_base"].startswith("os.environ/"):
litellm_params["api_base"] = litellm.get_secret(
litellm_params["api_base"]
)
# Init guardrail CustomLoggerClass
if litellm_params["guardrail"] == "aporia":
from litellm.proxy.enterprise.enterprise_hooks.aporio_ai import (
_ENTERPRISE_Aporio,
)
_aporia_callback = _ENTERPRISE_Aporio(
api_base=litellm_params["api_base"],
api_key=litellm_params["api_key"],
guardrail_name=guardrail["guardrail_name"],
event_hook=litellm_params["mode"],
)
litellm.callbacks.append(_aporia_callback) # type: ignore
elif litellm_params["guardrail"] == "lakera":
from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import (
_ENTERPRISE_lakeraAI_Moderation,
)
_lakera_callback = _ENTERPRISE_lakeraAI_Moderation()
litellm.callbacks.append(_lakera_callback) # type: ignore
parsed_guardrail = Guardrail(
guardrail_name=guardrail["guardrail_name"], litellm_params=litellm_params
)
guardrail_list.append(parsed_guardrail)
guardrail_name = guardrail["guardrail_name"]
# pretty print guardrail_list in green
print(f"\nGuardrail List:{guardrail_list}\n") # noqa