fix allow setting language per call to presidio

This commit is contained in:
Ishaan Jaff 2024-09-04 12:46:59 -07:00
parent d954413b14
commit 9b5164b38d
2 changed files with 60 additions and 14 deletions

View file

@ -16,6 +16,7 @@ from typing import Any, List, Optional, Tuple, Union
import aiohttp import aiohttp
from fastapi import HTTPException from fastapi import HTTPException
from pydantic import BaseModel
import litellm # noqa: E401 import litellm # noqa: E401
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
@ -32,6 +33,14 @@ from litellm.utils import (
) )
class PresidioPerRequestConfig(BaseModel):
"""
presdio params that can be controlled per request, api key
"""
language: Optional[str] = None
class _OPTIONAL_PresidioPIIMasking(CustomGuardrail): class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
user_api_key_cache = None user_api_key_cache = None
ad_hoc_recognizers = None ad_hoc_recognizers = None
@ -70,7 +79,6 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
raise Exception( raise Exception(
f"An error occurred: {str(e)}, file_path={ad_hoc_recognizers}" f"An error occurred: {str(e)}, file_path={ad_hoc_recognizers}"
) )
self.async_http_client = _get_async_httpx_client()
self.validate_environment( self.validate_environment(
presidio_analyzer_api_base=presidio_analyzer_api_base, presidio_analyzer_api_base=presidio_analyzer_api_base,
presidio_anonymizer_api_base=presidio_anonymizer_api_base, presidio_anonymizer_api_base=presidio_anonymizer_api_base,
@ -119,15 +127,12 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
"http://" + self.presidio_anonymizer_api_base "http://" + self.presidio_anonymizer_api_base
) )
def print_verbose(self, print_statement): async def check_pii(
try: self,
verbose_proxy_logger.debug(print_statement) text: str,
if litellm.set_verbose: output_parse_pii: bool,
print(print_statement) # noqa presidio_config: Optional[PresidioPerRequestConfig],
except: ) -> str:
pass
async def check_pii(self, text: str, output_parse_pii: bool) -> str:
""" """
[TODO] make this more performant for high-throughput scenario [TODO] make this more performant for high-throughput scenario
""" """
@ -137,13 +142,21 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
redacted_text = self.mock_redacted_text redacted_text = self.mock_redacted_text
else: else:
# Make the first request to /analyze # Make the first request to /analyze
# Construct Request 1
analyze_url = f"{self.presidio_analyzer_api_base}analyze" analyze_url = f"{self.presidio_analyzer_api_base}analyze"
verbose_proxy_logger.debug("Making request to: %s", analyze_url)
analyze_payload = {"text": text, "language": "en"} analyze_payload = {"text": text, "language": "en"}
if presidio_config and presidio_config.language:
analyze_payload["language"] = presidio_config.language
if self.ad_hoc_recognizers is not None: if self.ad_hoc_recognizers is not None:
analyze_payload["ad_hoc_recognizers"] = self.ad_hoc_recognizers analyze_payload["ad_hoc_recognizers"] = self.ad_hoc_recognizers
redacted_text = None # End of constructing Request 1
redacted_text = None
verbose_proxy_logger.debug(
"Making request to: %s with payload: %s",
analyze_url,
analyze_payload,
)
async with session.post( async with session.post(
analyze_url, json=analyze_payload analyze_url, json=analyze_payload
) as response: ) as response:
@ -275,6 +288,8 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
if no_pii is True: # turn off pii masking if no_pii is True: # turn off pii masking
return data return data
presidio_config = self.get_presidio_settings_from_request_data(data)
if call_type == "completion": # /chat/completions requests if call_type == "completion": # /chat/completions requests
messages = data["messages"] messages = data["messages"]
tasks = [] tasks = []
@ -283,7 +298,9 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
if isinstance(m["content"], str): if isinstance(m["content"], str):
tasks.append( tasks.append(
self.check_pii( self.check_pii(
text=m["content"], output_parse_pii=output_parse_pii text=m["content"],
output_parse_pii=output_parse_pii,
presidio_config=presidio_config,
) )
) )
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*tasks)
@ -317,6 +334,8 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
if messages is None: if messages is None:
return kwargs, result return kwargs, result
presidio_config = self.get_presidio_settings_from_request_data(kwargs)
for m in messages: for m in messages:
text_str = "" text_str = ""
if m["content"] is None: if m["content"] is None:
@ -324,7 +343,11 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
if isinstance(m["content"], str): if isinstance(m["content"], str):
text_str = m["content"] text_str = m["content"]
tasks.append( tasks.append(
self.check_pii(text=text_str, output_parse_pii=False) self.check_pii(
text=text_str,
output_parse_pii=False,
presidio_config=presidio_config,
)
) # need to pass separately b/c presidio has context window limits ) # need to pass separately b/c presidio has context window limits
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*tasks)
for index, r in enumerate(responses): for index, r in enumerate(responses):
@ -366,3 +389,22 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
0 0
].message.content.replace(key, value) ].message.content.replace(key, value)
return response return response
def get_presidio_settings_from_request_data(
self, data: dict
) -> Optional[PresidioPerRequestConfig]:
if "metadata" in data:
_metadata = data["metadata"]
_guardrail_config = _metadata.get("guardrail_config")
_presidio_config = PresidioPerRequestConfig(**_guardrail_config)
return _presidio_config
return None
def print_verbose(self, print_statement):
try:
verbose_proxy_logger.debug(print_statement)
if litellm.set_verbose:
print(print_statement) # noqa
except:
pass

View file

@ -420,6 +420,10 @@ def move_guardrails_to_metadata(
data[_metadata_variable_name]["guardrails"] = data["guardrails"] data[_metadata_variable_name]["guardrails"] = data["guardrails"]
del data["guardrails"] del data["guardrails"]
if "guardrail_config" in data:
data[_metadata_variable_name]["guardrail_config"] = data["guardrail_config"]
del data["guardrail_config"]
def add_provider_specific_headers_to_request( def add_provider_specific_headers_to_request(
data: dict, data: dict,