forked from phoenix/litellm-mirror
feat(guardrails.py): allow setting logging_only
in guardrails_config for presidio pii masking integration
This commit is contained in:
parent
f2522867ed
commit
6b78e39600
7 changed files with 71 additions and 18 deletions
|
@ -17,7 +17,6 @@ from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.proxy.guardrails.init_guardrails import all_guardrails
|
|
||||||
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
|
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
|
@ -125,7 +125,7 @@ llamaguard_unsafe_content_categories: Optional[str] = None
|
||||||
blocked_user_list: Optional[Union[str, List]] = None
|
blocked_user_list: Optional[Union[str, List]] = None
|
||||||
banned_keywords_list: Optional[Union[str, List]] = None
|
banned_keywords_list: Optional[Union[str, List]] = None
|
||||||
llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all"
|
llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all"
|
||||||
guardrail_name_config_map: Optional[Dict[str, GuardrailItem]] = None
|
guardrail_name_config_map: Dict[str, GuardrailItem] = {}
|
||||||
##################
|
##################
|
||||||
### PREVIEW FEATURES ###
|
### PREVIEW FEATURES ###
|
||||||
enable_preview_features: bool = False
|
enable_preview_features: bool = False
|
||||||
|
|
|
@ -14,6 +14,7 @@ def initialize_callbacks_on_proxy(
|
||||||
premium_user: bool,
|
premium_user: bool,
|
||||||
config_file_path: str,
|
config_file_path: str,
|
||||||
litellm_settings: dict,
|
litellm_settings: dict,
|
||||||
|
callback_specific_params: dict = {},
|
||||||
):
|
):
|
||||||
from litellm.proxy.proxy_server import prisma_client
|
from litellm.proxy.proxy_server import prisma_client
|
||||||
|
|
||||||
|
@ -25,7 +26,6 @@ def initialize_callbacks_on_proxy(
|
||||||
known_compatible_callbacks = list(
|
known_compatible_callbacks = list(
|
||||||
get_args(litellm._custom_logger_compatible_callbacks_literal)
|
get_args(litellm._custom_logger_compatible_callbacks_literal)
|
||||||
)
|
)
|
||||||
|
|
||||||
for callback in value: # ["presidio", <my-custom-callback>]
|
for callback in value: # ["presidio", <my-custom-callback>]
|
||||||
if isinstance(callback, str) and callback in known_compatible_callbacks:
|
if isinstance(callback, str) and callback in known_compatible_callbacks:
|
||||||
imported_list.append(callback)
|
imported_list.append(callback)
|
||||||
|
@ -54,9 +54,11 @@ def initialize_callbacks_on_proxy(
|
||||||
presidio_logging_only
|
presidio_logging_only
|
||||||
) # validate boolean given
|
) # validate boolean given
|
||||||
|
|
||||||
pii_masking_object = _OPTIONAL_PresidioPIIMasking(
|
params = {
|
||||||
logging_only=presidio_logging_only
|
"logging_only": presidio_logging_only,
|
||||||
)
|
**callback_specific_params,
|
||||||
|
}
|
||||||
|
pii_masking_object = _OPTIONAL_PresidioPIIMasking(**params)
|
||||||
imported_list.append(pii_masking_object)
|
imported_list.append(pii_masking_object)
|
||||||
elif isinstance(callback, str) and callback == "llamaguard_moderations":
|
elif isinstance(callback, str) and callback == "llamaguard_moderations":
|
||||||
from enterprise.enterprise_hooks.llama_guard import (
|
from enterprise.enterprise_hooks.llama_guard import (
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
|
import litellm
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.proxy.guardrails.init_guardrails import guardrail_name_config_map
|
|
||||||
from litellm.proxy.proxy_server import UserAPIKeyAuth
|
from litellm.proxy.proxy_server import UserAPIKeyAuth
|
||||||
from litellm.types.guardrails import *
|
from litellm.types.guardrails import *
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ async def should_proceed_based_on_metadata(data: dict, guardrail_name: str) -> b
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# lookup the guardrail in guardrail_name_config_map
|
# lookup the guardrail in guardrail_name_config_map
|
||||||
guardrail_item: GuardrailItem = guardrail_name_config_map[
|
guardrail_item: GuardrailItem = litellm.guardrail_name_config_map[
|
||||||
_guardrail_name
|
_guardrail_name
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -80,7 +80,9 @@ async def should_proceed_based_on_api_key(
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# lookup the guardrail in guardrail_name_config_map
|
# lookup the guardrail in guardrail_name_config_map
|
||||||
guardrail_item: GuardrailItem = guardrail_name_config_map[_guardrail_name]
|
guardrail_item: GuardrailItem = litellm.guardrail_name_config_map[
|
||||||
|
_guardrail_name
|
||||||
|
]
|
||||||
|
|
||||||
guardrail_callbacks = guardrail_item.callbacks
|
guardrail_callbacks = guardrail_item.callbacks
|
||||||
if guardrail_name in guardrail_callbacks:
|
if guardrail_name in guardrail_callbacks:
|
||||||
|
|
|
@ -6,15 +6,13 @@ from pydantic import BaseModel, RootModel
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy
|
from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy
|
||||||
from litellm.types.guardrails import GuardrailItem
|
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
|
||||||
|
|
||||||
all_guardrails: List[GuardrailItem] = []
|
all_guardrails: List[GuardrailItem] = []
|
||||||
|
|
||||||
guardrail_name_config_map: Dict[str, GuardrailItem] = {}
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_guardrails(
|
def initialize_guardrails(
|
||||||
guardrails_config: list,
|
guardrails_config: List[Dict[str, GuardrailItemSpec]],
|
||||||
premium_user: bool,
|
premium_user: bool,
|
||||||
config_file_path: str,
|
config_file_path: str,
|
||||||
litellm_settings: dict,
|
litellm_settings: dict,
|
||||||
|
@ -28,14 +26,14 @@ def initialize_guardrails(
|
||||||
|
|
||||||
{'prompt_injection': {'callbacks': ['lakera_prompt_injection', 'prompt_injection_api_2'], 'default_on': True}}
|
{'prompt_injection': {'callbacks': ['lakera_prompt_injection', 'prompt_injection_api_2'], 'default_on': True}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for k, v in item.items():
|
for k, v in item.items():
|
||||||
guardrail_item = GuardrailItem(**v, guardrail_name=k)
|
guardrail_item = GuardrailItem(**v, guardrail_name=k)
|
||||||
all_guardrails.append(guardrail_item)
|
all_guardrails.append(guardrail_item)
|
||||||
guardrail_name_config_map[k] = guardrail_item
|
litellm.guardrail_name_config_map[k] = guardrail_item
|
||||||
|
|
||||||
# set appropriate callbacks if they are default on
|
# set appropriate callbacks if they are default on
|
||||||
default_on_callbacks = set()
|
default_on_callbacks = set()
|
||||||
|
callback_specific_params = {}
|
||||||
for guardrail in all_guardrails:
|
for guardrail in all_guardrails:
|
||||||
verbose_proxy_logger.debug(guardrail.guardrail_name)
|
verbose_proxy_logger.debug(guardrail.guardrail_name)
|
||||||
verbose_proxy_logger.debug(guardrail.default_on)
|
verbose_proxy_logger.debug(guardrail.default_on)
|
||||||
|
@ -46,6 +44,10 @@ def initialize_guardrails(
|
||||||
if callback not in litellm.callbacks:
|
if callback not in litellm.callbacks:
|
||||||
default_on_callbacks.add(callback)
|
default_on_callbacks.add(callback)
|
||||||
|
|
||||||
|
if guardrail.logging_only is True:
|
||||||
|
if callback == "presidio":
|
||||||
|
callback_specific_params["logging_only"] = True
|
||||||
|
|
||||||
default_on_callbacks_list = list(default_on_callbacks)
|
default_on_callbacks_list = list(default_on_callbacks)
|
||||||
if len(default_on_callbacks_list) > 0:
|
if len(default_on_callbacks_list) > 0:
|
||||||
initialize_callbacks_on_proxy(
|
initialize_callbacks_on_proxy(
|
||||||
|
@ -53,9 +55,10 @@ def initialize_guardrails(
|
||||||
premium_user=premium_user,
|
premium_user=premium_user,
|
||||||
config_file_path=config_file_path,
|
config_file_path=config_file_path,
|
||||||
litellm_settings=litellm_settings,
|
litellm_settings=litellm_settings,
|
||||||
|
callback_specific_params=callback_specific_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
return guardrail_name_config_map
|
return litellm.guardrail_name_config_map
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.error(
|
verbose_proxy_logger.error(
|
||||||
"error initializing guardrails {}\n{}".format(
|
"error initializing guardrails {}\n{}".format(
|
||||||
|
|
|
@ -263,3 +263,43 @@ async def test_presidio_pii_masking_logging_output_only_logged_response():
|
||||||
mock_call.call_args.kwargs["kwargs"]["messages"][0]["content"]
|
mock_call.call_args.kwargs["kwargs"]["messages"][0]["content"]
|
||||||
== "My name is <PERSON>, who are you? Say my name in your response"
|
== "My name is <PERSON>, who are you? Say my name in your response"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_presidio_pii_masking_logging_output_only_logged_response_guardrails_config():
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails
|
||||||
|
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
|
||||||
|
|
||||||
|
guardrails_config: List[Dict[str, GuardrailItemSpec]] = [
|
||||||
|
{
|
||||||
|
"pii_masking": {
|
||||||
|
"callbacks": ["presidio"],
|
||||||
|
"default_on": True,
|
||||||
|
"logging_only": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
litellm_settings = {"guardrails": guardrails_config}
|
||||||
|
|
||||||
|
assert len(litellm.guardrail_name_config_map) == 0
|
||||||
|
initialize_guardrails(
|
||||||
|
guardrails_config=guardrails_config,
|
||||||
|
premium_user=True,
|
||||||
|
config_file_path="",
|
||||||
|
litellm_settings=litellm_settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(litellm.guardrail_name_config_map) == 1
|
||||||
|
|
||||||
|
pii_masking_obj: Optional[_OPTIONAL_PresidioPIIMasking] = None
|
||||||
|
for callback in litellm.callbacks:
|
||||||
|
if isinstance(callback, _OPTIONAL_PresidioPIIMasking):
|
||||||
|
pii_masking_obj = callback
|
||||||
|
|
||||||
|
assert pii_masking_obj is not None
|
||||||
|
|
||||||
|
assert hasattr(pii_masking_obj, "logging_only")
|
||||||
|
assert pii_masking_obj.logging_only is True
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from typing import Dict, List, Optional, TypedDict, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, RootModel
|
from pydantic import BaseModel, RootModel
|
||||||
|
from typing_extensions import Required, TypedDict, override
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Pydantic object defining how to set guardrails on litellm proxy
|
Pydantic object defining how to set guardrails on litellm proxy
|
||||||
|
@ -16,6 +17,12 @@ litellm_settings:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class GuardrailItemSpec(TypedDict, total=False):
|
||||||
|
callbacks: Required[List[str]]
|
||||||
|
default_on: bool
|
||||||
|
logging_only: Optional[bool]
|
||||||
|
|
||||||
|
|
||||||
class GuardrailItem(BaseModel):
|
class GuardrailItem(BaseModel):
|
||||||
callbacks: List[str]
|
callbacks: List[str]
|
||||||
default_on: bool
|
default_on: bool
|
||||||
|
@ -25,8 +32,8 @@ class GuardrailItem(BaseModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
callbacks: List[str],
|
callbacks: List[str],
|
||||||
default_on: bool,
|
|
||||||
guardrail_name: str,
|
guardrail_name: str,
|
||||||
|
default_on: bool = False,
|
||||||
logging_only: Optional[bool] = None,
|
logging_only: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue