litellm-mirror/litellm/proxy/guardrails/guardrail_hooks/presidio.py
Krish Dholakia 234185ec13
LiteLLM Minor Fixes & Improvements (09/16/2024) (#5723) (#5731)
* LiteLLM Minor Fixes & Improvements (09/16/2024)  (#5723)

* coverage (#5713)

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Move (#5714)

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* fix(litellm_logging.py): fix logging client re-init (#5710)

Fixes https://github.com/BerriAI/litellm/issues/5695

* fix(presidio.py): Fix logging_hook response and add support for additional presidio variables in guardrails config

Fixes https://github.com/BerriAI/litellm/issues/5682

* feat(o1_handler.py): fake streaming for openai o1 models

Fixes https://github.com/BerriAI/litellm/issues/5694

* docs: deprecated traceloop integration in favor of native otel (#5249)

* fix: fix linting errors

* fix: fix linting errors

* fix(main.py): fix o1 import

---------

Signed-off-by: dbczumar <corey.zumar@databricks.com>
Co-authored-by: Corey Zumar <39497902+dbczumar@users.noreply.github.com>
Co-authored-by: Nir Gazit <nirga@users.noreply.github.com>

* feat(spend_management_endpoints.py): expose `/global/spend/refresh` endpoint for updating material view (#5730)

* feat(spend_management_endpoints.py): expose `/global/spend/refresh` endpoint for updating material view

Supports having `MonthlyGlobalSpend` view be a material view, and exposes an endpoint to refresh it

* fix(custom_logger.py): reset calltype

* fix: fix linting errors

* fix: fix linting error

* fix: fix import

* test(test_databricks.py): fix databricks tests

---------

Signed-off-by: dbczumar <corey.zumar@databricks.com>
Co-authored-by: Corey Zumar <39497902+dbczumar@users.noreply.github.com>
Co-authored-by: Nir Gazit <nirga@users.noreply.github.com>
2024-09-17 08:05:52 -07:00

341 lines
13 KiB
Python

# +-----------------------------------------------+
# | |
# | PII Masking |
# | with Microsoft Presidio |
# | https://github.com/BerriAI/litellm/issues/ |
# +-----------------------------------------------+
#
# Tell us how we can improve! - Krrish & Ishaan
import asyncio
import json
import traceback
import uuid
from typing import Any, List, Optional, Tuple, Union
import aiohttp
from fastapi import HTTPException
from pydantic import BaseModel
import litellm # noqa: E401
from litellm import get_secret
from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.proxy._types import UserAPIKeyAuth
from litellm.utils import (
EmbeddingResponse,
ImageResponse,
ModelResponse,
StreamingChoices,
get_formatted_prompt,
)
class PresidioPerRequestConfig(BaseModel):
"""
presdio params that can be controlled per request, api key
"""
language: Optional[str] = None
class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
user_api_key_cache = None
ad_hoc_recognizers = None
# Class variables or attributes
def __init__(
self,
mock_testing: bool = False,
mock_redacted_text: Optional[dict] = None,
presidio_analyzer_api_base: Optional[str] = None,
presidio_anonymizer_api_base: Optional[str] = None,
output_parse_pii: Optional[bool] = False,
presidio_ad_hoc_recognizers: Optional[str] = None,
**kwargs,
):
self.pii_tokens: dict = (
{}
) # mapping of PII token to original text - only used with Presidio `replace` operation
self.mock_redacted_text = mock_redacted_text
self.output_parse_pii = output_parse_pii or False
if mock_testing is True: # for testing purposes only
return
ad_hoc_recognizers = presidio_ad_hoc_recognizers
if ad_hoc_recognizers is not None:
try:
with open(ad_hoc_recognizers, "r") as file:
self.ad_hoc_recognizers = json.load(file)
except FileNotFoundError:
raise Exception(f"File not found. file_path={ad_hoc_recognizers}")
except json.JSONDecodeError as e:
raise Exception(
f"Error decoding JSON file: {str(e)}, file_path={ad_hoc_recognizers}"
)
except Exception as e:
raise Exception(
f"An error occurred: {str(e)}, file_path={ad_hoc_recognizers}"
)
self.validate_environment(
presidio_analyzer_api_base=presidio_analyzer_api_base,
presidio_anonymizer_api_base=presidio_anonymizer_api_base,
)
super().__init__(**kwargs)
def validate_environment(
self,
presidio_analyzer_api_base: Optional[str] = None,
presidio_anonymizer_api_base: Optional[str] = None,
):
self.presidio_analyzer_api_base: Optional[str] = (
presidio_analyzer_api_base or get_secret("PRESIDIO_ANALYZER_API_BASE", None) # type: ignore
)
self.presidio_anonymizer_api_base: Optional[
str
] = presidio_anonymizer_api_base or litellm.get_secret(
"PRESIDIO_ANONYMIZER_API_BASE", None
) # type: ignore
if self.presidio_analyzer_api_base is None:
raise Exception("Missing `PRESIDIO_ANALYZER_API_BASE` from environment")
if not self.presidio_analyzer_api_base.endswith("/"):
self.presidio_analyzer_api_base += "/"
if not (
self.presidio_analyzer_api_base.startswith("http://")
or self.presidio_analyzer_api_base.startswith("https://")
):
# add http:// if unset, assume communicating over private network - e.g. render
self.presidio_analyzer_api_base = (
"http://" + self.presidio_analyzer_api_base
)
if self.presidio_anonymizer_api_base is None:
raise Exception("Missing `PRESIDIO_ANONYMIZER_API_BASE` from environment")
if not self.presidio_anonymizer_api_base.endswith("/"):
self.presidio_anonymizer_api_base += "/"
if not (
self.presidio_anonymizer_api_base.startswith("http://")
or self.presidio_anonymizer_api_base.startswith("https://")
):
# add http:// if unset, assume communicating over private network - e.g. render
self.presidio_anonymizer_api_base = (
"http://" + self.presidio_anonymizer_api_base
)
async def check_pii(
self,
text: str,
output_parse_pii: bool,
presidio_config: Optional[PresidioPerRequestConfig],
) -> str:
"""
[TODO] make this more performant for high-throughput scenario
"""
try:
async with aiohttp.ClientSession() as session:
if self.mock_redacted_text is not None:
redacted_text = self.mock_redacted_text
else:
# Make the first request to /analyze
# Construct Request 1
analyze_url = f"{self.presidio_analyzer_api_base}analyze"
analyze_payload = {"text": text, "language": "en"}
if presidio_config and presidio_config.language:
analyze_payload["language"] = presidio_config.language
if self.ad_hoc_recognizers is not None:
analyze_payload["ad_hoc_recognizers"] = self.ad_hoc_recognizers
# End of constructing Request 1
redacted_text = None
verbose_proxy_logger.debug(
"Making request to: %s with payload: %s",
analyze_url,
analyze_payload,
)
async with session.post(
analyze_url, json=analyze_payload
) as response:
analyze_results = await response.json()
# Make the second request to /anonymize
anonymize_url = f"{self.presidio_anonymizer_api_base}anonymize"
verbose_proxy_logger.debug("Making request to: %s", anonymize_url)
anonymize_payload = {
"text": text,
"analyzer_results": analyze_results,
}
async with session.post(
anonymize_url, json=anonymize_payload
) as response:
redacted_text = await response.json()
new_text = text
if redacted_text is not None:
verbose_proxy_logger.debug("redacted_text: %s", redacted_text)
for item in redacted_text["items"]:
start = item["start"]
end = item["end"]
replacement = item["text"] # replacement token
if item["operator"] == "replace" and output_parse_pii == True:
# check if token in dict
# if exists, add a uuid to the replacement token for swapping back to the original text in llm response output parsing
if replacement in self.pii_tokens:
replacement = replacement + str(uuid.uuid4())
self.pii_tokens[replacement] = new_text[
start:end
] # get text it'll replace
new_text = new_text[:start] + replacement + new_text[end:]
return redacted_text["text"]
else:
raise Exception(f"Invalid anonymizer response: {redacted_text}")
except Exception as e:
raise e
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
):
"""
- Check if request turned off pii
- Check if user allowed to turn off pii (key permissions -> 'allow_pii_controls')
- Take the request data
- Call /analyze -> get the results
- Call /anonymize w/ the analyze results -> get the redacted text
For multiple messages in /chat/completions, we'll need to call them in parallel.
"""
try:
content_safety = data.get("content_safety", None)
verbose_proxy_logger.debug("content_safety: %s", content_safety)
presidio_config = self.get_presidio_settings_from_request_data(data)
if call_type == "completion": # /chat/completions requests
messages = data["messages"]
tasks = []
for m in messages:
if isinstance(m["content"], str):
tasks.append(
self.check_pii(
text=m["content"],
output_parse_pii=self.output_parse_pii,
presidio_config=presidio_config,
)
)
responses = await asyncio.gather(*tasks)
for index, r in enumerate(responses):
if isinstance(messages[index]["content"], str):
messages[index][
"content"
] = r # replace content with redacted string
verbose_proxy_logger.info(
f"Presidio PII Masking: Redacted pii message: {data['messages']}"
)
return data
except Exception as e:
raise e
async def async_logging_hook(
self, kwargs: dict, result: Any, call_type: str
) -> Tuple[dict, Any]:
"""
Masks the input before logging to langfuse, datadog, etc.
"""
if (
call_type == "completion" or call_type == "acompletion"
): # /chat/completions requests
messages: Optional[List] = kwargs.get("messages", None)
tasks = []
if messages is None:
return kwargs, result
presidio_config = self.get_presidio_settings_from_request_data(kwargs)
for m in messages:
text_str = ""
if m["content"] is None:
continue
if isinstance(m["content"], str):
text_str = m["content"]
tasks.append(
self.check_pii(
text=text_str,
output_parse_pii=False,
presidio_config=presidio_config,
)
) # need to pass separately b/c presidio has context window limits
responses = await asyncio.gather(*tasks)
for index, r in enumerate(responses):
if isinstance(messages[index]["content"], str):
messages[index][
"content"
] = r # replace content with redacted string
verbose_proxy_logger.info(
f"Presidio PII Masking: Redacted pii message: {messages}"
)
kwargs["messages"] = messages
return kwargs, result
async def async_post_call_success_hook( # type: ignore
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
):
"""
Output parse the response object to replace the masked tokens with user sent values
"""
verbose_proxy_logger.debug(
f"PII Masking Args: self.output_parse_pii={self.output_parse_pii}; type of response={type(response)}"
)
if self.output_parse_pii is False:
return response
if isinstance(response, ModelResponse) and not isinstance(
response.choices[0], StreamingChoices
): # /chat/completions requests
if isinstance(response.choices[0].message.content, str):
verbose_proxy_logger.debug(
f"self.pii_tokens: {self.pii_tokens}; initial response: {response.choices[0].message.content}"
)
for key, value in self.pii_tokens.items():
response.choices[0].message.content = response.choices[
0
].message.content.replace(key, value)
return response
def get_presidio_settings_from_request_data(
self, data: dict
) -> Optional[PresidioPerRequestConfig]:
if "metadata" in data:
_metadata = data["metadata"]
_guardrail_config = _metadata.get("guardrail_config")
if _guardrail_config:
_presidio_config = PresidioPerRequestConfig(**_guardrail_config)
return _presidio_config
return None
def print_verbose(self, print_statement):
try:
verbose_proxy_logger.debug(print_statement)
if litellm.set_verbose:
print(print_statement) # noqa
except:
pass