forked from phoenix/litellm-mirror
fix allow setting language per call to presidio
This commit is contained in:
parent
d954413b14
commit
9b5164b38d
2 changed files with 60 additions and 14 deletions
|
@ -16,6 +16,7 @@ from typing import Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import litellm # noqa: E401
|
import litellm # noqa: E401
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
@ -32,6 +33,14 @@ from litellm.utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PresidioPerRequestConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
presdio params that can be controlled per request, api key
|
||||||
|
"""
|
||||||
|
|
||||||
|
language: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
||||||
user_api_key_cache = None
|
user_api_key_cache = None
|
||||||
ad_hoc_recognizers = None
|
ad_hoc_recognizers = None
|
||||||
|
@ -70,7 +79,6 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"An error occurred: {str(e)}, file_path={ad_hoc_recognizers}"
|
f"An error occurred: {str(e)}, file_path={ad_hoc_recognizers}"
|
||||||
)
|
)
|
||||||
self.async_http_client = _get_async_httpx_client()
|
|
||||||
self.validate_environment(
|
self.validate_environment(
|
||||||
presidio_analyzer_api_base=presidio_analyzer_api_base,
|
presidio_analyzer_api_base=presidio_analyzer_api_base,
|
||||||
presidio_anonymizer_api_base=presidio_anonymizer_api_base,
|
presidio_anonymizer_api_base=presidio_anonymizer_api_base,
|
||||||
|
@ -119,15 +127,12 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
||||||
"http://" + self.presidio_anonymizer_api_base
|
"http://" + self.presidio_anonymizer_api_base
|
||||||
)
|
)
|
||||||
|
|
||||||
def print_verbose(self, print_statement):
|
async def check_pii(
|
||||||
try:
|
self,
|
||||||
verbose_proxy_logger.debug(print_statement)
|
text: str,
|
||||||
if litellm.set_verbose:
|
output_parse_pii: bool,
|
||||||
print(print_statement) # noqa
|
presidio_config: Optional[PresidioPerRequestConfig],
|
||||||
except:
|
) -> str:
|
||||||
pass
|
|
||||||
|
|
||||||
async def check_pii(self, text: str, output_parse_pii: bool) -> str:
|
|
||||||
"""
|
"""
|
||||||
[TODO] make this more performant for high-throughput scenario
|
[TODO] make this more performant for high-throughput scenario
|
||||||
"""
|
"""
|
||||||
|
@ -137,13 +142,21 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
||||||
redacted_text = self.mock_redacted_text
|
redacted_text = self.mock_redacted_text
|
||||||
else:
|
else:
|
||||||
# Make the first request to /analyze
|
# Make the first request to /analyze
|
||||||
|
# Construct Request 1
|
||||||
analyze_url = f"{self.presidio_analyzer_api_base}analyze"
|
analyze_url = f"{self.presidio_analyzer_api_base}analyze"
|
||||||
verbose_proxy_logger.debug("Making request to: %s", analyze_url)
|
|
||||||
analyze_payload = {"text": text, "language": "en"}
|
analyze_payload = {"text": text, "language": "en"}
|
||||||
|
if presidio_config and presidio_config.language:
|
||||||
|
analyze_payload["language"] = presidio_config.language
|
||||||
if self.ad_hoc_recognizers is not None:
|
if self.ad_hoc_recognizers is not None:
|
||||||
analyze_payload["ad_hoc_recognizers"] = self.ad_hoc_recognizers
|
analyze_payload["ad_hoc_recognizers"] = self.ad_hoc_recognizers
|
||||||
redacted_text = None
|
# End of constructing Request 1
|
||||||
|
|
||||||
|
redacted_text = None
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"Making request to: %s with payload: %s",
|
||||||
|
analyze_url,
|
||||||
|
analyze_payload,
|
||||||
|
)
|
||||||
async with session.post(
|
async with session.post(
|
||||||
analyze_url, json=analyze_payload
|
analyze_url, json=analyze_payload
|
||||||
) as response:
|
) as response:
|
||||||
|
@ -275,6 +288,8 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
||||||
if no_pii is True: # turn off pii masking
|
if no_pii is True: # turn off pii masking
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
presidio_config = self.get_presidio_settings_from_request_data(data)
|
||||||
|
|
||||||
if call_type == "completion": # /chat/completions requests
|
if call_type == "completion": # /chat/completions requests
|
||||||
messages = data["messages"]
|
messages = data["messages"]
|
||||||
tasks = []
|
tasks = []
|
||||||
|
@ -283,7 +298,9 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
||||||
if isinstance(m["content"], str):
|
if isinstance(m["content"], str):
|
||||||
tasks.append(
|
tasks.append(
|
||||||
self.check_pii(
|
self.check_pii(
|
||||||
text=m["content"], output_parse_pii=output_parse_pii
|
text=m["content"],
|
||||||
|
output_parse_pii=output_parse_pii,
|
||||||
|
presidio_config=presidio_config,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
responses = await asyncio.gather(*tasks)
|
responses = await asyncio.gather(*tasks)
|
||||||
|
@ -317,6 +334,8 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
||||||
if messages is None:
|
if messages is None:
|
||||||
return kwargs, result
|
return kwargs, result
|
||||||
|
|
||||||
|
presidio_config = self.get_presidio_settings_from_request_data(kwargs)
|
||||||
|
|
||||||
for m in messages:
|
for m in messages:
|
||||||
text_str = ""
|
text_str = ""
|
||||||
if m["content"] is None:
|
if m["content"] is None:
|
||||||
|
@ -324,7 +343,11 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
||||||
if isinstance(m["content"], str):
|
if isinstance(m["content"], str):
|
||||||
text_str = m["content"]
|
text_str = m["content"]
|
||||||
tasks.append(
|
tasks.append(
|
||||||
self.check_pii(text=text_str, output_parse_pii=False)
|
self.check_pii(
|
||||||
|
text=text_str,
|
||||||
|
output_parse_pii=False,
|
||||||
|
presidio_config=presidio_config,
|
||||||
|
)
|
||||||
) # need to pass separately b/c presidio has context window limits
|
) # need to pass separately b/c presidio has context window limits
|
||||||
responses = await asyncio.gather(*tasks)
|
responses = await asyncio.gather(*tasks)
|
||||||
for index, r in enumerate(responses):
|
for index, r in enumerate(responses):
|
||||||
|
@ -366,3 +389,22 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
||||||
0
|
0
|
||||||
].message.content.replace(key, value)
|
].message.content.replace(key, value)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
def get_presidio_settings_from_request_data(
|
||||||
|
self, data: dict
|
||||||
|
) -> Optional[PresidioPerRequestConfig]:
|
||||||
|
if "metadata" in data:
|
||||||
|
_metadata = data["metadata"]
|
||||||
|
_guardrail_config = _metadata.get("guardrail_config")
|
||||||
|
_presidio_config = PresidioPerRequestConfig(**_guardrail_config)
|
||||||
|
return _presidio_config
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def print_verbose(self, print_statement):
|
||||||
|
try:
|
||||||
|
verbose_proxy_logger.debug(print_statement)
|
||||||
|
if litellm.set_verbose:
|
||||||
|
print(print_statement) # noqa
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
|
@ -420,6 +420,10 @@ def move_guardrails_to_metadata(
|
||||||
data[_metadata_variable_name]["guardrails"] = data["guardrails"]
|
data[_metadata_variable_name]["guardrails"] = data["guardrails"]
|
||||||
del data["guardrails"]
|
del data["guardrails"]
|
||||||
|
|
||||||
|
if "guardrail_config" in data:
|
||||||
|
data[_metadata_variable_name]["guardrail_config"] = data["guardrail_config"]
|
||||||
|
del data["guardrail_config"]
|
||||||
|
|
||||||
|
|
||||||
def add_provider_specific_headers_to_request(
|
def add_provider_specific_headers_to_request(
|
||||||
data: dict,
|
data: dict,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue