rename lakera ai

This commit is contained in:
Ishaan Jaff 2024-08-20 13:44:39 -07:00
parent 042350bd74
commit cad0352f76
5 changed files with 20 additions and 26 deletions

View file

@ -42,7 +42,7 @@ class LakeraCategories(TypedDict, total=False):
prompt_injection: float prompt_injection: float
class _ENTERPRISE_lakeraAI_Moderation(CustomLogger): class lakeraAI_Moderation(CustomLogger):
def __init__( def __init__(
self, self,
moderation_check: Literal["pre_call", "in_parallel"] = "in_parallel", moderation_check: Literal["pre_call", "in_parallel"] = "in_parallel",

View file

@ -101,9 +101,7 @@ def initialize_callbacks_on_proxy(
openai_moderations_object = _ENTERPRISE_OpenAI_Moderation() openai_moderations_object = _ENTERPRISE_OpenAI_Moderation()
imported_list.append(openai_moderations_object) imported_list.append(openai_moderations_object)
elif isinstance(callback, str) and callback == "lakera_prompt_injection": elif isinstance(callback, str) and callback == "lakera_prompt_injection":
from enterprise.enterprise_hooks.lakera_ai import ( from enterprise.enterprise_hooks.lakera_ai import lakeraAI_Moderation
_ENTERPRISE_lakeraAI_Moderation,
)
if premium_user != True: if premium_user != True:
raise Exception( raise Exception(
@ -114,9 +112,7 @@ def initialize_callbacks_on_proxy(
init_params = {} init_params = {}
if "lakera_prompt_injection" in callback_specific_params: if "lakera_prompt_injection" in callback_specific_params:
init_params = callback_specific_params["lakera_prompt_injection"] init_params = callback_specific_params["lakera_prompt_injection"]
lakera_moderations_object = _ENTERPRISE_lakeraAI_Moderation( lakera_moderations_object = lakeraAI_Moderation(**init_params)
**init_params
)
imported_list.append(lakera_moderations_object) imported_list.append(lakera_moderations_object)
elif isinstance(callback, str) and callback == "aporia_prompt_injection": elif isinstance(callback, str) and callback == "aporia_prompt_injection":
from litellm.proxy.guardrails.guardrail_hooks.aporia_ai import ( from litellm.proxy.guardrails.guardrail_hooks.aporia_ai import (

View file

@ -126,10 +126,10 @@ def init_guardrails_v2(all_guardrails: dict):
litellm.callbacks.append(_aporia_callback) # type: ignore litellm.callbacks.append(_aporia_callback) # type: ignore
elif litellm_params["guardrail"] == "lakera": elif litellm_params["guardrail"] == "lakera":
from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import ( from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import (
_ENTERPRISE_lakeraAI_Moderation, lakeraAI_Moderation,
) )
_lakera_callback = _ENTERPRISE_lakeraAI_Moderation() _lakera_callback = lakeraAI_Moderation()
litellm.callbacks.append(_lakera_callback) # type: ignore litellm.callbacks.append(_lakera_callback) # type: ignore
parsed_guardrail = Guardrail( parsed_guardrail = Guardrail(

View file

@ -27,9 +27,7 @@ import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import ( from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import lakeraAI_Moderation
_ENTERPRISE_lakeraAI_Moderation,
)
from litellm.proxy.proxy_server import embeddings from litellm.proxy.proxy_server import embeddings
from litellm.proxy.utils import ProxyLogging, hash_token from litellm.proxy.utils import ProxyLogging, hash_token
@ -62,7 +60,7 @@ async def test_lakera_prompt_injection_detection():
Tests to see OpenAI Moderation raises an error for a flagged response Tests to see OpenAI Moderation raises an error for a flagged response
""" """
lakera_ai = _ENTERPRISE_lakeraAI_Moderation() lakera_ai = lakeraAI_Moderation()
_api_key = "sk-12345" _api_key = "sk-12345"
_api_key = hash_token("sk-12345") _api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key) user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
@ -106,7 +104,7 @@ async def test_lakera_safe_prompt():
Nothing should get raised here Nothing should get raised here
""" """
lakera_ai = _ENTERPRISE_lakeraAI_Moderation() lakera_ai = lakeraAI_Moderation()
_api_key = "sk-12345" _api_key = "sk-12345"
_api_key = hash_token("sk-12345") _api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key) user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
@ -144,7 +142,7 @@ async def test_moderations_on_embeddings():
setattr(litellm.proxy.proxy_server, "llm_router", temp_router) setattr(litellm.proxy.proxy_server, "llm_router", temp_router)
api_route = APIRoute(path="/embeddings", endpoint=embeddings) api_route = APIRoute(path="/embeddings", endpoint=embeddings)
litellm.callbacks = [_ENTERPRISE_lakeraAI_Moderation()] litellm.callbacks = [lakeraAI_Moderation()]
request = Request( request = Request(
{ {
"type": "http", "type": "http",
@ -189,7 +187,7 @@ async def test_moderations_on_embeddings():
), ),
) )
async def test_messages_for_disabled_role(spy_post): async def test_messages_for_disabled_role(spy_post):
moderation = _ENTERPRISE_lakeraAI_Moderation() moderation = lakeraAI_Moderation()
data = { data = {
"messages": [ "messages": [
{"role": "assistant", "content": "This should be ignored."}, {"role": "assistant", "content": "This should be ignored."},
@ -227,7 +225,7 @@ async def test_messages_for_disabled_role(spy_post):
) )
@patch("litellm.add_function_to_prompt", False) @patch("litellm.add_function_to_prompt", False)
async def test_system_message_with_function_input(spy_post): async def test_system_message_with_function_input(spy_post):
moderation = _ENTERPRISE_lakeraAI_Moderation() moderation = lakeraAI_Moderation()
data = { data = {
"messages": [ "messages": [
{"role": "system", "content": "Initial content."}, {"role": "system", "content": "Initial content."},
@ -271,7 +269,7 @@ async def test_system_message_with_function_input(spy_post):
) )
@patch("litellm.add_function_to_prompt", False) @patch("litellm.add_function_to_prompt", False)
async def test_multi_message_with_function_input(spy_post): async def test_multi_message_with_function_input(spy_post):
moderation = _ENTERPRISE_lakeraAI_Moderation() moderation = lakeraAI_Moderation()
data = { data = {
"messages": [ "messages": [
{ {
@ -318,7 +316,7 @@ async def test_multi_message_with_function_input(spy_post):
), ),
) )
async def test_message_ordering(spy_post): async def test_message_ordering(spy_post):
moderation = _ENTERPRISE_lakeraAI_Moderation() moderation = lakeraAI_Moderation()
data = { data = {
"messages": [ "messages": [
{"role": "assistant", "content": "Assistant message."}, {"role": "assistant", "content": "Assistant message."},
@ -347,7 +345,7 @@ async def test_callback_specific_param_run_pre_call_check_lakera():
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import litellm import litellm
from enterprise.enterprise_hooks.lakera_ai import _ENTERPRISE_lakeraAI_Moderation from enterprise.enterprise_hooks.lakera_ai import lakeraAI_Moderation
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails from litellm.proxy.guardrails.init_guardrails import initialize_guardrails
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
@ -374,10 +372,10 @@ async def test_callback_specific_param_run_pre_call_check_lakera():
assert len(litellm.guardrail_name_config_map) == 1 assert len(litellm.guardrail_name_config_map) == 1
prompt_injection_obj: Optional[_ENTERPRISE_lakeraAI_Moderation] = None prompt_injection_obj: Optional[lakeraAI_Moderation] = None
print("litellm callbacks={}".format(litellm.callbacks)) print("litellm callbacks={}".format(litellm.callbacks))
for callback in litellm.callbacks: for callback in litellm.callbacks:
if isinstance(callback, _ENTERPRISE_lakeraAI_Moderation): if isinstance(callback, lakeraAI_Moderation):
prompt_injection_obj = callback prompt_injection_obj = callback
else: else:
print("Type of callback={}".format(type(callback))) print("Type of callback={}".format(type(callback)))
@ -393,7 +391,7 @@ async def test_callback_specific_thresholds():
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import litellm import litellm
from enterprise.enterprise_hooks.lakera_ai import _ENTERPRISE_lakeraAI_Moderation from enterprise.enterprise_hooks.lakera_ai import lakeraAI_Moderation
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails from litellm.proxy.guardrails.init_guardrails import initialize_guardrails
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
@ -426,10 +424,10 @@ async def test_callback_specific_thresholds():
assert len(litellm.guardrail_name_config_map) == 1 assert len(litellm.guardrail_name_config_map) == 1
prompt_injection_obj: Optional[_ENTERPRISE_lakeraAI_Moderation] = None prompt_injection_obj: Optional[lakeraAI_Moderation] = None
print("litellm callbacks={}".format(litellm.callbacks)) print("litellm callbacks={}".format(litellm.callbacks))
for callback in litellm.callbacks: for callback in litellm.callbacks:
if isinstance(callback, _ENTERPRISE_lakeraAI_Moderation): if isinstance(callback, lakeraAI_Moderation):
prompt_injection_obj = callback prompt_injection_obj = callback
else: else:
print("Type of callback={}".format(type(callback))) print("Type of callback={}".format(type(callback)))

View file

@ -48,7 +48,7 @@ def test_active_callbacks(client):
_active_callbacks = json_response["litellm.callbacks"] _active_callbacks = json_response["litellm.callbacks"]
expected_callback_names = [ expected_callback_names = [
"_ENTERPRISE_lakeraAI_Moderation", "lakeraAI_Moderation",
"_OPTIONAL_PromptInjectionDetectio", "_OPTIONAL_PromptInjectionDetectio",
"_ENTERPRISE_SecretDetection", "_ENTERPRISE_SecretDetection",
] ]