mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
working lakera ai during call hook
This commit is contained in:
parent
9d809e8404
commit
cdbd245c3d
5 changed files with 28 additions and 22 deletions
|
@ -1,4 +1,4 @@
|
||||||
from typing import Literal
|
from typing import Literal, Optional
|
||||||
|
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
@ -7,9 +7,14 @@ from litellm.types.guardrails import GuardrailEventHooks
|
||||||
|
|
||||||
class CustomGuardrail(CustomLogger):
|
class CustomGuardrail(CustomLogger):
|
||||||
|
|
||||||
def __init__(self, guardrail_name: str, event_hook: GuardrailEventHooks, **kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
guardrail_name: Optional[str] = None,
|
||||||
|
event_hook: Optional[GuardrailEventHooks] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
self.guardrail_name = guardrail_name
|
self.guardrail_name = guardrail_name
|
||||||
self.event_hook: GuardrailEventHooks = event_hook
|
self.event_hook: Optional[GuardrailEventHooks] = event_hook
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool:
|
def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool:
|
||||||
|
|
|
@ -101,13 +101,9 @@ def initialize_callbacks_on_proxy(
|
||||||
openai_moderations_object = _ENTERPRISE_OpenAI_Moderation()
|
openai_moderations_object = _ENTERPRISE_OpenAI_Moderation()
|
||||||
imported_list.append(openai_moderations_object)
|
imported_list.append(openai_moderations_object)
|
||||||
elif isinstance(callback, str) and callback == "lakera_prompt_injection":
|
elif isinstance(callback, str) and callback == "lakera_prompt_injection":
|
||||||
from enterprise.enterprise_hooks.lakera_ai import lakeraAI_Moderation
|
from litellm.proxy.guardrails.guardrail_hooks.lakera_ai import (
|
||||||
|
lakeraAI_Moderation,
|
||||||
if premium_user != True:
|
)
|
||||||
raise Exception(
|
|
||||||
"Trying to use LakeraAI Prompt Injection"
|
|
||||||
+ CommonProxyErrors.not_premium_user.value
|
|
||||||
)
|
|
||||||
|
|
||||||
init_params = {}
|
init_params = {}
|
||||||
if "lakera_prompt_injection" in callback_specific_params:
|
if "lakera_prompt_injection" in callback_specific_params:
|
||||||
|
@ -119,12 +115,6 @@ def initialize_callbacks_on_proxy(
|
||||||
AporiaGuardrail,
|
AporiaGuardrail,
|
||||||
)
|
)
|
||||||
|
|
||||||
if premium_user is not True:
|
|
||||||
raise Exception(
|
|
||||||
"Trying to use Aporia AI Guardrail"
|
|
||||||
+ CommonProxyErrors.not_premium_user.value
|
|
||||||
)
|
|
||||||
|
|
||||||
aporia_guardrail_object = AporiaGuardrail()
|
aporia_guardrail_object = AporiaGuardrail()
|
||||||
imported_list.append(aporia_guardrail_object)
|
imported_list.append(aporia_guardrail_object)
|
||||||
elif isinstance(callback, str) and callback == "google_text_moderation":
|
elif isinstance(callback, str) and callback == "google_text_moderation":
|
||||||
|
@ -305,7 +295,11 @@ def get_applied_guardrails_header(request_data: Dict) -> Optional[Dict]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def add_guardrail_to_applied_guardrails_header(request_data: Dict, guardrail_name: str):
|
def add_guardrail_to_applied_guardrails_header(
|
||||||
|
request_data: Dict, guardrail_name: Optional[str]
|
||||||
|
):
|
||||||
|
if guardrail_name is None:
|
||||||
|
return
|
||||||
_metadata = request_data.get("metadata", None) or {}
|
_metadata = request_data.get("metadata", None) or {}
|
||||||
if "applied_guardrails" in _metadata:
|
if "applied_guardrails" in _metadata:
|
||||||
_metadata["applied_guardrails"].append(guardrail_name)
|
_metadata["applied_guardrails"].append(guardrail_name)
|
||||||
|
|
|
@ -301,8 +301,9 @@ class lakeraAI_Moderation(CustomGuardrail):
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
call_type: Literal["completion", "embeddings", "image_generation"],
|
call_type: Literal["completion", "embeddings", "image_generation"],
|
||||||
):
|
):
|
||||||
if self.moderation_check == "pre_call":
|
if self.event_hook is None:
|
||||||
return
|
if self.moderation_check == "pre_call":
|
||||||
|
return
|
||||||
|
|
||||||
from litellm.types.guardrails import GuardrailEventHooks
|
from litellm.types.guardrails import GuardrailEventHooks
|
||||||
|
|
||||||
|
|
|
@ -129,11 +129,17 @@ def init_guardrails_v2(all_guardrails: dict):
|
||||||
lakeraAI_Moderation,
|
lakeraAI_Moderation,
|
||||||
)
|
)
|
||||||
|
|
||||||
_lakera_callback = lakeraAI_Moderation()
|
_lakera_callback = lakeraAI_Moderation(
|
||||||
|
api_base=litellm_params["api_base"],
|
||||||
|
api_key=litellm_params["api_key"],
|
||||||
|
guardrail_name=guardrail["guardrail_name"],
|
||||||
|
event_hook=litellm_params["mode"],
|
||||||
|
)
|
||||||
litellm.callbacks.append(_lakera_callback) # type: ignore
|
litellm.callbacks.append(_lakera_callback) # type: ignore
|
||||||
|
|
||||||
parsed_guardrail = Guardrail(
|
parsed_guardrail = Guardrail(
|
||||||
guardrail_name=guardrail["guardrail_name"], litellm_params=litellm_params
|
guardrail_name=guardrail["guardrail_name"],
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
guardrail_list.append(parsed_guardrail)
|
guardrail_list.append(parsed_guardrail)
|
||||||
|
|
|
@ -15,7 +15,7 @@ guardrails:
|
||||||
- guardrail_name: "lakera-pre-guard"
|
- guardrail_name: "lakera-pre-guard"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
guardrail: lakera # supported values: "aporia", "bedrock", "lakera"
|
guardrail: lakera # supported values: "aporia", "bedrock", "lakera"
|
||||||
mode: "pre_call"
|
mode: "during_call"
|
||||||
api_key: os.environ/LAKERA_API_KEY
|
api_key: os.environ/LAKERA_API_KEY
|
||||||
api_base: os.environ/LAKERA_API_BASE
|
api_base: os.environ/LAKERA_API_BASE
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue