support lakera ai category thresholds

This commit is contained in:
Ishaan Jaff 2024-08-20 17:19:24 -07:00
parent 30da63bd4f
commit 8d2c529e55
5 changed files with 80 additions and 15 deletions

View file

@ -12,6 +12,7 @@ from litellm.types.guardrails import (
Guardrail,
GuardrailItem,
GuardrailItemSpec,
LakeraCategoryThresholds,
LitellmParams,
guardrailConfig,
)
@ -99,6 +100,15 @@ def init_guardrails_v2(all_guardrails: dict):
api_base=litellm_params_data["api_base"],
)
if (
"category_thresholds" in litellm_params_data
and litellm_params_data["category_thresholds"]
):
lakera_category_thresholds = LakeraCategoryThresholds(
**litellm_params_data["category_thresholds"]
)
litellm_params["category_thresholds"] = lakera_category_thresholds
if litellm_params["api_key"]:
if litellm_params["api_key"].startswith("os.environ/"):
litellm_params["api_key"] = litellm.get_secret(
@ -134,6 +144,7 @@ def init_guardrails_v2(all_guardrails: dict):
api_key=litellm_params["api_key"],
guardrail_name=guardrail["guardrail_name"],
event_hook=litellm_params["mode"],
category_thresholds=litellm_params.get("category_thresholds"),
)
litellm.callbacks.append(_lakera_callback) # type: ignore