diff --git a/litellm/proxy/guardrails/init_guardrails.py b/litellm/proxy/guardrails/init_guardrails.py index 1ff16b59e5..4cf4510196 100644 --- a/litellm/proxy/guardrails/init_guardrails.py +++ b/litellm/proxy/guardrails/init_guardrails.py @@ -31,7 +31,7 @@ def initialize_guardrails( all_guardrails.append(guardrail_item) # set appropriate callbacks if they are default on - default_on_callbacks = [] + default_on_callbacks = set() for guardrail in all_guardrails: verbose_proxy_logger.debug(guardrail.guardrail_name) verbose_proxy_logger.debug(guardrail.default_on) @@ -40,11 +40,12 @@ def initialize_guardrails( # add these to litellm callbacks if they don't exist for callback in guardrail.callbacks: if callback not in litellm.callbacks: - default_on_callbacks.append(callback) + default_on_callbacks.add(callback) - if len(default_on_callbacks) > 0: + default_on_callbacks_list = list(default_on_callbacks) + if len(default_on_callbacks_list) > 0: initialize_callbacks_on_proxy( - value=default_on_callbacks, + value=default_on_callbacks_list, premium_user=premium_user, config_file_path=config_file_path, litellm_settings=litellm_settings, diff --git a/litellm/tests/test_configs/test_guardrails_config.yaml b/litellm/tests/test_configs/test_guardrails_config.yaml new file mode 100644 index 0000000000..f09ff9d1bc --- /dev/null +++ b/litellm/tests/test_configs/test_guardrails_config.yaml @@ -0,0 +1,32 @@ + + +model_list: +- litellm_params: + api_base: https://my-endpoint-europe-berri-992.openai.azure.com/ + api_key: os.environ/AZURE_EUROPE_API_KEY + model: azure/gpt-35-turbo + model_name: azure-model +- litellm_params: + api_base: https://my-endpoint-canada-berri992.openai.azure.com + api_key: os.environ/AZURE_CANADA_API_KEY + model: azure/gpt-35-turbo + model_name: azure-model +- litellm_params: + api_base: https://openai-france-1234.openai.azure.com + api_key: os.environ/AZURE_FRANCE_API_KEY + model: azure/gpt-turbo + model_name: azure-model + + + +litellm_settings: + guardrails: + - prompt_injection: + callbacks: [lakera_prompt_injection, detect_prompt_injection] + default_on: true + - hide_secrets: + callbacks: [hide_secrets] + default_on: true + - moderations: + callbacks: [openai_moderations] + default_on: false \ No newline at end of file diff --git a/litellm/tests/test_proxy_setting_guardrails.py b/litellm/tests/test_proxy_setting_guardrails.py new file mode 100644 index 0000000000..048951da0a --- /dev/null +++ b/litellm/tests/test_proxy_setting_guardrails.py @@ -0,0 +1,69 @@ +import json +import os +import sys +from unittest import mock + +from dotenv import load_dotenv + +load_dotenv() +import asyncio +import io +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import openai +import pytest +from fastapi import Response +from fastapi.testclient import TestClient + +import litellm +from litellm.proxy.proxy_server import ( # Replace with the actual module where your FastAPI router is defined + initialize, + router, + save_worker_config, +) + + +@pytest.fixture +def client(): + filepath = os.path.dirname(os.path.abspath(__file__)) + config_fp = f"{filepath}/test_configs/test_guardrails_config.yaml" + asyncio.run(initialize(config=config_fp)) + from litellm.proxy.proxy_server import app + + return TestClient(app) + + +# raise openai.AuthenticationError +def test_active_callbacks(client): + response = client.get("/active/callbacks") + + print("response", response) + print("response.text", response.text) + print("response.status_code", response.status_code) + + json_response = response.json() + _active_callbacks = json_response["litellm.callbacks"] + + expected_callback_names = [ + "_ENTERPRISE_lakeraAI_Moderation", + "_OPTIONAL_PromptInjectionDetectio", + "_ENTERPRISE_SecretDetection", + ] + + for callback_name in expected_callback_names: + # check if any of the callbacks have callback_name as a substring + found_match = False + for callback in _active_callbacks: + if callback_name in callback: + found_match = True + break + assert ( + found_match is True + ), f"{callback_name} not found in _active_callbacks={_active_callbacks}" + + assert not any( + "_ENTERPRISE_OpenAI_Moderation" in callback for callback in _active_callbacks + ), f"_ENTERPRISE_OpenAI_Moderation should not be in _active_callbacks={_active_callbacks}"