diff --git a/docs/my-website/docs/proxy/logging.md b/docs/my-website/docs/proxy/logging.md index 932ea4f57..2c03cf728 100644 --- a/docs/my-website/docs/proxy/logging.md +++ b/docs/my-website/docs/proxy/logging.md @@ -3,7 +3,7 @@ import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; -# 🔎 Logging - Custom Callbacks, DataDog, Langfuse, s3 Bucket, Sentry, OpenTelemetry, Athina +# 🔎 Logging - Custom Callbacks, DataDog, Langfuse, s3 Bucket, Sentry, OpenTelemetry, Athina, Azure Content-Safety Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTelemetry, LangFuse, DynamoDB, s3 Bucket @@ -17,6 +17,7 @@ Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTeleme - [Logging to Sentry](#logging-proxy-inputoutput---sentry) - [Logging to Traceloop (OpenTelemetry)](#logging-proxy-inputoutput-traceloop-opentelemetry) - [Logging to Athina](#logging-proxy-inputoutput-athina) +- [Moderation with Azure Content-Safety](#moderation-with-azure-content-safety) ## Custom Callback Class [Async] Use this when you want to run custom callbacks in `python` @@ -1037,3 +1038,86 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \ ] }' ``` + +## Moderation with Azure Content Safety + +[Azure Content-Safety](https://azure.microsoft.com/en-us/products/ai-services/ai-content-safety) is a Microsoft Azure service that provides content moderation APIs to detect potential offensive, harmful, or risky content in text. + +We will use the `--config` to set `litellm.success_callback = ["azure_content_safety"]` this will moderate all LLM calls using Azure Content Safety. + +**Step 0** Deploy Azure Content Safety + +Deploy an Azure Content-Safety instance from the Azure Portal and get the `endpoint` and `key`. + +**Step 1** Set Athina API key + +```shell +AZURE_CONTENT_SAFETY_KEY = "" +``` + +**Step 2**: Create a `config.yaml` file and set `litellm_settings`: `success_callback` +```yaml +model_list: + - model_name: gpt-3.5-turbo + litellm_params: + model: gpt-3.5-turbo +litellm_settings: + callbacks: ["azure_content_safety"] + azure_content_safety_params: + endpoint: "" + key: "os.environ/AZURE_CONTENT_SAFETY_KEY" +``` + +**Step 3**: Start the proxy, make a test request + +Start proxy +```shell +litellm --config config.yaml --debug +``` + +Test Request +``` +curl --location 'http://0.0.0.0:4000/chat/completions' \ + --header 'Content-Type: application/json' \ + --data ' { + "model": "gpt-3.5-turbo", + "messages": [ + { + "role": "user", + "content": "Hi, how are you?" + } + ] + }' +``` + +An HTTP 400 error will be returned if the content is detected with a value greater than the threshold set in the `config.yaml`. +The details of the response will describe : +- The `source` : input text or llm generated text +- The `category` : the category of the content that triggered the moderation +- The `severity` : the severity from 0 to 10 + +**Step 4**: Customizing Azure Content Safety Thresholds + +You can customize the thresholds for each category by setting the `thresholds` in the `config.yaml` + +```yaml +model_list: + - model_name: gpt-3.5-turbo + litellm_params: + model: gpt-3.5-turbo +litellm_settings: + callbacks: ["azure_content_safety"] + azure_content_safety_params: + endpoint: "" + key: "os.environ/AZURE_CONTENT_SAFETY_KEY" + thresholds: + Hate: 6 + SelfHarm: 8 + Sexual: 6 + Violence: 4 +``` + +:::info +`thresholds` are not required by default, but you can tune the values to your needs. +Default values is `4` for all categories +::: \ No newline at end of file diff --git a/litellm/proxy/hooks/azure_content_safety.py b/litellm/proxy/hooks/azure_content_safety.py new file mode 100644 index 000000000..433571c15 --- /dev/null +++ b/litellm/proxy/hooks/azure_content_safety.py @@ -0,0 +1,146 @@ +from litellm.integrations.custom_logger import CustomLogger +from litellm.caching import DualCache +from litellm.proxy._types import UserAPIKeyAuth +import litellm, traceback, sys, uuid +from fastapi import HTTPException +from litellm._logging import verbose_proxy_logger + + +class _PROXY_AzureContentSafety( + CustomLogger +): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class + # Class variables or attributes + + def __init__(self, endpoint, api_key, thresholds=None): + try: + from azure.ai.contentsafety.aio import ContentSafetyClient + from azure.core.credentials import AzureKeyCredential + from azure.ai.contentsafety.models import ( + TextCategory, + AnalyzeTextOptions, + AnalyzeTextOutputType, + ) + from azure.core.exceptions import HttpResponseError + except Exception as e: + raise Exception( + f"\033[91mAzure Content-Safety not installed, try running 'pip install azure-ai-contentsafety' to fix this error: {e}\n{traceback.format_exc()}\033[0m" + ) + self.endpoint = endpoint + self.api_key = api_key + self.text_category = TextCategory + self.analyze_text_options = AnalyzeTextOptions + self.analyze_text_output_type = AnalyzeTextOutputType + self.azure_http_error = HttpResponseError + + self.thresholds = self._configure_thresholds(thresholds) + + self.client = ContentSafetyClient( + self.endpoint, AzureKeyCredential(self.api_key) + ) + + def _configure_thresholds(self, thresholds=None): + default_thresholds = { + self.text_category.HATE: 4, + self.text_category.SELF_HARM: 4, + self.text_category.SEXUAL: 4, + self.text_category.VIOLENCE: 4, + } + + if thresholds is None: + return default_thresholds + + for key, default in default_thresholds.items(): + if key not in thresholds: + thresholds[key] = default + + return thresholds + + def _compute_result(self, response): + result = {} + + category_severity = { + item.category: item.severity for item in response.categories_analysis + } + for category in self.text_category: + severity = category_severity.get(category) + if severity is not None: + result[category] = { + "filtered": severity >= self.thresholds[category], + "severity": severity, + } + + return result + + async def test_violation(self, content: str, source: str = None): + verbose_proxy_logger.debug("Testing Azure Content-Safety for: %s", content) + + # Construct a request + request = self.analyze_text_options( + text=content, + output_type=self.analyze_text_output_type.EIGHT_SEVERITY_LEVELS, + ) + + # Analyze text + try: + response = await self.client.analyze_text(request) + except self.azure_http_error as e: + verbose_proxy_logger.debug( + "Error in Azure Content-Safety: %s", traceback.format_exc() + ) + traceback.print_exc() + raise + + result = self._compute_result(response) + verbose_proxy_logger.debug("Azure Content-Safety Result: %s", result) + + for key, value in result.items(): + if value["filtered"]: + raise HTTPException( + status_code=400, + detail={ + "error": "Violated content safety policy", + "source": source, + "category": key, + "severity": value["severity"], + }, + ) + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: str, # "completion", "embeddings", "image_generation", "moderation" + ): + verbose_proxy_logger.debug("Inside Azure Content-Safety Pre-Call Hook") + try: + if call_type == "completion" and "messages" in data: + for m in data["messages"]: + if "content" in m and isinstance(m["content"], str): + await self.test_violation(content=m["content"], source="input") + + except HTTPException as e: + raise e + except Exception as e: + traceback.print_exc() + + async def async_post_call_success_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + response, + ): + verbose_proxy_logger.debug("Inside Azure Content-Safety Post-Call Hook") + if isinstance(response, litellm.ModelResponse) and isinstance( + response.choices[0], litellm.utils.Choices + ): + await self.test_violation( + content=response.choices[0].message.content, source="output" + ) + + # async def async_post_call_streaming_hook( + # self, + # user_api_key_dict: UserAPIKeyAuth, + # response: str, + # ): + # verbose_proxy_logger.debug("Inside Azure Content-Safety Call-Stream Hook") + # await self.test_violation(content=response, source="output") diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 46c132773..779df7800 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2255,6 +2255,23 @@ class ProxyConfig: batch_redis_obj = _PROXY_BatchRedisRequests() imported_list.append(batch_redis_obj) + elif ( + isinstance(callback, str) + and callback == "azure_content_safety" + ): + from litellm.proxy.hooks.azure_content_safety import ( + _PROXY_AzureContentSafety, + ) + + azure_content_safety_params = litellm_settings["azure_content_safety_params"] + for k, v in azure_content_safety_params.items(): + if v is not None and isinstance(v, str) and v.startswith("os.environ/"): + azure_content_safety_params[k] = litellm.get_secret(v) + + azure_content_safety_obj = _PROXY_AzureContentSafety( + **azure_content_safety_params, + ) + imported_list.append(azure_content_safety_obj) else: imported_list.append( get_instance_fn( diff --git a/litellm/tests/test_azure_content_safety.py b/litellm/tests/test_azure_content_safety.py new file mode 100644 index 000000000..3cc31003a --- /dev/null +++ b/litellm/tests/test_azure_content_safety.py @@ -0,0 +1,267 @@ +# What is this? +## Unit test for azure content safety +import sys, os, asyncio, time, random +from datetime import datetime +import traceback +from dotenv import load_dotenv +from fastapi import HTTPException + +load_dotenv() +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest +import litellm +from litellm.proxy.hooks.azure_content_safety import _PROXY_AzureContentSafety +from litellm import Router, mock_completion +from litellm.proxy.utils import ProxyLogging +from litellm.proxy._types import UserAPIKeyAuth +from litellm.caching import DualCache + + +@pytest.mark.asyncio +async def test_strict_input_filtering_01(): + """ + - have a response with a filtered input + - call the pre call hook + """ + azure_content_safety = _PROXY_AzureContentSafety( + endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"), + api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"), + thresholds={"Hate": 2}, + ) + + data = { + "messages": [ + {"role": "system", "content": "You are an helpfull assistant"}, + {"role": "user", "content": "Fuck yourself you stupid bitch"}, + ] + } + + with pytest.raises(HTTPException) as exc_info: + await azure_content_safety.async_pre_call_hook( + user_api_key_dict=UserAPIKeyAuth(), + cache=DualCache(), + data=data, + call_type="completion", + ) + + assert exc_info.value.detail["source"] == "input" + assert exc_info.value.detail["category"] == "Hate" + assert exc_info.value.detail["severity"] == 2 + + +@pytest.mark.asyncio +async def test_strict_input_filtering_02(): + """ + - have a response with a filtered input + - call the pre call hook + """ + azure_content_safety = _PROXY_AzureContentSafety( + endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"), + api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"), + thresholds={"Hate": 2}, + ) + + data = { + "messages": [ + {"role": "system", "content": "You are an helpfull assistant"}, + {"role": "user", "content": "Hello how are you ?"}, + ] + } + + await azure_content_safety.async_pre_call_hook( + user_api_key_dict=UserAPIKeyAuth(), + cache=DualCache(), + data=data, + call_type="completion", + ) + + +@pytest.mark.asyncio +async def test_loose_input_filtering_01(): + """ + - have a response with a filtered input + - call the pre call hook + """ + azure_content_safety = _PROXY_AzureContentSafety( + endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"), + api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"), + thresholds={"Hate": 8}, + ) + + data = { + "messages": [ + {"role": "system", "content": "You are an helpfull assistant"}, + {"role": "user", "content": "Fuck yourself you stupid bitch"}, + ] + } + + await azure_content_safety.async_pre_call_hook( + user_api_key_dict=UserAPIKeyAuth(), + cache=DualCache(), + data=data, + call_type="completion", + ) + + +@pytest.mark.asyncio +async def test_loose_input_filtering_02(): + """ + - have a response with a filtered input + - call the pre call hook + """ + azure_content_safety = _PROXY_AzureContentSafety( + endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"), + api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"), + thresholds={"Hate": 8}, + ) + + data = { + "messages": [ + {"role": "system", "content": "You are an helpfull assistant"}, + {"role": "user", "content": "Hello how are you ?"}, + ] + } + + await azure_content_safety.async_pre_call_hook( + user_api_key_dict=UserAPIKeyAuth(), + cache=DualCache(), + data=data, + call_type="completion", + ) + + +@pytest.mark.asyncio +async def test_strict_output_filtering_01(): + """ + - have a response with a filtered output + - call the post call hook + """ + azure_content_safety = _PROXY_AzureContentSafety( + endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"), + api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"), + thresholds={"Hate": 2}, + ) + + response = mock_completion( + model="gpt-3.5-turbo", + messages=[ + { + "role": "system", + "content": "You are a song writer expert. You help users to write songs about any topic in any genre.", + }, + { + "role": "user", + "content": "Help me write a rap text song. Add some insults to make it more credible.", + }, + ], + mock_response="I'm the king of the mic, you're just a fucking dick. Don't fuck with me your stupid bitch.", + ) + + with pytest.raises(HTTPException) as exc_info: + await azure_content_safety.async_post_call_success_hook( + user_api_key_dict=UserAPIKeyAuth(), response=response + ) + + assert exc_info.value.detail["source"] == "output" + assert exc_info.value.detail["category"] == "Hate" + assert exc_info.value.detail["severity"] == 2 + + +@pytest.mark.asyncio +async def test_strict_output_filtering_02(): + """ + - have a response with a filtered output + - call the post call hook + """ + azure_content_safety = _PROXY_AzureContentSafety( + endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"), + api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"), + thresholds={"Hate": 2}, + ) + + response = mock_completion( + model="gpt-3.5-turbo", + messages=[ + { + "role": "system", + "content": "You are a song writer expert. You help users to write songs about any topic in any genre.", + }, + { + "role": "user", + "content": "Help me write a rap text song. Add some insults to make it more credible.", + }, + ], + mock_response="I'm unable to help with you with hate speech", + ) + + await azure_content_safety.async_post_call_success_hook( + user_api_key_dict=UserAPIKeyAuth(), response=response + ) + + +@pytest.mark.asyncio +async def test_loose_output_filtering_01(): + """ + - have a response with a filtered output + - call the post call hook + """ + azure_content_safety = _PROXY_AzureContentSafety( + endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"), + api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"), + thresholds={"Hate": 8}, + ) + + response = mock_completion( + model="gpt-3.5-turbo", + messages=[ + { + "role": "system", + "content": "You are a song writer expert. You help users to write songs about any topic in any genre.", + }, + { + "role": "user", + "content": "Help me write a rap text song. Add some insults to make it more credible.", + }, + ], + mock_response="I'm the king of the mic, you're just a fucking dick. Don't fuck with me your stupid bitch.", + ) + + await azure_content_safety.async_post_call_success_hook( + user_api_key_dict=UserAPIKeyAuth(), response=response + ) + + +@pytest.mark.asyncio +async def test_loose_output_filtering_02(): + """ + - have a response with a filtered output + - call the post call hook + """ + azure_content_safety = _PROXY_AzureContentSafety( + endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"), + api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"), + thresholds={"Hate": 8}, + ) + + response = mock_completion( + model="gpt-3.5-turbo", + messages=[ + { + "role": "system", + "content": "You are a song writer expert. You help users to write songs about any topic in any genre.", + }, + { + "role": "user", + "content": "Help me write a rap text song. Add some insults to make it more credible.", + }, + ], + mock_response="I'm unable to help with you with hate speech", + ) + + await azure_content_safety.async_post_call_success_hook( + user_api_key_dict=UserAPIKeyAuth(), response=response + ) diff --git a/requirements.txt b/requirements.txt index 7740e0752..d16f0453f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,6 +26,8 @@ fastapi-sso==0.10.0 # admin UI, SSO pyjwt[crypto]==2.8.0 python-multipart==0.0.9 # admin UI Pillow==10.3.0 +azure-ai-contentsafety==1.0.0 # for azure content safety +azure-identity==1.15.0 # for azure content safety ### LITELLM PACKAGE DEPENDENCIES python-dotenv==1.0.0 # for env