forked from phoenix/litellm-mirror
rename lakera ai
This commit is contained in:
parent
042350bd74
commit
cad0352f76
5 changed files with 20 additions and 26 deletions
|
@ -42,7 +42,7 @@ class LakeraCategories(TypedDict, total=False):
|
|||
prompt_injection: float
|
||||
|
||||
|
||||
class _ENTERPRISE_lakeraAI_Moderation(CustomLogger):
|
||||
class lakeraAI_Moderation(CustomLogger):
|
||||
def __init__(
|
||||
self,
|
||||
moderation_check: Literal["pre_call", "in_parallel"] = "in_parallel",
|
||||
|
|
|
@ -101,9 +101,7 @@ def initialize_callbacks_on_proxy(
|
|||
openai_moderations_object = _ENTERPRISE_OpenAI_Moderation()
|
||||
imported_list.append(openai_moderations_object)
|
||||
elif isinstance(callback, str) and callback == "lakera_prompt_injection":
|
||||
from enterprise.enterprise_hooks.lakera_ai import (
|
||||
_ENTERPRISE_lakeraAI_Moderation,
|
||||
)
|
||||
from enterprise.enterprise_hooks.lakera_ai import lakeraAI_Moderation
|
||||
|
||||
if premium_user != True:
|
||||
raise Exception(
|
||||
|
@ -114,9 +112,7 @@ def initialize_callbacks_on_proxy(
|
|||
init_params = {}
|
||||
if "lakera_prompt_injection" in callback_specific_params:
|
||||
init_params = callback_specific_params["lakera_prompt_injection"]
|
||||
lakera_moderations_object = _ENTERPRISE_lakeraAI_Moderation(
|
||||
**init_params
|
||||
)
|
||||
lakera_moderations_object = lakeraAI_Moderation(**init_params)
|
||||
imported_list.append(lakera_moderations_object)
|
||||
elif isinstance(callback, str) and callback == "aporia_prompt_injection":
|
||||
from litellm.proxy.guardrails.guardrail_hooks.aporia_ai import (
|
||||
|
|
|
@ -126,10 +126,10 @@ def init_guardrails_v2(all_guardrails: dict):
|
|||
litellm.callbacks.append(_aporia_callback) # type: ignore
|
||||
elif litellm_params["guardrail"] == "lakera":
|
||||
from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import (
|
||||
_ENTERPRISE_lakeraAI_Moderation,
|
||||
lakeraAI_Moderation,
|
||||
)
|
||||
|
||||
_lakera_callback = _ENTERPRISE_lakeraAI_Moderation()
|
||||
_lakera_callback = lakeraAI_Moderation()
|
||||
litellm.callbacks.append(_lakera_callback) # type: ignore
|
||||
|
||||
parsed_guardrail = Guardrail(
|
||||
|
|
|
@ -27,9 +27,7 @@ import litellm
|
|||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import (
|
||||
_ENTERPRISE_lakeraAI_Moderation,
|
||||
)
|
||||
from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import lakeraAI_Moderation
|
||||
from litellm.proxy.proxy_server import embeddings
|
||||
from litellm.proxy.utils import ProxyLogging, hash_token
|
||||
|
||||
|
@ -62,7 +60,7 @@ async def test_lakera_prompt_injection_detection():
|
|||
Tests to see OpenAI Moderation raises an error for a flagged response
|
||||
"""
|
||||
|
||||
lakera_ai = _ENTERPRISE_lakeraAI_Moderation()
|
||||
lakera_ai = lakeraAI_Moderation()
|
||||
_api_key = "sk-12345"
|
||||
_api_key = hash_token("sk-12345")
|
||||
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
|
||||
|
@ -106,7 +104,7 @@ async def test_lakera_safe_prompt():
|
|||
Nothing should get raised here
|
||||
"""
|
||||
|
||||
lakera_ai = _ENTERPRISE_lakeraAI_Moderation()
|
||||
lakera_ai = lakeraAI_Moderation()
|
||||
_api_key = "sk-12345"
|
||||
_api_key = hash_token("sk-12345")
|
||||
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
|
||||
|
@ -144,7 +142,7 @@ async def test_moderations_on_embeddings():
|
|||
setattr(litellm.proxy.proxy_server, "llm_router", temp_router)
|
||||
|
||||
api_route = APIRoute(path="/embeddings", endpoint=embeddings)
|
||||
litellm.callbacks = [_ENTERPRISE_lakeraAI_Moderation()]
|
||||
litellm.callbacks = [lakeraAI_Moderation()]
|
||||
request = Request(
|
||||
{
|
||||
"type": "http",
|
||||
|
@ -189,7 +187,7 @@ async def test_moderations_on_embeddings():
|
|||
),
|
||||
)
|
||||
async def test_messages_for_disabled_role(spy_post):
|
||||
moderation = _ENTERPRISE_lakeraAI_Moderation()
|
||||
moderation = lakeraAI_Moderation()
|
||||
data = {
|
||||
"messages": [
|
||||
{"role": "assistant", "content": "This should be ignored."},
|
||||
|
@ -227,7 +225,7 @@ async def test_messages_for_disabled_role(spy_post):
|
|||
)
|
||||
@patch("litellm.add_function_to_prompt", False)
|
||||
async def test_system_message_with_function_input(spy_post):
|
||||
moderation = _ENTERPRISE_lakeraAI_Moderation()
|
||||
moderation = lakeraAI_Moderation()
|
||||
data = {
|
||||
"messages": [
|
||||
{"role": "system", "content": "Initial content."},
|
||||
|
@ -271,7 +269,7 @@ async def test_system_message_with_function_input(spy_post):
|
|||
)
|
||||
@patch("litellm.add_function_to_prompt", False)
|
||||
async def test_multi_message_with_function_input(spy_post):
|
||||
moderation = _ENTERPRISE_lakeraAI_Moderation()
|
||||
moderation = lakeraAI_Moderation()
|
||||
data = {
|
||||
"messages": [
|
||||
{
|
||||
|
@ -318,7 +316,7 @@ async def test_multi_message_with_function_input(spy_post):
|
|||
),
|
||||
)
|
||||
async def test_message_ordering(spy_post):
|
||||
moderation = _ENTERPRISE_lakeraAI_Moderation()
|
||||
moderation = lakeraAI_Moderation()
|
||||
data = {
|
||||
"messages": [
|
||||
{"role": "assistant", "content": "Assistant message."},
|
||||
|
@ -347,7 +345,7 @@ async def test_callback_specific_param_run_pre_call_check_lakera():
|
|||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import litellm
|
||||
from enterprise.enterprise_hooks.lakera_ai import _ENTERPRISE_lakeraAI_Moderation
|
||||
from enterprise.enterprise_hooks.lakera_ai import lakeraAI_Moderation
|
||||
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails
|
||||
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
|
||||
|
||||
|
@ -374,10 +372,10 @@ async def test_callback_specific_param_run_pre_call_check_lakera():
|
|||
|
||||
assert len(litellm.guardrail_name_config_map) == 1
|
||||
|
||||
prompt_injection_obj: Optional[_ENTERPRISE_lakeraAI_Moderation] = None
|
||||
prompt_injection_obj: Optional[lakeraAI_Moderation] = None
|
||||
print("litellm callbacks={}".format(litellm.callbacks))
|
||||
for callback in litellm.callbacks:
|
||||
if isinstance(callback, _ENTERPRISE_lakeraAI_Moderation):
|
||||
if isinstance(callback, lakeraAI_Moderation):
|
||||
prompt_injection_obj = callback
|
||||
else:
|
||||
print("Type of callback={}".format(type(callback)))
|
||||
|
@ -393,7 +391,7 @@ async def test_callback_specific_thresholds():
|
|||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import litellm
|
||||
from enterprise.enterprise_hooks.lakera_ai import _ENTERPRISE_lakeraAI_Moderation
|
||||
from enterprise.enterprise_hooks.lakera_ai import lakeraAI_Moderation
|
||||
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails
|
||||
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
|
||||
|
||||
|
@ -426,10 +424,10 @@ async def test_callback_specific_thresholds():
|
|||
|
||||
assert len(litellm.guardrail_name_config_map) == 1
|
||||
|
||||
prompt_injection_obj: Optional[_ENTERPRISE_lakeraAI_Moderation] = None
|
||||
prompt_injection_obj: Optional[lakeraAI_Moderation] = None
|
||||
print("litellm callbacks={}".format(litellm.callbacks))
|
||||
for callback in litellm.callbacks:
|
||||
if isinstance(callback, _ENTERPRISE_lakeraAI_Moderation):
|
||||
if isinstance(callback, lakeraAI_Moderation):
|
||||
prompt_injection_obj = callback
|
||||
else:
|
||||
print("Type of callback={}".format(type(callback)))
|
||||
|
|
|
@ -48,7 +48,7 @@ def test_active_callbacks(client):
|
|||
_active_callbacks = json_response["litellm.callbacks"]
|
||||
|
||||
expected_callback_names = [
|
||||
"_ENTERPRISE_lakeraAI_Moderation",
|
||||
"lakeraAI_Moderation",
|
||||
"_OPTIONAL_PromptInjectionDetectio",
|
||||
"_ENTERPRISE_SecretDetection",
|
||||
]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue