From 6cec252b076012fd3ea5a8b06d603e05c38dd789 Mon Sep 17 00:00:00 2001 From: Lunik Date: Thu, 2 May 2024 23:12:48 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20Add=20Azure=20Content-Safet?= =?UTF-8?q?y=20Proxy=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Lunik --- litellm/proxy/hooks/azure_content_safety.py | 157 ++++++++++++++++++++ litellm/proxy/proxy_server.py | 17 +++ 2 files changed, 174 insertions(+) create mode 100644 litellm/proxy/hooks/azure_content_safety.py diff --git a/litellm/proxy/hooks/azure_content_safety.py b/litellm/proxy/hooks/azure_content_safety.py new file mode 100644 index 000000000..161e35cde --- /dev/null +++ b/litellm/proxy/hooks/azure_content_safety.py @@ -0,0 +1,157 @@ +from litellm.integrations.custom_logger import CustomLogger +from litellm.caching import DualCache +from litellm.proxy._types import UserAPIKeyAuth +import litellm, traceback, sys, uuid +from fastapi import HTTPException +from litellm._logging import verbose_proxy_logger + + +class _PROXY_AzureContentSafety( + CustomLogger +): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class + # Class variables or attributes + + def __init__(self, endpoint, api_key, thresholds=None): + try: + from azure.ai.contentsafety.aio import ContentSafetyClient + from azure.core.credentials import AzureKeyCredential + from azure.ai.contentsafety.models import TextCategory + from azure.ai.contentsafety.models import AnalyzeTextOptions + from azure.core.exceptions import HttpResponseError + except Exception as e: + raise Exception( + f"\033[91mAzure Content-Safety not installed, try running 'pip install azure-ai-contentsafety' to fix this error: {e}\n{traceback.format_exc()}\033[0m" + ) + self.endpoint = endpoint + self.api_key = api_key + self.text_category = TextCategory + self.analyze_text_options = AnalyzeTextOptions + self.azure_http_error = HttpResponseError + + self.thresholds = self._configure_thresholds(thresholds) + + self.client = ContentSafetyClient( + self.endpoint, AzureKeyCredential(self.api_key) + ) + + def _configure_thresholds(self, thresholds=None): + default_thresholds = { + self.text_category.HATE: 6, + self.text_category.SELF_HARM: 6, + self.text_category.SEXUAL: 6, + self.text_category.VIOLENCE: 6, + } + + if thresholds is None: + return default_thresholds + + for key, default in default_thresholds.items(): + if key not in thresholds: + thresholds[key] = default + + return thresholds + + def print_verbose(self, print_statement): + try: + verbose_proxy_logger.debug(print_statement) + if litellm.set_verbose: + print(print_statement) # noqa + except: + pass + + def _severity(self, severity): + if severity >= 6: + return "high" + elif severity >= 4: + return "medium" + elif severity >= 2: + return "low" + else: + return "safe" + + def _compute_result(self, response): + result = {} + + category_severity = { + item.category: item.severity for item in response.categories_analysis + } + for category in self.text_category: + severity = category_severity.get(category) + if severity is not None: + result[category] = { + "filtered": severity >= self.thresholds[category], + "severity": self._severity(severity), + } + + return result + + async def test_violation(self, content: str, source: str = None): + self.print_verbose(f"Testing Azure Content-Safety for: {content}") + + # Construct a request + request = self.analyze_text_options(text=content) + + # Analyze text + try: + response = await self.client.analyze_text(request) + except self.azure_http_error as e: + self.print_verbose( + f"Error in Azure Content-Safety: {traceback.format_exc()}" + ) + traceback.print_exc() + raise + + result = self._compute_result(response) + self.print_verbose(f"Azure Content-Safety Result: {result}") + + for key, value in result.items(): + if value["filtered"]: + raise HTTPException( + status_code=400, + detail={ + "error": "Violated content safety policy", + "source": source, + "category": key, + "severity": value["severity"], + }, + ) + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: str, # "completion", "embeddings", "image_generation", "moderation" + ): + self.print_verbose(f"Inside Azure Content-Safety Pre-Call Hook") + try: + if call_type == "completion" and "messages" in data: + for m in data["messages"]: + if "content" in m and isinstance(m["content"], str): + await self.test_violation(content=m["content"], source="input") + + except HTTPException as e: + raise e + except Exception as e: + traceback.print_exc() + + async def async_post_call_success_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + response, + ): + self.print_verbose(f"Inside Azure Content-Safety Post-Call Hook") + if isinstance(response, litellm.ModelResponse) and isinstance( + response.choices[0], litellm.utils.Choices + ): + await self.test_violation( + content=response.choices[0].message.content, source="output" + ) + + async def async_post_call_streaming_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + response: str, + ): + self.print_verbose(f"Inside Azure Content-Safety Call-Stream Hook") + await self.test_violation(content=response, source="output") diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 9cc871966..9c74659cc 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2235,6 +2235,23 @@ class ProxyConfig: batch_redis_obj = _PROXY_BatchRedisRequests() imported_list.append(batch_redis_obj) + elif ( + isinstance(callback, str) + and callback == "azure_content_safety" + ): + from litellm.proxy.hooks.azure_content_safety import ( + _PROXY_AzureContentSafety, + ) + + azure_content_safety_params = litellm_settings["azure_content_safety_params"] + for k, v in azure_content_safety_params.items(): + if v is not None and isinstance(v, str) and v.startswith("os.environ/"): + azure_content_safety_params[k] = litellm.get_secret(v) + + azure_content_safety_obj = _PROXY_AzureContentSafety( + **azure_content_safety_params, + ) + imported_list.append(azure_content_safety_obj) else: imported_list.append( get_instance_fn(