init guardrails on proxy

This commit is contained in:
Ishaan Jaff 2024-07-03 14:18:12 -07:00
parent a2b6baab16
commit 129c2e0c4f
4 changed files with 302 additions and 250 deletions

View file

@ -142,6 +142,8 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.caching_routes import router as caching_router
from litellm.proxy.common_utils.debug_utils import router as debugging_endpoints_router
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails
from litellm.proxy.health_check import perform_health_check
from litellm.proxy.health_endpoints._health_endpoints import router as health_router
from litellm.proxy.hooks.prompt_injection_detection import (
@ -1443,248 +1445,28 @@ class ProxyConfig:
)
elif key == "cache" and value == False:
pass
elif key == "callbacks":
if isinstance(value, list):
imported_list: List[Any] = []
known_compatible_callbacks = list(
get_args(
litellm._custom_logger_compatible_callbacks_literal
)
elif key == "guardrails":
if premium_user is not True:
raise ValueError(
"Trying to use `guardrails` on config.yaml "
+ CommonProxyErrors.not_premium_user.value
)
for callback in value: # ["presidio", <my-custom-callback>]
if (
isinstance(callback, str)
and callback in known_compatible_callbacks
):
imported_list.append(callback)
elif isinstance(callback, str) and callback == "otel":
from litellm.integrations.opentelemetry import (
OpenTelemetry,
)
open_telemetry_logger = OpenTelemetry()
imported_list.append(open_telemetry_logger)
elif isinstance(callback, str) and callback == "presidio":
from litellm.proxy.hooks.presidio_pii_masking import (
_OPTIONAL_PresidioPIIMasking,
)
pii_masking_object = _OPTIONAL_PresidioPIIMasking()
imported_list.append(pii_masking_object)
elif (
isinstance(callback, str)
and callback == "llamaguard_moderations"
):
from enterprise.enterprise_hooks.llama_guard import (
_ENTERPRISE_LlamaGuard,
)
if premium_user != True:
raise Exception(
"Trying to use Llama Guard"
+ CommonProxyErrors.not_premium_user.value
)
llama_guard_object = _ENTERPRISE_LlamaGuard()
imported_list.append(llama_guard_object)
elif (
isinstance(callback, str) and callback == "hide_secrets"
):
from enterprise.enterprise_hooks.secret_detection import (
_ENTERPRISE_SecretDetection,
)
if premium_user != True:
raise Exception(
"Trying to use secret hiding"
+ CommonProxyErrors.not_premium_user.value
)
_secret_detection_object = _ENTERPRISE_SecretDetection()
imported_list.append(_secret_detection_object)
elif (
isinstance(callback, str)
and callback == "openai_moderations"
):
from enterprise.enterprise_hooks.openai_moderation import (
_ENTERPRISE_OpenAI_Moderation,
)
if premium_user != True:
raise Exception(
"Trying to use OpenAI Moderations Check"
+ CommonProxyErrors.not_premium_user.value
)
openai_moderations_object = (
_ENTERPRISE_OpenAI_Moderation()
)
imported_list.append(openai_moderations_object)
elif (
isinstance(callback, str)
and callback == "lakera_prompt_injection"
):
from enterprise.enterprise_hooks.lakera_ai import (
_ENTERPRISE_lakeraAI_Moderation,
)
if premium_user != True:
raise Exception(
"Trying to use LakeraAI Prompt Injection"
+ CommonProxyErrors.not_premium_user.value
)
lakera_moderations_object = (
_ENTERPRISE_lakeraAI_Moderation()
)
imported_list.append(lakera_moderations_object)
elif (
isinstance(callback, str)
and callback == "google_text_moderation"
):
from enterprise.enterprise_hooks.google_text_moderation import (
_ENTERPRISE_GoogleTextModeration,
)
if premium_user != True:
raise Exception(
"Trying to use Google Text Moderation"
+ CommonProxyErrors.not_premium_user.value
)
google_text_moderation_obj = (
_ENTERPRISE_GoogleTextModeration()
)
imported_list.append(google_text_moderation_obj)
elif (
isinstance(callback, str)
and callback == "llmguard_moderations"
):
from enterprise.enterprise_hooks.llm_guard import (
_ENTERPRISE_LLMGuard,
)
if premium_user != True:
raise Exception(
"Trying to use Llm Guard"
+ CommonProxyErrors.not_premium_user.value
)
llm_guard_moderation_obj = _ENTERPRISE_LLMGuard()
imported_list.append(llm_guard_moderation_obj)
elif (
isinstance(callback, str)
and callback == "blocked_user_check"
):
from enterprise.enterprise_hooks.blocked_user_list import (
_ENTERPRISE_BlockedUserList,
)
if premium_user != True:
raise Exception(
"Trying to use ENTERPRISE BlockedUser"
+ CommonProxyErrors.not_premium_user.value
)
blocked_user_list = _ENTERPRISE_BlockedUserList(
prisma_client=prisma_client
)
imported_list.append(blocked_user_list)
elif (
isinstance(callback, str)
and callback == "banned_keywords"
):
from enterprise.enterprise_hooks.banned_keywords import (
_ENTERPRISE_BannedKeywords,
)
if premium_user != True:
raise Exception(
"Trying to use ENTERPRISE BannedKeyword"
+ CommonProxyErrors.not_premium_user.value
)
banned_keywords_obj = _ENTERPRISE_BannedKeywords()
imported_list.append(banned_keywords_obj)
elif (
isinstance(callback, str)
and callback == "detect_prompt_injection"
):
from litellm.proxy.hooks.prompt_injection_detection import (
_OPTIONAL_PromptInjectionDetection,
)
prompt_injection_params = None
if "prompt_injection_params" in litellm_settings:
prompt_injection_params_in_config = (
litellm_settings["prompt_injection_params"]
)
prompt_injection_params = (
LiteLLMPromptInjectionParams(
**prompt_injection_params_in_config
)
)
prompt_injection_detection_obj = (
_OPTIONAL_PromptInjectionDetection(
prompt_injection_params=prompt_injection_params,
)
)
imported_list.append(prompt_injection_detection_obj)
elif (
isinstance(callback, str)
and callback == "batch_redis_requests"
):
from litellm.proxy.hooks.batch_redis_get import (
_PROXY_BatchRedisRequests,
)
batch_redis_obj = _PROXY_BatchRedisRequests()
imported_list.append(batch_redis_obj)
elif (
isinstance(callback, str)
and callback == "azure_content_safety"
):
from litellm.proxy.hooks.azure_content_safety import (
_PROXY_AzureContentSafety,
)
azure_content_safety_params = litellm_settings[
"azure_content_safety_params"
]
for k, v in azure_content_safety_params.items():
if (
v is not None
and isinstance(v, str)
and v.startswith("os.environ/")
):
azure_content_safety_params[k] = (
litellm.get_secret(v)
)
azure_content_safety_obj = _PROXY_AzureContentSafety(
**azure_content_safety_params,
)
imported_list.append(azure_content_safety_obj)
else:
imported_list.append(
get_instance_fn(
value=callback,
config_file_path=config_file_path,
)
)
litellm.callbacks = imported_list # type: ignore
else:
litellm.callbacks = [
get_instance_fn(
value=value,
config_file_path=config_file_path,
)
]
verbose_proxy_logger.debug(
f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}"
initialize_guardrails(
guardrails_config=value,
premium_user=premium_user,
config_file_path=config_file_path,
litellm_settings=litellm_settings,
)
elif key == "callbacks":
initialize_callbacks_on_proxy(
value=value,
premium_user=premium_user,
config_file_path=config_file_path,
litellm_settings=litellm_settings,
)
elif key == "post_call_rules":
litellm.post_call_rules = [
get_instance_fn(value=value, config_file_path=config_file_path)