From 9b5164b38d5ebed1fff1febe2f78fd84cd7dc66e Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 4 Sep 2024 12:46:59 -0700 Subject: [PATCH] fix allow setting language per call to presidio --- .../guardrails/guardrail_hooks/presidio.py | 70 +++++++++++++++---- litellm/proxy/litellm_pre_call_utils.py | 4 ++ 2 files changed, 60 insertions(+), 14 deletions(-) diff --git a/litellm/proxy/guardrails/guardrail_hooks/presidio.py b/litellm/proxy/guardrails/guardrail_hooks/presidio.py index cb1d3df1b..a44bdaa9f 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/presidio.py +++ b/litellm/proxy/guardrails/guardrail_hooks/presidio.py @@ -16,6 +16,7 @@ from typing import Any, List, Optional, Tuple, Union import aiohttp from fastapi import HTTPException +from pydantic import BaseModel import litellm # noqa: E401 from litellm._logging import verbose_proxy_logger @@ -32,6 +33,14 @@ from litellm.utils import ( ) +class PresidioPerRequestConfig(BaseModel): + """ + presdio params that can be controlled per request, api key + """ + + language: Optional[str] = None + + class _OPTIONAL_PresidioPIIMasking(CustomGuardrail): user_api_key_cache = None ad_hoc_recognizers = None @@ -70,7 +79,6 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail): raise Exception( f"An error occurred: {str(e)}, file_path={ad_hoc_recognizers}" ) - self.async_http_client = _get_async_httpx_client() self.validate_environment( presidio_analyzer_api_base=presidio_analyzer_api_base, presidio_anonymizer_api_base=presidio_anonymizer_api_base, @@ -119,15 +127,12 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail): "http://" + self.presidio_anonymizer_api_base ) - def print_verbose(self, print_statement): - try: - verbose_proxy_logger.debug(print_statement) - if litellm.set_verbose: - print(print_statement) # noqa - except: - pass - - async def check_pii(self, text: str, output_parse_pii: bool) -> str: + async def check_pii( + self, + text: str, + output_parse_pii: bool, + presidio_config: Optional[PresidioPerRequestConfig], + ) -> str: """ [TODO] make this more performant for high-throughput scenario """ @@ -137,13 +142,21 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail): redacted_text = self.mock_redacted_text else: # Make the first request to /analyze + # Construct Request 1 analyze_url = f"{self.presidio_analyzer_api_base}analyze" - verbose_proxy_logger.debug("Making request to: %s", analyze_url) analyze_payload = {"text": text, "language": "en"} + if presidio_config and presidio_config.language: + analyze_payload["language"] = presidio_config.language if self.ad_hoc_recognizers is not None: analyze_payload["ad_hoc_recognizers"] = self.ad_hoc_recognizers - redacted_text = None + # End of constructing Request 1 + redacted_text = None + verbose_proxy_logger.debug( + "Making request to: %s with payload: %s", + analyze_url, + analyze_payload, + ) async with session.post( analyze_url, json=analyze_payload ) as response: @@ -275,6 +288,8 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail): if no_pii is True: # turn off pii masking return data + presidio_config = self.get_presidio_settings_from_request_data(data) + if call_type == "completion": # /chat/completions requests messages = data["messages"] tasks = [] @@ -283,7 +298,9 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail): if isinstance(m["content"], str): tasks.append( self.check_pii( - text=m["content"], output_parse_pii=output_parse_pii + text=m["content"], + output_parse_pii=output_parse_pii, + presidio_config=presidio_config, ) ) responses = await asyncio.gather(*tasks) @@ -317,6 +334,8 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail): if messages is None: return kwargs, result + presidio_config = self.get_presidio_settings_from_request_data(kwargs) + for m in messages: text_str = "" if m["content"] is None: @@ -324,7 +343,11 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail): if isinstance(m["content"], str): text_str = m["content"] tasks.append( - self.check_pii(text=text_str, output_parse_pii=False) + self.check_pii( + text=text_str, + output_parse_pii=False, + presidio_config=presidio_config, + ) ) # need to pass separately b/c presidio has context window limits responses = await asyncio.gather(*tasks) for index, r in enumerate(responses): @@ -366,3 +389,22 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail): 0 ].message.content.replace(key, value) return response + + def get_presidio_settings_from_request_data( + self, data: dict + ) -> Optional[PresidioPerRequestConfig]: + if "metadata" in data: + _metadata = data["metadata"] + _guardrail_config = _metadata.get("guardrail_config") + _presidio_config = PresidioPerRequestConfig(**_guardrail_config) + return _presidio_config + + return None + + def print_verbose(self, print_statement): + try: + verbose_proxy_logger.debug(print_statement) + if litellm.set_verbose: + print(print_statement) # noqa + except: + pass diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 60052bc27..d41aae50f 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -420,6 +420,10 @@ def move_guardrails_to_metadata( data[_metadata_variable_name]["guardrails"] = data["guardrails"] del data["guardrails"] + if "guardrail_config" in data: + data[_metadata_variable_name]["guardrail_config"] = data["guardrail_config"] + del data["guardrail_config"] + def add_provider_specific_headers_to_request( data: dict,