working lakera ai during call hook

This commit is contained in:
Ishaan Jaff 2024-08-20 14:39:04 -07:00
parent 1a142053e5
commit 1fdebfb0b7
5 changed files with 28 additions and 22 deletions

View file

@ -1,4 +1,4 @@
from typing import Literal from typing import Literal, Optional
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
@ -7,9 +7,14 @@ from litellm.types.guardrails import GuardrailEventHooks
class CustomGuardrail(CustomLogger): class CustomGuardrail(CustomLogger):
def __init__(self, guardrail_name: str, event_hook: GuardrailEventHooks, **kwargs): def __init__(
self,
guardrail_name: Optional[str] = None,
event_hook: Optional[GuardrailEventHooks] = None,
**kwargs
):
self.guardrail_name = guardrail_name self.guardrail_name = guardrail_name
self.event_hook: GuardrailEventHooks = event_hook self.event_hook: Optional[GuardrailEventHooks] = event_hook
super().__init__(**kwargs) super().__init__(**kwargs)
def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool: def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool:

View file

@ -101,12 +101,8 @@ def initialize_callbacks_on_proxy(
openai_moderations_object = _ENTERPRISE_OpenAI_Moderation() openai_moderations_object = _ENTERPRISE_OpenAI_Moderation()
imported_list.append(openai_moderations_object) imported_list.append(openai_moderations_object)
elif isinstance(callback, str) and callback == "lakera_prompt_injection": elif isinstance(callback, str) and callback == "lakera_prompt_injection":
from enterprise.enterprise_hooks.lakera_ai import lakeraAI_Moderation from litellm.proxy.guardrails.guardrail_hooks.lakera_ai import (
lakeraAI_Moderation,
if premium_user != True:
raise Exception(
"Trying to use LakeraAI Prompt Injection"
+ CommonProxyErrors.not_premium_user.value
) )
init_params = {} init_params = {}
@ -119,12 +115,6 @@ def initialize_callbacks_on_proxy(
AporiaGuardrail, AporiaGuardrail,
) )
if premium_user is not True:
raise Exception(
"Trying to use Aporia AI Guardrail"
+ CommonProxyErrors.not_premium_user.value
)
aporia_guardrail_object = AporiaGuardrail() aporia_guardrail_object = AporiaGuardrail()
imported_list.append(aporia_guardrail_object) imported_list.append(aporia_guardrail_object)
elif isinstance(callback, str) and callback == "google_text_moderation": elif isinstance(callback, str) and callback == "google_text_moderation":
@ -305,7 +295,11 @@ def get_applied_guardrails_header(request_data: Dict) -> Optional[Dict]:
return None return None
def add_guardrail_to_applied_guardrails_header(request_data: Dict, guardrail_name: str): def add_guardrail_to_applied_guardrails_header(
request_data: Dict, guardrail_name: Optional[str]
):
if guardrail_name is None:
return
_metadata = request_data.get("metadata", None) or {} _metadata = request_data.get("metadata", None) or {}
if "applied_guardrails" in _metadata: if "applied_guardrails" in _metadata:
_metadata["applied_guardrails"].append(guardrail_name) _metadata["applied_guardrails"].append(guardrail_name)

View file

@ -301,6 +301,7 @@ class lakeraAI_Moderation(CustomGuardrail):
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"], call_type: Literal["completion", "embeddings", "image_generation"],
): ):
if self.event_hook is None:
if self.moderation_check == "pre_call": if self.moderation_check == "pre_call":
return return

View file

@ -129,11 +129,17 @@ def init_guardrails_v2(all_guardrails: dict):
lakeraAI_Moderation, lakeraAI_Moderation,
) )
_lakera_callback = lakeraAI_Moderation() _lakera_callback = lakeraAI_Moderation(
api_base=litellm_params["api_base"],
api_key=litellm_params["api_key"],
guardrail_name=guardrail["guardrail_name"],
event_hook=litellm_params["mode"],
)
litellm.callbacks.append(_lakera_callback) # type: ignore litellm.callbacks.append(_lakera_callback) # type: ignore
parsed_guardrail = Guardrail( parsed_guardrail = Guardrail(
guardrail_name=guardrail["guardrail_name"], litellm_params=litellm_params guardrail_name=guardrail["guardrail_name"],
litellm_params=litellm_params,
) )
guardrail_list.append(parsed_guardrail) guardrail_list.append(parsed_guardrail)

View file

@ -15,7 +15,7 @@ guardrails:
- guardrail_name: "lakera-pre-guard" - guardrail_name: "lakera-pre-guard"
litellm_params: litellm_params:
guardrail: lakera # supported values: "aporia", "bedrock", "lakera" guardrail: lakera # supported values: "aporia", "bedrock", "lakera"
mode: "pre_call" mode: "during_call"
api_key: os.environ/LAKERA_API_KEY api_key: os.environ/LAKERA_API_KEY
api_base: os.environ/LAKERA_API_BASE api_base: os.environ/LAKERA_API_BASE