mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
init custom guardrail class
This commit is contained in:
parent
7e064f2dcd
commit
dcd39dac00
2 changed files with 26 additions and 3 deletions
|
@ -84,7 +84,10 @@ Map guardrail_name: <pre_call>, <post_call>, during_call
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def init_guardrails_v2(all_guardrails: dict):
|
def init_guardrails_v2(
|
||||||
|
all_guardrails: dict,
|
||||||
|
config_file_path: str,
|
||||||
|
):
|
||||||
# Convert the loaded data to the TypedDict structure
|
# Convert the loaded data to the TypedDict structure
|
||||||
guardrail_list = []
|
guardrail_list = []
|
||||||
|
|
||||||
|
@ -166,6 +169,10 @@ def init_guardrails_v2(all_guardrails: dict):
|
||||||
isinstance(litellm_params["guardrail"], str)
|
isinstance(litellm_params["guardrail"], str)
|
||||||
and "." in litellm_params["guardrail"]
|
and "." in litellm_params["guardrail"]
|
||||||
):
|
):
|
||||||
|
import os
|
||||||
|
|
||||||
|
from litellm.proxy.utils import get_instance_fn
|
||||||
|
|
||||||
# Custom guardrail
|
# Custom guardrail
|
||||||
_guardrail = litellm_params["guardrail"]
|
_guardrail = litellm_params["guardrail"]
|
||||||
_file_name, _class_name = _guardrail.split(".")
|
_file_name, _class_name = _guardrail.split(".")
|
||||||
|
@ -175,7 +182,21 @@ def init_guardrails_v2(all_guardrails: dict):
|
||||||
_file_name,
|
_file_name,
|
||||||
_class_name,
|
_class_name,
|
||||||
)
|
)
|
||||||
_guardrail_class = getattr(importlib.import_module(_file_name), _class_name)
|
|
||||||
|
directory = os.path.dirname(config_file_path)
|
||||||
|
module_file_path = os.path.join(directory, _file_name)
|
||||||
|
module_file_path += ".py"
|
||||||
|
|
||||||
|
spec = importlib.util.spec_from_file_location(_class_name, module_file_path) # type: ignore
|
||||||
|
if spec is None:
|
||||||
|
raise ImportError(
|
||||||
|
f"Could not find a module specification for {module_file_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
module = importlib.util.module_from_spec(spec) # type: ignore
|
||||||
|
spec.loader.exec_module(module) # type: ignore
|
||||||
|
_guardrail_class = getattr(module, _class_name)
|
||||||
|
|
||||||
_guardrail_callback = _guardrail_class(
|
_guardrail_callback = _guardrail_class(
|
||||||
guardrail_name=guardrail["guardrail_name"],
|
guardrail_name=guardrail["guardrail_name"],
|
||||||
event_hook=litellm_params["mode"],
|
event_hook=litellm_params["mode"],
|
||||||
|
|
|
@ -1959,7 +1959,9 @@ class ProxyConfig:
|
||||||
# Guardrail settings
|
# Guardrail settings
|
||||||
guardrails_v2 = config.get("guardrails", None)
|
guardrails_v2 = config.get("guardrails", None)
|
||||||
if guardrails_v2:
|
if guardrails_v2:
|
||||||
init_guardrails_v2(all_guardrails=guardrails_v2)
|
init_guardrails_v2(
|
||||||
|
all_guardrails=guardrails_v2, config_file_path=config_file_path
|
||||||
|
)
|
||||||
return router, router.get_model_list(), general_settings
|
return router, router.get_model_list(), general_settings
|
||||||
|
|
||||||
def get_model_info_with_id(self, model, db_model=False) -> RouterModelInfo:
|
def get_model_info_with_id(self, model, db_model=False) -> RouterModelInfo:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue