rename lakera ai

This commit is contained in:
Ishaan Jaff 2024-08-20 13:44:39 -07:00
parent 042350bd74
commit cad0352f76
5 changed files with 20 additions and 26 deletions

View file

@ -42,7 +42,7 @@ class LakeraCategories(TypedDict, total=False):
prompt_injection: float
class _ENTERPRISE_lakeraAI_Moderation(CustomLogger):
class lakeraAI_Moderation(CustomLogger):
def __init__(
self,
moderation_check: Literal["pre_call", "in_parallel"] = "in_parallel",

View file

@ -101,9 +101,7 @@ def initialize_callbacks_on_proxy(
openai_moderations_object = _ENTERPRISE_OpenAI_Moderation()
imported_list.append(openai_moderations_object)
elif isinstance(callback, str) and callback == "lakera_prompt_injection":
from enterprise.enterprise_hooks.lakera_ai import (
_ENTERPRISE_lakeraAI_Moderation,
)
from enterprise.enterprise_hooks.lakera_ai import lakeraAI_Moderation
if premium_user != True:
raise Exception(
@ -114,9 +112,7 @@ def initialize_callbacks_on_proxy(
init_params = {}
if "lakera_prompt_injection" in callback_specific_params:
init_params = callback_specific_params["lakera_prompt_injection"]
lakera_moderations_object = _ENTERPRISE_lakeraAI_Moderation(
**init_params
)
lakera_moderations_object = lakeraAI_Moderation(**init_params)
imported_list.append(lakera_moderations_object)
elif isinstance(callback, str) and callback == "aporia_prompt_injection":
from litellm.proxy.guardrails.guardrail_hooks.aporia_ai import (

View file

@ -126,10 +126,10 @@ def init_guardrails_v2(all_guardrails: dict):
litellm.callbacks.append(_aporia_callback) # type: ignore
elif litellm_params["guardrail"] == "lakera":
from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import (
_ENTERPRISE_lakeraAI_Moderation,
lakeraAI_Moderation,
)
_lakera_callback = _ENTERPRISE_lakeraAI_Moderation()
_lakera_callback = lakeraAI_Moderation()
litellm.callbacks.append(_lakera_callback) # type: ignore
parsed_guardrail = Guardrail(

View file

@ -27,9 +27,7 @@ import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import (
_ENTERPRISE_lakeraAI_Moderation,
)
from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import lakeraAI_Moderation
from litellm.proxy.proxy_server import embeddings
from litellm.proxy.utils import ProxyLogging, hash_token
@ -62,7 +60,7 @@ async def test_lakera_prompt_injection_detection():
Tests to see OpenAI Moderation raises an error for a flagged response
"""
lakera_ai = _ENTERPRISE_lakeraAI_Moderation()
lakera_ai = lakeraAI_Moderation()
_api_key = "sk-12345"
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
@ -106,7 +104,7 @@ async def test_lakera_safe_prompt():
Nothing should get raised here
"""
lakera_ai = _ENTERPRISE_lakeraAI_Moderation()
lakera_ai = lakeraAI_Moderation()
_api_key = "sk-12345"
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
@ -144,7 +142,7 @@ async def test_moderations_on_embeddings():
setattr(litellm.proxy.proxy_server, "llm_router", temp_router)
api_route = APIRoute(path="/embeddings", endpoint=embeddings)
litellm.callbacks = [_ENTERPRISE_lakeraAI_Moderation()]
litellm.callbacks = [lakeraAI_Moderation()]
request = Request(
{
"type": "http",
@ -189,7 +187,7 @@ async def test_moderations_on_embeddings():
),
)
async def test_messages_for_disabled_role(spy_post):
moderation = _ENTERPRISE_lakeraAI_Moderation()
moderation = lakeraAI_Moderation()
data = {
"messages": [
{"role": "assistant", "content": "This should be ignored."},
@ -227,7 +225,7 @@ async def test_messages_for_disabled_role(spy_post):
)
@patch("litellm.add_function_to_prompt", False)
async def test_system_message_with_function_input(spy_post):
moderation = _ENTERPRISE_lakeraAI_Moderation()
moderation = lakeraAI_Moderation()
data = {
"messages": [
{"role": "system", "content": "Initial content."},
@ -271,7 +269,7 @@ async def test_system_message_with_function_input(spy_post):
)
@patch("litellm.add_function_to_prompt", False)
async def test_multi_message_with_function_input(spy_post):
moderation = _ENTERPRISE_lakeraAI_Moderation()
moderation = lakeraAI_Moderation()
data = {
"messages": [
{
@ -318,7 +316,7 @@ async def test_multi_message_with_function_input(spy_post):
),
)
async def test_message_ordering(spy_post):
moderation = _ENTERPRISE_lakeraAI_Moderation()
moderation = lakeraAI_Moderation()
data = {
"messages": [
{"role": "assistant", "content": "Assistant message."},
@ -347,7 +345,7 @@ async def test_callback_specific_param_run_pre_call_check_lakera():
from typing import Dict, List, Optional, Union
import litellm
from enterprise.enterprise_hooks.lakera_ai import _ENTERPRISE_lakeraAI_Moderation
from enterprise.enterprise_hooks.lakera_ai import lakeraAI_Moderation
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
@ -374,10 +372,10 @@ async def test_callback_specific_param_run_pre_call_check_lakera():
assert len(litellm.guardrail_name_config_map) == 1
prompt_injection_obj: Optional[_ENTERPRISE_lakeraAI_Moderation] = None
prompt_injection_obj: Optional[lakeraAI_Moderation] = None
print("litellm callbacks={}".format(litellm.callbacks))
for callback in litellm.callbacks:
if isinstance(callback, _ENTERPRISE_lakeraAI_Moderation):
if isinstance(callback, lakeraAI_Moderation):
prompt_injection_obj = callback
else:
print("Type of callback={}".format(type(callback)))
@ -393,7 +391,7 @@ async def test_callback_specific_thresholds():
from typing import Dict, List, Optional, Union
import litellm
from enterprise.enterprise_hooks.lakera_ai import _ENTERPRISE_lakeraAI_Moderation
from enterprise.enterprise_hooks.lakera_ai import lakeraAI_Moderation
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
@ -426,10 +424,10 @@ async def test_callback_specific_thresholds():
assert len(litellm.guardrail_name_config_map) == 1
prompt_injection_obj: Optional[_ENTERPRISE_lakeraAI_Moderation] = None
prompt_injection_obj: Optional[lakeraAI_Moderation] = None
print("litellm callbacks={}".format(litellm.callbacks))
for callback in litellm.callbacks:
if isinstance(callback, _ENTERPRISE_lakeraAI_Moderation):
if isinstance(callback, lakeraAI_Moderation):
prompt_injection_obj = callback
else:
print("Type of callback={}".format(type(callback)))

View file

@ -48,7 +48,7 @@ def test_active_callbacks(client):
_active_callbacks = json_response["litellm.callbacks"]
expected_callback_names = [
"_ENTERPRISE_lakeraAI_Moderation",
"lakeraAI_Moderation",
"_OPTIONAL_PromptInjectionDetectio",
"_ENTERPRISE_SecretDetection",
]