mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
rename lakera ai
This commit is contained in:
parent
042350bd74
commit
cad0352f76
5 changed files with 20 additions and 26 deletions
|
@ -42,7 +42,7 @@ class LakeraCategories(TypedDict, total=False):
|
||||||
prompt_injection: float
|
prompt_injection: float
|
||||||
|
|
||||||
|
|
||||||
class _ENTERPRISE_lakeraAI_Moderation(CustomLogger):
|
class lakeraAI_Moderation(CustomLogger):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
moderation_check: Literal["pre_call", "in_parallel"] = "in_parallel",
|
moderation_check: Literal["pre_call", "in_parallel"] = "in_parallel",
|
||||||
|
|
|
@ -101,9 +101,7 @@ def initialize_callbacks_on_proxy(
|
||||||
openai_moderations_object = _ENTERPRISE_OpenAI_Moderation()
|
openai_moderations_object = _ENTERPRISE_OpenAI_Moderation()
|
||||||
imported_list.append(openai_moderations_object)
|
imported_list.append(openai_moderations_object)
|
||||||
elif isinstance(callback, str) and callback == "lakera_prompt_injection":
|
elif isinstance(callback, str) and callback == "lakera_prompt_injection":
|
||||||
from enterprise.enterprise_hooks.lakera_ai import (
|
from enterprise.enterprise_hooks.lakera_ai import lakeraAI_Moderation
|
||||||
_ENTERPRISE_lakeraAI_Moderation,
|
|
||||||
)
|
|
||||||
|
|
||||||
if premium_user != True:
|
if premium_user != True:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
@ -114,9 +112,7 @@ def initialize_callbacks_on_proxy(
|
||||||
init_params = {}
|
init_params = {}
|
||||||
if "lakera_prompt_injection" in callback_specific_params:
|
if "lakera_prompt_injection" in callback_specific_params:
|
||||||
init_params = callback_specific_params["lakera_prompt_injection"]
|
init_params = callback_specific_params["lakera_prompt_injection"]
|
||||||
lakera_moderations_object = _ENTERPRISE_lakeraAI_Moderation(
|
lakera_moderations_object = lakeraAI_Moderation(**init_params)
|
||||||
**init_params
|
|
||||||
)
|
|
||||||
imported_list.append(lakera_moderations_object)
|
imported_list.append(lakera_moderations_object)
|
||||||
elif isinstance(callback, str) and callback == "aporia_prompt_injection":
|
elif isinstance(callback, str) and callback == "aporia_prompt_injection":
|
||||||
from litellm.proxy.guardrails.guardrail_hooks.aporia_ai import (
|
from litellm.proxy.guardrails.guardrail_hooks.aporia_ai import (
|
||||||
|
|
|
@ -126,10 +126,10 @@ def init_guardrails_v2(all_guardrails: dict):
|
||||||
litellm.callbacks.append(_aporia_callback) # type: ignore
|
litellm.callbacks.append(_aporia_callback) # type: ignore
|
||||||
elif litellm_params["guardrail"] == "lakera":
|
elif litellm_params["guardrail"] == "lakera":
|
||||||
from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import (
|
from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import (
|
||||||
_ENTERPRISE_lakeraAI_Moderation,
|
lakeraAI_Moderation,
|
||||||
)
|
)
|
||||||
|
|
||||||
_lakera_callback = _ENTERPRISE_lakeraAI_Moderation()
|
_lakera_callback = lakeraAI_Moderation()
|
||||||
litellm.callbacks.append(_lakera_callback) # type: ignore
|
litellm.callbacks.append(_lakera_callback) # type: ignore
|
||||||
|
|
||||||
parsed_guardrail = Guardrail(
|
parsed_guardrail = Guardrail(
|
||||||
|
|
|
@ -27,9 +27,7 @@ import litellm
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import (
|
from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import lakeraAI_Moderation
|
||||||
_ENTERPRISE_lakeraAI_Moderation,
|
|
||||||
)
|
|
||||||
from litellm.proxy.proxy_server import embeddings
|
from litellm.proxy.proxy_server import embeddings
|
||||||
from litellm.proxy.utils import ProxyLogging, hash_token
|
from litellm.proxy.utils import ProxyLogging, hash_token
|
||||||
|
|
||||||
|
@ -62,7 +60,7 @@ async def test_lakera_prompt_injection_detection():
|
||||||
Tests to see OpenAI Moderation raises an error for a flagged response
|
Tests to see OpenAI Moderation raises an error for a flagged response
|
||||||
"""
|
"""
|
||||||
|
|
||||||
lakera_ai = _ENTERPRISE_lakeraAI_Moderation()
|
lakera_ai = lakeraAI_Moderation()
|
||||||
_api_key = "sk-12345"
|
_api_key = "sk-12345"
|
||||||
_api_key = hash_token("sk-12345")
|
_api_key = hash_token("sk-12345")
|
||||||
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
|
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
|
||||||
|
@ -106,7 +104,7 @@ async def test_lakera_safe_prompt():
|
||||||
Nothing should get raised here
|
Nothing should get raised here
|
||||||
"""
|
"""
|
||||||
|
|
||||||
lakera_ai = _ENTERPRISE_lakeraAI_Moderation()
|
lakera_ai = lakeraAI_Moderation()
|
||||||
_api_key = "sk-12345"
|
_api_key = "sk-12345"
|
||||||
_api_key = hash_token("sk-12345")
|
_api_key = hash_token("sk-12345")
|
||||||
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
|
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
|
||||||
|
@ -144,7 +142,7 @@ async def test_moderations_on_embeddings():
|
||||||
setattr(litellm.proxy.proxy_server, "llm_router", temp_router)
|
setattr(litellm.proxy.proxy_server, "llm_router", temp_router)
|
||||||
|
|
||||||
api_route = APIRoute(path="/embeddings", endpoint=embeddings)
|
api_route = APIRoute(path="/embeddings", endpoint=embeddings)
|
||||||
litellm.callbacks = [_ENTERPRISE_lakeraAI_Moderation()]
|
litellm.callbacks = [lakeraAI_Moderation()]
|
||||||
request = Request(
|
request = Request(
|
||||||
{
|
{
|
||||||
"type": "http",
|
"type": "http",
|
||||||
|
@ -189,7 +187,7 @@ async def test_moderations_on_embeddings():
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
async def test_messages_for_disabled_role(spy_post):
|
async def test_messages_for_disabled_role(spy_post):
|
||||||
moderation = _ENTERPRISE_lakeraAI_Moderation()
|
moderation = lakeraAI_Moderation()
|
||||||
data = {
|
data = {
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "assistant", "content": "This should be ignored."},
|
{"role": "assistant", "content": "This should be ignored."},
|
||||||
|
@ -227,7 +225,7 @@ async def test_messages_for_disabled_role(spy_post):
|
||||||
)
|
)
|
||||||
@patch("litellm.add_function_to_prompt", False)
|
@patch("litellm.add_function_to_prompt", False)
|
||||||
async def test_system_message_with_function_input(spy_post):
|
async def test_system_message_with_function_input(spy_post):
|
||||||
moderation = _ENTERPRISE_lakeraAI_Moderation()
|
moderation = lakeraAI_Moderation()
|
||||||
data = {
|
data = {
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": "Initial content."},
|
{"role": "system", "content": "Initial content."},
|
||||||
|
@ -271,7 +269,7 @@ async def test_system_message_with_function_input(spy_post):
|
||||||
)
|
)
|
||||||
@patch("litellm.add_function_to_prompt", False)
|
@patch("litellm.add_function_to_prompt", False)
|
||||||
async def test_multi_message_with_function_input(spy_post):
|
async def test_multi_message_with_function_input(spy_post):
|
||||||
moderation = _ENTERPRISE_lakeraAI_Moderation()
|
moderation = lakeraAI_Moderation()
|
||||||
data = {
|
data = {
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{
|
||||||
|
@ -318,7 +316,7 @@ async def test_multi_message_with_function_input(spy_post):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
async def test_message_ordering(spy_post):
|
async def test_message_ordering(spy_post):
|
||||||
moderation = _ENTERPRISE_lakeraAI_Moderation()
|
moderation = lakeraAI_Moderation()
|
||||||
data = {
|
data = {
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "assistant", "content": "Assistant message."},
|
{"role": "assistant", "content": "Assistant message."},
|
||||||
|
@ -347,7 +345,7 @@ async def test_callback_specific_param_run_pre_call_check_lakera():
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from enterprise.enterprise_hooks.lakera_ai import _ENTERPRISE_lakeraAI_Moderation
|
from enterprise.enterprise_hooks.lakera_ai import lakeraAI_Moderation
|
||||||
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails
|
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails
|
||||||
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
|
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
|
||||||
|
|
||||||
|
@ -374,10 +372,10 @@ async def test_callback_specific_param_run_pre_call_check_lakera():
|
||||||
|
|
||||||
assert len(litellm.guardrail_name_config_map) == 1
|
assert len(litellm.guardrail_name_config_map) == 1
|
||||||
|
|
||||||
prompt_injection_obj: Optional[_ENTERPRISE_lakeraAI_Moderation] = None
|
prompt_injection_obj: Optional[lakeraAI_Moderation] = None
|
||||||
print("litellm callbacks={}".format(litellm.callbacks))
|
print("litellm callbacks={}".format(litellm.callbacks))
|
||||||
for callback in litellm.callbacks:
|
for callback in litellm.callbacks:
|
||||||
if isinstance(callback, _ENTERPRISE_lakeraAI_Moderation):
|
if isinstance(callback, lakeraAI_Moderation):
|
||||||
prompt_injection_obj = callback
|
prompt_injection_obj = callback
|
||||||
else:
|
else:
|
||||||
print("Type of callback={}".format(type(callback)))
|
print("Type of callback={}".format(type(callback)))
|
||||||
|
@ -393,7 +391,7 @@ async def test_callback_specific_thresholds():
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from enterprise.enterprise_hooks.lakera_ai import _ENTERPRISE_lakeraAI_Moderation
|
from enterprise.enterprise_hooks.lakera_ai import lakeraAI_Moderation
|
||||||
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails
|
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails
|
||||||
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
|
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
|
||||||
|
|
||||||
|
@ -426,10 +424,10 @@ async def test_callback_specific_thresholds():
|
||||||
|
|
||||||
assert len(litellm.guardrail_name_config_map) == 1
|
assert len(litellm.guardrail_name_config_map) == 1
|
||||||
|
|
||||||
prompt_injection_obj: Optional[_ENTERPRISE_lakeraAI_Moderation] = None
|
prompt_injection_obj: Optional[lakeraAI_Moderation] = None
|
||||||
print("litellm callbacks={}".format(litellm.callbacks))
|
print("litellm callbacks={}".format(litellm.callbacks))
|
||||||
for callback in litellm.callbacks:
|
for callback in litellm.callbacks:
|
||||||
if isinstance(callback, _ENTERPRISE_lakeraAI_Moderation):
|
if isinstance(callback, lakeraAI_Moderation):
|
||||||
prompt_injection_obj = callback
|
prompt_injection_obj = callback
|
||||||
else:
|
else:
|
||||||
print("Type of callback={}".format(type(callback)))
|
print("Type of callback={}".format(type(callback)))
|
||||||
|
|
|
@ -48,7 +48,7 @@ def test_active_callbacks(client):
|
||||||
_active_callbacks = json_response["litellm.callbacks"]
|
_active_callbacks = json_response["litellm.callbacks"]
|
||||||
|
|
||||||
expected_callback_names = [
|
expected_callback_names = [
|
||||||
"_ENTERPRISE_lakeraAI_Moderation",
|
"lakeraAI_Moderation",
|
||||||
"_OPTIONAL_PromptInjectionDetectio",
|
"_OPTIONAL_PromptInjectionDetectio",
|
||||||
"_ENTERPRISE_SecretDetection",
|
"_ENTERPRISE_SecretDetection",
|
||||||
]
|
]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue