mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
added a dynamic configuration in addition to the static 'litellm_params' configuration for CustomGuardrail
This commit is contained in:
parent
b0fa934fe3
commit
8f01f3eb87
3 changed files with 123 additions and 5 deletions
|
@ -99,7 +99,6 @@ def init_guardrails_v2(
|
|||
}
|
||||
|
||||
litellm_params = LitellmParams(**_litellm_params_kwargs) # type: ignore
|
||||
|
||||
if (
|
||||
"category_thresholds" in litellm_params_data
|
||||
and litellm_params_data["category_thresholds"]
|
||||
|
@ -152,11 +151,19 @@ def init_guardrails_v2(
|
|||
spec.loader.exec_module(module) # type: ignore
|
||||
_guardrail_class = getattr(module, _class_name)
|
||||
|
||||
# Split params into known and additional parameters
|
||||
known_params = {k: litellm_params_data[k] for k in LitellmParams.__annotations__.keys() if k in litellm_params_data}
|
||||
additional_params = {k: v for k, v in litellm_params_data.items() if k not in LitellmParams.__annotations__.keys()}
|
||||
|
||||
# Initialize with known parameters
|
||||
_guardrail_callback = _guardrail_class(
|
||||
guardrail_name=guardrail["guardrail_name"],
|
||||
event_hook=litellm_params["mode"],
|
||||
default_on=litellm_params["default_on"],
|
||||
**known_params
|
||||
)
|
||||
|
||||
# Update optional parameters while preserving existing ones
|
||||
if not hasattr(_guardrail_callback, 'optional_params'):
|
||||
_guardrail_callback.optional_params = {}
|
||||
_guardrail_callback.optional_params.update(additional_params)
|
||||
litellm.logging_callback_manager.add_litellm_callback(_guardrail_callback) # type: ignore
|
||||
else:
|
||||
raise ValueError(f"Unsupported guardrail: {guardrail_type}")
|
||||
|
|
|
@ -80,7 +80,8 @@ class LakeraCategoryThresholds(TypedDict, total=False):
|
|||
jailbreak: float
|
||||
|
||||
|
||||
class LitellmParams(TypedDict):
|
||||
class LitellmParams(TypedDict, total=False):
|
||||
"""TypedDict for Litellm parameters with support for both static and dynamic fields"""
|
||||
guardrail: str
|
||||
mode: str
|
||||
api_key: Optional[str]
|
||||
|
@ -105,6 +106,11 @@ class LitellmParams(TypedDict):
|
|||
guard_name: Optional[str]
|
||||
default_on: Optional[bool]
|
||||
|
||||
# Support for dynamic parameters
|
||||
def __class_getitem__(cls, key: str) -> Any:
|
||||
"""Enable dictionary-style access to dynamic fields"""
|
||||
return Dict[str, Any].__class_getitem__(key)
|
||||
|
||||
|
||||
class Guardrail(TypedDict, total=False):
|
||||
guardrail_name: str
|
||||
|
|
105
tests/local_testing/test_custom_guardrails.py
Normal file
105
tests/local_testing/test_custom_guardrails.py
Normal file
|
@ -0,0 +1,105 @@
|
|||
from typing import Dict, Literal, Optional, Union
|
||||
|
||||
import pytest
|
||||
from litellm import DualCache
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.proxy.guardrails.init_guardrails import init_guardrails_v2
|
||||
from litellm.proxy.proxy_server import UserAPIKeyAuth
|
||||
|
||||
# Test Constants
|
||||
TEST_API_BASE = "http://127.0.0.1:8000/api/scan"
|
||||
TEST_API_JWT = "token"
|
||||
TEST_THRESHOLD = 1
|
||||
|
||||
class CustomGuardrailMock(CustomGuardrail):
|
||||
"""Mock implementation of CustomGuardrail for testing purposes"""
|
||||
def __init__(self, **kwargs) -> None:
|
||||
# Initialize with message_logging=True for parent class
|
||||
super().__init__(message_logging=True)
|
||||
# Store all kwargs as optional_params
|
||||
self.optional_params = kwargs
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: Dict,
|
||||
call_type: Literal["completion", "text_completion", "embeddings"],
|
||||
) -> Optional[Union[Exception, str, Dict]]:
|
||||
"""Mock pre-call hook that always succeeds"""
|
||||
return None
|
||||
|
||||
async def async_moderation_hook(
|
||||
self,
|
||||
data: Dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
call_type: Literal["completion", "embeddings", "image_generation", "moderation", "audio_transcription"],
|
||||
) -> None:
|
||||
"""Mock moderation hook that always succeeds"""
|
||||
return None
|
||||
|
||||
class TestCustomGuardrails:
|
||||
"""Test suite for custom guardrails functionality"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self) -> None:
|
||||
"""Setup test environment before each test"""
|
||||
import litellm
|
||||
litellm.set_verbose = True
|
||||
yield
|
||||
# Reset callbacks after each test
|
||||
litellm.callbacks = []
|
||||
|
||||
def get_test_guardrail_config(self, guardrail_class: str = "test_custom_guardrails.CustomGuardrailMock") -> list[Dict]:
|
||||
"""Helper method to generate test guardrail configuration"""
|
||||
return [{
|
||||
"guardrail_name": "custom_guardrail",
|
||||
"litellm_params": {
|
||||
"guardrail": guardrail_class,
|
||||
"guard_name": "custom_guard",
|
||||
"mode": "pre_call",
|
||||
"api_base": TEST_API_BASE,
|
||||
"api_jwt": TEST_API_JWT,
|
||||
"threshold": TEST_THRESHOLD,
|
||||
},
|
||||
}]
|
||||
|
||||
def test_unsupported_guardrail(self) -> None:
|
||||
"""Test initialization with unsupported guardrail class"""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
init_guardrails_v2(
|
||||
all_guardrails=self.get_test_guardrail_config("FakeCustomGuardrail"),
|
||||
config_file_path="test_config.yml",
|
||||
)
|
||||
assert "Unsupported guardrail" in str(exc_info.value)
|
||||
|
||||
def test_missing_config_file(self) -> None:
|
||||
"""Test initialization with missing config file"""
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
init_guardrails_v2(
|
||||
all_guardrails=self.get_test_guardrail_config(),
|
||||
config_file_path="",
|
||||
)
|
||||
assert "GuardrailsAIException - Please pass the config_file_path" in str(exc_info.value)
|
||||
|
||||
def test_successful_initialization(self) -> None:
|
||||
"""Test successful guardrail initialization and configuration"""
|
||||
import litellm
|
||||
|
||||
init_guardrails_v2(
|
||||
all_guardrails=self.get_test_guardrail_config(),
|
||||
config_file_path="local_testing/test_custom_guardrails.py",
|
||||
)
|
||||
|
||||
# Verify guardrail was properly initialized
|
||||
custom_guardrails = [
|
||||
callback for callback in litellm.callbacks
|
||||
if isinstance(callback, CustomGuardrail)
|
||||
]
|
||||
assert len(custom_guardrails) == 1
|
||||
|
||||
# Verify configuration was properly set
|
||||
custom_guardrail = custom_guardrails[0]
|
||||
assert custom_guardrail.optional_params.get("api_base") == TEST_API_BASE
|
||||
assert custom_guardrail.optional_params.get("api_jwt") == TEST_API_JWT
|
||||
assert custom_guardrail.optional_params.get("threshold") == TEST_THRESHOLD
|
Loading…
Add table
Add a link
Reference in a new issue