init guardrails on proxy

This commit is contained in:
Ishaan Jaff 2024-07-03 14:18:12 -07:00
parent a2b6baab16
commit 129c2e0c4f
4 changed files with 302 additions and 250 deletions

View file

@ -0,0 +1,217 @@
from typing import Any, List, Optional, get_args
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import CommonProxyErrors, LiteLLMPromptInjectionParams
from litellm.proxy.utils import get_instance_fn
blue_color_code = "\033[94m"
reset_color_code = "\033[0m"
def initialize_callbacks_on_proxy(
value: Any,
premium_user: bool,
config_file_path: str,
litellm_settings: dict,
):
from litellm.proxy.proxy_server import prisma_client
verbose_proxy_logger.debug(
f"{blue_color_code}initializing callbacks={value} on proxy{reset_color_code}"
)
if isinstance(value, list):
imported_list: List[Any] = []
known_compatible_callbacks = list(
get_args(litellm._custom_logger_compatible_callbacks_literal)
)
for callback in value: # ["presidio", <my-custom-callback>]
if isinstance(callback, str) and callback in known_compatible_callbacks:
imported_list.append(callback)
elif isinstance(callback, str) and callback == "otel":
from litellm.integrations.opentelemetry import OpenTelemetry
open_telemetry_logger = OpenTelemetry()
imported_list.append(open_telemetry_logger)
elif isinstance(callback, str) and callback == "presidio":
from litellm.proxy.hooks.presidio_pii_masking import (
_OPTIONAL_PresidioPIIMasking,
)
pii_masking_object = _OPTIONAL_PresidioPIIMasking()
imported_list.append(pii_masking_object)
elif isinstance(callback, str) and callback == "llamaguard_moderations":
from enterprise.enterprise_hooks.llama_guard import (
_ENTERPRISE_LlamaGuard,
)
if premium_user != True:
raise Exception(
"Trying to use Llama Guard"
+ CommonProxyErrors.not_premium_user.value
)
llama_guard_object = _ENTERPRISE_LlamaGuard()
imported_list.append(llama_guard_object)
elif isinstance(callback, str) and callback == "hide_secrets":
from enterprise.enterprise_hooks.secret_detection import (
_ENTERPRISE_SecretDetection,
)
if premium_user != True:
raise Exception(
"Trying to use secret hiding"
+ CommonProxyErrors.not_premium_user.value
)
_secret_detection_object = _ENTERPRISE_SecretDetection()
imported_list.append(_secret_detection_object)
elif isinstance(callback, str) and callback == "openai_moderations":
from enterprise.enterprise_hooks.openai_moderation import (
_ENTERPRISE_OpenAI_Moderation,
)
if premium_user != True:
raise Exception(
"Trying to use OpenAI Moderations Check"
+ CommonProxyErrors.not_premium_user.value
)
openai_moderations_object = _ENTERPRISE_OpenAI_Moderation()
imported_list.append(openai_moderations_object)
elif isinstance(callback, str) and callback == "lakera_prompt_injection":
from enterprise.enterprise_hooks.lakera_ai import (
_ENTERPRISE_lakeraAI_Moderation,
)
if premium_user != True:
raise Exception(
"Trying to use LakeraAI Prompt Injection"
+ CommonProxyErrors.not_premium_user.value
)
lakera_moderations_object = _ENTERPRISE_lakeraAI_Moderation()
imported_list.append(lakera_moderations_object)
elif isinstance(callback, str) and callback == "google_text_moderation":
from enterprise.enterprise_hooks.google_text_moderation import (
_ENTERPRISE_GoogleTextModeration,
)
if premium_user != True:
raise Exception(
"Trying to use Google Text Moderation"
+ CommonProxyErrors.not_premium_user.value
)
google_text_moderation_obj = _ENTERPRISE_GoogleTextModeration()
imported_list.append(google_text_moderation_obj)
elif isinstance(callback, str) and callback == "llmguard_moderations":
from enterprise.enterprise_hooks.llm_guard import _ENTERPRISE_LLMGuard
if premium_user != True:
raise Exception(
"Trying to use Llm Guard"
+ CommonProxyErrors.not_premium_user.value
)
llm_guard_moderation_obj = _ENTERPRISE_LLMGuard()
imported_list.append(llm_guard_moderation_obj)
elif isinstance(callback, str) and callback == "blocked_user_check":
from enterprise.enterprise_hooks.blocked_user_list import (
_ENTERPRISE_BlockedUserList,
)
if premium_user != True:
raise Exception(
"Trying to use ENTERPRISE BlockedUser"
+ CommonProxyErrors.not_premium_user.value
)
blocked_user_list = _ENTERPRISE_BlockedUserList(
prisma_client=prisma_client
)
imported_list.append(blocked_user_list)
elif isinstance(callback, str) and callback == "banned_keywords":
from enterprise.enterprise_hooks.banned_keywords import (
_ENTERPRISE_BannedKeywords,
)
if premium_user != True:
raise Exception(
"Trying to use ENTERPRISE BannedKeyword"
+ CommonProxyErrors.not_premium_user.value
)
banned_keywords_obj = _ENTERPRISE_BannedKeywords()
imported_list.append(banned_keywords_obj)
elif isinstance(callback, str) and callback == "detect_prompt_injection":
from litellm.proxy.hooks.prompt_injection_detection import (
_OPTIONAL_PromptInjectionDetection,
)
prompt_injection_params = None
if "prompt_injection_params" in litellm_settings:
prompt_injection_params_in_config = litellm_settings[
"prompt_injection_params"
]
prompt_injection_params = LiteLLMPromptInjectionParams(
**prompt_injection_params_in_config
)
prompt_injection_detection_obj = _OPTIONAL_PromptInjectionDetection(
prompt_injection_params=prompt_injection_params,
)
imported_list.append(prompt_injection_detection_obj)
elif isinstance(callback, str) and callback == "batch_redis_requests":
from litellm.proxy.hooks.batch_redis_get import (
_PROXY_BatchRedisRequests,
)
batch_redis_obj = _PROXY_BatchRedisRequests()
imported_list.append(batch_redis_obj)
elif isinstance(callback, str) and callback == "azure_content_safety":
from litellm.proxy.hooks.azure_content_safety import (
_PROXY_AzureContentSafety,
)
azure_content_safety_params = litellm_settings[
"azure_content_safety_params"
]
for k, v in azure_content_safety_params.items():
if (
v is not None
and isinstance(v, str)
and v.startswith("os.environ/")
):
azure_content_safety_params[k] = litellm.get_secret(v)
azure_content_safety_obj = _PROXY_AzureContentSafety(
**azure_content_safety_params,
)
imported_list.append(azure_content_safety_obj)
else:
verbose_proxy_logger.debug(
f"{blue_color_code} attempting to import custom calback={callback} {reset_color_code}"
)
imported_list.append(
get_instance_fn(
value=callback,
config_file_path=config_file_path,
)
)
if isinstance(litellm.callbacks, list):
litellm.callbacks.extend(imported_list)
else:
litellm.callbacks = imported_list # type: ignore
else:
litellm.callbacks = [
get_instance_fn(
value=value,
config_file_path=config_file_path,
)
]
verbose_proxy_logger.debug(
f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}"
)

View file

@ -0,0 +1,56 @@
import traceback
from typing import Dict, List
from pydantic import BaseModel, RootModel
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy
from litellm.types.guardrails import GuardrailItem
def initialize_guardrails(
guardrails_config: list,
premium_user: bool,
config_file_path: str,
litellm_settings: dict,
):
try:
verbose_proxy_logger.debug(f"validating guardrails passed {guardrails_config}")
all_guardrails: List[GuardrailItem] = []
for item in guardrails_config:
"""
one item looks like this:
{'prompt_injection': {'callbacks': ['lakera_prompt_injection', 'prompt_injection_api_2'], 'default_on': True}}
"""
for k, v in item.items():
guardrail_item = GuardrailItem(**v, guardrail_name=k)
all_guardrails.append(guardrail_item)
# set appropriate callbacks if they are default on
default_on_callbacks = []
for guardrail in all_guardrails:
verbose_proxy_logger.debug(guardrail.guardrail_name)
verbose_proxy_logger.debug(guardrail.default_on)
if guardrail.default_on is True:
# add these to litellm callbacks if they don't exist
for callback in guardrail.callbacks:
if callback not in litellm.callbacks:
default_on_callbacks.append(callback)
if len(default_on_callbacks) > 0:
initialize_callbacks_on_proxy(
value=default_on_callbacks,
premium_user=premium_user,
config_file_path=config_file_path,
litellm_settings=litellm_settings,
)
except Exception as e:
verbose_proxy_logger.error(f"error initializing guardrails {str(e)}")
traceback.print_exc()
raise e

View file

@ -19,7 +19,6 @@ model_list:
model: mistral/mistral-embed
general_settings:
master_key: sk-1234
pass_through_endpoints:
- path: "/v1/rerank"
target: "https://api.cohere.com/v1/rerank"
@ -36,15 +35,13 @@ general_settings:
LANGFUSE_SECRET_KEY: "os.environ/LANGFUSE_DEV_SK_KEY"
litellm_settings:
return_response_headers: true
success_callback: ["prometheus"]
callbacks: ["otel", "hide_secrets"]
failure_callback: ["prometheus"]
store_audit_logs: true
redact_messages_in_exceptions: True
enforced_params:
- user
- metadata
- metadata.generation_name
guardrails:
- prompt_injection:
callbacks: [lakera_prompt_injection, hide_secrets]
default_on: true
- hide_secrets:
callbacks: [hide_secrets]
default_on: true

View file

@ -142,6 +142,8 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.caching_routes import router as caching_router
from litellm.proxy.common_utils.debug_utils import router as debugging_endpoints_router
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails
from litellm.proxy.health_check import perform_health_check
from litellm.proxy.health_endpoints._health_endpoints import router as health_router
from litellm.proxy.hooks.prompt_injection_detection import (
@ -1443,248 +1445,28 @@ class ProxyConfig:
)
elif key == "cache" and value == False:
pass
elif key == "guardrails":
if premium_user is not True:
raise ValueError(
"Trying to use `guardrails` on config.yaml "
+ CommonProxyErrors.not_premium_user.value
)
initialize_guardrails(
guardrails_config=value,
premium_user=premium_user,
config_file_path=config_file_path,
litellm_settings=litellm_settings,
)
elif key == "callbacks":
if isinstance(value, list):
imported_list: List[Any] = []
known_compatible_callbacks = list(
get_args(
litellm._custom_logger_compatible_callbacks_literal
)
)
for callback in value: # ["presidio", <my-custom-callback>]
if (
isinstance(callback, str)
and callback in known_compatible_callbacks
):
imported_list.append(callback)
elif isinstance(callback, str) and callback == "otel":
from litellm.integrations.opentelemetry import (
OpenTelemetry,
)
open_telemetry_logger = OpenTelemetry()
imported_list.append(open_telemetry_logger)
elif isinstance(callback, str) and callback == "presidio":
from litellm.proxy.hooks.presidio_pii_masking import (
_OPTIONAL_PresidioPIIMasking,
)
pii_masking_object = _OPTIONAL_PresidioPIIMasking()
imported_list.append(pii_masking_object)
elif (
isinstance(callback, str)
and callback == "llamaguard_moderations"
):
from enterprise.enterprise_hooks.llama_guard import (
_ENTERPRISE_LlamaGuard,
)
if premium_user != True:
raise Exception(
"Trying to use Llama Guard"
+ CommonProxyErrors.not_premium_user.value
)
llama_guard_object = _ENTERPRISE_LlamaGuard()
imported_list.append(llama_guard_object)
elif (
isinstance(callback, str) and callback == "hide_secrets"
):
from enterprise.enterprise_hooks.secret_detection import (
_ENTERPRISE_SecretDetection,
)
if premium_user != True:
raise Exception(
"Trying to use secret hiding"
+ CommonProxyErrors.not_premium_user.value
)
_secret_detection_object = _ENTERPRISE_SecretDetection()
imported_list.append(_secret_detection_object)
elif (
isinstance(callback, str)
and callback == "openai_moderations"
):
from enterprise.enterprise_hooks.openai_moderation import (
_ENTERPRISE_OpenAI_Moderation,
)
if premium_user != True:
raise Exception(
"Trying to use OpenAI Moderations Check"
+ CommonProxyErrors.not_premium_user.value
)
openai_moderations_object = (
_ENTERPRISE_OpenAI_Moderation()
)
imported_list.append(openai_moderations_object)
elif (
isinstance(callback, str)
and callback == "lakera_prompt_injection"
):
from enterprise.enterprise_hooks.lakera_ai import (
_ENTERPRISE_lakeraAI_Moderation,
)
if premium_user != True:
raise Exception(
"Trying to use LakeraAI Prompt Injection"
+ CommonProxyErrors.not_premium_user.value
)
lakera_moderations_object = (
_ENTERPRISE_lakeraAI_Moderation()
)
imported_list.append(lakera_moderations_object)
elif (
isinstance(callback, str)
and callback == "google_text_moderation"
):
from enterprise.enterprise_hooks.google_text_moderation import (
_ENTERPRISE_GoogleTextModeration,
)
if premium_user != True:
raise Exception(
"Trying to use Google Text Moderation"
+ CommonProxyErrors.not_premium_user.value
)
google_text_moderation_obj = (
_ENTERPRISE_GoogleTextModeration()
)
imported_list.append(google_text_moderation_obj)
elif (
isinstance(callback, str)
and callback == "llmguard_moderations"
):
from enterprise.enterprise_hooks.llm_guard import (
_ENTERPRISE_LLMGuard,
)
if premium_user != True:
raise Exception(
"Trying to use Llm Guard"
+ CommonProxyErrors.not_premium_user.value
)
llm_guard_moderation_obj = _ENTERPRISE_LLMGuard()
imported_list.append(llm_guard_moderation_obj)
elif (
isinstance(callback, str)
and callback == "blocked_user_check"
):
from enterprise.enterprise_hooks.blocked_user_list import (
_ENTERPRISE_BlockedUserList,
)
if premium_user != True:
raise Exception(
"Trying to use ENTERPRISE BlockedUser"
+ CommonProxyErrors.not_premium_user.value
)
blocked_user_list = _ENTERPRISE_BlockedUserList(
prisma_client=prisma_client
)
imported_list.append(blocked_user_list)
elif (
isinstance(callback, str)
and callback == "banned_keywords"
):
from enterprise.enterprise_hooks.banned_keywords import (
_ENTERPRISE_BannedKeywords,
)
if premium_user != True:
raise Exception(
"Trying to use ENTERPRISE BannedKeyword"
+ CommonProxyErrors.not_premium_user.value
)
banned_keywords_obj = _ENTERPRISE_BannedKeywords()
imported_list.append(banned_keywords_obj)
elif (
isinstance(callback, str)
and callback == "detect_prompt_injection"
):
from litellm.proxy.hooks.prompt_injection_detection import (
_OPTIONAL_PromptInjectionDetection,
)
prompt_injection_params = None
if "prompt_injection_params" in litellm_settings:
prompt_injection_params_in_config = (
litellm_settings["prompt_injection_params"]
)
prompt_injection_params = (
LiteLLMPromptInjectionParams(
**prompt_injection_params_in_config
)
)
prompt_injection_detection_obj = (
_OPTIONAL_PromptInjectionDetection(
prompt_injection_params=prompt_injection_params,
)
)
imported_list.append(prompt_injection_detection_obj)
elif (
isinstance(callback, str)
and callback == "batch_redis_requests"
):
from litellm.proxy.hooks.batch_redis_get import (
_PROXY_BatchRedisRequests,
)
batch_redis_obj = _PROXY_BatchRedisRequests()
imported_list.append(batch_redis_obj)
elif (
isinstance(callback, str)
and callback == "azure_content_safety"
):
from litellm.proxy.hooks.azure_content_safety import (
_PROXY_AzureContentSafety,
)
azure_content_safety_params = litellm_settings[
"azure_content_safety_params"
]
for k, v in azure_content_safety_params.items():
if (
v is not None
and isinstance(v, str)
and v.startswith("os.environ/")
):
azure_content_safety_params[k] = (
litellm.get_secret(v)
)
azure_content_safety_obj = _PROXY_AzureContentSafety(
**azure_content_safety_params,
)
imported_list.append(azure_content_safety_obj)
else:
imported_list.append(
get_instance_fn(
value=callback,
config_file_path=config_file_path,
)
)
litellm.callbacks = imported_list # type: ignore
else:
litellm.callbacks = [
get_instance_fn(
initialize_callbacks_on_proxy(
value=value,
premium_user=premium_user,
config_file_path=config_file_path,
litellm_settings=litellm_settings,
)
]
verbose_proxy_logger.debug(
f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}"
)
elif key == "post_call_rules":
litellm.post_call_rules = [
get_instance_fn(value=value, config_file_path=config_file_path)