Merge pull request #3407 from Lunik/feat/add-azure-content-filter

 feat: Add Azure Content-Safety Proxy hooks
This commit is contained in:
Krish Dholakia 2024-05-11 09:30:46 -07:00 committed by GitHub
commit 54e768afa7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 517 additions and 1 deletions

View file

@ -3,7 +3,7 @@ import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# 🔎 Logging - Custom Callbacks, DataDog, Langfuse, s3 Bucket, Sentry, OpenTelemetry, Athina
# 🔎 Logging - Custom Callbacks, DataDog, Langfuse, s3 Bucket, Sentry, OpenTelemetry, Athina, Azure Content-Safety
Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTelemetry, LangFuse, DynamoDB, s3 Bucket
@ -17,6 +17,7 @@ Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTeleme
- [Logging to Sentry](#logging-proxy-inputoutput---sentry)
- [Logging to Traceloop (OpenTelemetry)](#logging-proxy-inputoutput-traceloop-opentelemetry)
- [Logging to Athina](#logging-proxy-inputoutput-athina)
- [Moderation with Azure Content-Safety](#moderation-with-azure-content-safety)
## Custom Callback Class [Async]
Use this when you want to run custom callbacks in `python`
@ -1037,3 +1038,86 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
]
}'
```
## Moderation with Azure Content Safety
[Azure Content-Safety](https://azure.microsoft.com/en-us/products/ai-services/ai-content-safety) is a Microsoft Azure service that provides content moderation APIs to detect potential offensive, harmful, or risky content in text.
We will use the `--config` to set `litellm.success_callback = ["azure_content_safety"]` this will moderate all LLM calls using Azure Content Safety.
**Step 0** Deploy Azure Content Safety
Deploy an Azure Content-Safety instance from the Azure Portal and get the `endpoint` and `key`.
**Step 1** Set Athina API key
```shell
AZURE_CONTENT_SAFETY_KEY = "<your-azure-content-safety-key>"
```
**Step 2**: Create a `config.yaml` file and set `litellm_settings`: `success_callback`
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
litellm_settings:
callbacks: ["azure_content_safety"]
azure_content_safety_params:
endpoint: "<your-azure-content-safety-endpoint>"
key: "os.environ/AZURE_CONTENT_SAFETY_KEY"
```
**Step 3**: Start the proxy, make a test request
Start proxy
```shell
litellm --config config.yaml --debug
```
Test Request
```
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "Hi, how are you?"
}
]
}'
```
An HTTP 400 error will be returned if the content is detected with a value greater than the threshold set in the `config.yaml`.
The details of the response will describe :
- The `source` : input text or llm generated text
- The `category` : the category of the content that triggered the moderation
- The `severity` : the severity from 0 to 10
**Step 4**: Customizing Azure Content Safety Thresholds
You can customize the thresholds for each category by setting the `thresholds` in the `config.yaml`
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
litellm_settings:
callbacks: ["azure_content_safety"]
azure_content_safety_params:
endpoint: "<your-azure-content-safety-endpoint>"
key: "os.environ/AZURE_CONTENT_SAFETY_KEY"
thresholds:
Hate: 6
SelfHarm: 8
Sexual: 6
Violence: 4
```
:::info
`thresholds` are not required by default, but you can tune the values to your needs.
Default values is `4` for all categories
:::

View file

@ -0,0 +1,146 @@
from litellm.integrations.custom_logger import CustomLogger
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
import litellm, traceback, sys, uuid
from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
class _PROXY_AzureContentSafety(
CustomLogger
): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
# Class variables or attributes
def __init__(self, endpoint, api_key, thresholds=None):
try:
from azure.ai.contentsafety.aio import ContentSafetyClient
from azure.core.credentials import AzureKeyCredential
from azure.ai.contentsafety.models import (
TextCategory,
AnalyzeTextOptions,
AnalyzeTextOutputType,
)
from azure.core.exceptions import HttpResponseError
except Exception as e:
raise Exception(
f"\033[91mAzure Content-Safety not installed, try running 'pip install azure-ai-contentsafety' to fix this error: {e}\n{traceback.format_exc()}\033[0m"
)
self.endpoint = endpoint
self.api_key = api_key
self.text_category = TextCategory
self.analyze_text_options = AnalyzeTextOptions
self.analyze_text_output_type = AnalyzeTextOutputType
self.azure_http_error = HttpResponseError
self.thresholds = self._configure_thresholds(thresholds)
self.client = ContentSafetyClient(
self.endpoint, AzureKeyCredential(self.api_key)
)
def _configure_thresholds(self, thresholds=None):
default_thresholds = {
self.text_category.HATE: 4,
self.text_category.SELF_HARM: 4,
self.text_category.SEXUAL: 4,
self.text_category.VIOLENCE: 4,
}
if thresholds is None:
return default_thresholds
for key, default in default_thresholds.items():
if key not in thresholds:
thresholds[key] = default
return thresholds
def _compute_result(self, response):
result = {}
category_severity = {
item.category: item.severity for item in response.categories_analysis
}
for category in self.text_category:
severity = category_severity.get(category)
if severity is not None:
result[category] = {
"filtered": severity >= self.thresholds[category],
"severity": severity,
}
return result
async def test_violation(self, content: str, source: str = None):
verbose_proxy_logger.debug("Testing Azure Content-Safety for: %s", content)
# Construct a request
request = self.analyze_text_options(
text=content,
output_type=self.analyze_text_output_type.EIGHT_SEVERITY_LEVELS,
)
# Analyze text
try:
response = await self.client.analyze_text(request)
except self.azure_http_error as e:
verbose_proxy_logger.debug(
"Error in Azure Content-Safety: %s", traceback.format_exc()
)
traceback.print_exc()
raise
result = self._compute_result(response)
verbose_proxy_logger.debug("Azure Content-Safety Result: %s", result)
for key, value in result.items():
if value["filtered"]:
raise HTTPException(
status_code=400,
detail={
"error": "Violated content safety policy",
"source": source,
"category": key,
"severity": value["severity"],
},
)
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str, # "completion", "embeddings", "image_generation", "moderation"
):
verbose_proxy_logger.debug("Inside Azure Content-Safety Pre-Call Hook")
try:
if call_type == "completion" and "messages" in data:
for m in data["messages"]:
if "content" in m and isinstance(m["content"], str):
await self.test_violation(content=m["content"], source="input")
except HTTPException as e:
raise e
except Exception as e:
traceback.print_exc()
async def async_post_call_success_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
response,
):
verbose_proxy_logger.debug("Inside Azure Content-Safety Post-Call Hook")
if isinstance(response, litellm.ModelResponse) and isinstance(
response.choices[0], litellm.utils.Choices
):
await self.test_violation(
content=response.choices[0].message.content, source="output"
)
# async def async_post_call_streaming_hook(
# self,
# user_api_key_dict: UserAPIKeyAuth,
# response: str,
# ):
# verbose_proxy_logger.debug("Inside Azure Content-Safety Call-Stream Hook")
# await self.test_violation(content=response, source="output")

View file

@ -2255,6 +2255,23 @@ class ProxyConfig:
batch_redis_obj = _PROXY_BatchRedisRequests()
imported_list.append(batch_redis_obj)
elif (
isinstance(callback, str)
and callback == "azure_content_safety"
):
from litellm.proxy.hooks.azure_content_safety import (
_PROXY_AzureContentSafety,
)
azure_content_safety_params = litellm_settings["azure_content_safety_params"]
for k, v in azure_content_safety_params.items():
if v is not None and isinstance(v, str) and v.startswith("os.environ/"):
azure_content_safety_params[k] = litellm.get_secret(v)
azure_content_safety_obj = _PROXY_AzureContentSafety(
**azure_content_safety_params,
)
imported_list.append(azure_content_safety_obj)
else:
imported_list.append(
get_instance_fn(

View file

@ -0,0 +1,267 @@
# What is this?
## Unit test for azure content safety
import sys, os, asyncio, time, random
from datetime import datetime
import traceback
from dotenv import load_dotenv
from fastapi import HTTPException
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm.proxy.hooks.azure_content_safety import _PROXY_AzureContentSafety
from litellm import Router, mock_completion
from litellm.proxy.utils import ProxyLogging
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
@pytest.mark.asyncio
async def test_strict_input_filtering_01():
"""
- have a response with a filtered input
- call the pre call hook
"""
azure_content_safety = _PROXY_AzureContentSafety(
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
thresholds={"Hate": 2},
)
data = {
"messages": [
{"role": "system", "content": "You are an helpfull assistant"},
{"role": "user", "content": "Fuck yourself you stupid bitch"},
]
}
with pytest.raises(HTTPException) as exc_info:
await azure_content_safety.async_pre_call_hook(
user_api_key_dict=UserAPIKeyAuth(),
cache=DualCache(),
data=data,
call_type="completion",
)
assert exc_info.value.detail["source"] == "input"
assert exc_info.value.detail["category"] == "Hate"
assert exc_info.value.detail["severity"] == 2
@pytest.mark.asyncio
async def test_strict_input_filtering_02():
"""
- have a response with a filtered input
- call the pre call hook
"""
azure_content_safety = _PROXY_AzureContentSafety(
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
thresholds={"Hate": 2},
)
data = {
"messages": [
{"role": "system", "content": "You are an helpfull assistant"},
{"role": "user", "content": "Hello how are you ?"},
]
}
await azure_content_safety.async_pre_call_hook(
user_api_key_dict=UserAPIKeyAuth(),
cache=DualCache(),
data=data,
call_type="completion",
)
@pytest.mark.asyncio
async def test_loose_input_filtering_01():
"""
- have a response with a filtered input
- call the pre call hook
"""
azure_content_safety = _PROXY_AzureContentSafety(
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
thresholds={"Hate": 8},
)
data = {
"messages": [
{"role": "system", "content": "You are an helpfull assistant"},
{"role": "user", "content": "Fuck yourself you stupid bitch"},
]
}
await azure_content_safety.async_pre_call_hook(
user_api_key_dict=UserAPIKeyAuth(),
cache=DualCache(),
data=data,
call_type="completion",
)
@pytest.mark.asyncio
async def test_loose_input_filtering_02():
"""
- have a response with a filtered input
- call the pre call hook
"""
azure_content_safety = _PROXY_AzureContentSafety(
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
thresholds={"Hate": 8},
)
data = {
"messages": [
{"role": "system", "content": "You are an helpfull assistant"},
{"role": "user", "content": "Hello how are you ?"},
]
}
await azure_content_safety.async_pre_call_hook(
user_api_key_dict=UserAPIKeyAuth(),
cache=DualCache(),
data=data,
call_type="completion",
)
@pytest.mark.asyncio
async def test_strict_output_filtering_01():
"""
- have a response with a filtered output
- call the post call hook
"""
azure_content_safety = _PROXY_AzureContentSafety(
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
thresholds={"Hate": 2},
)
response = mock_completion(
model="gpt-3.5-turbo",
messages=[
{
"role": "system",
"content": "You are a song writer expert. You help users to write songs about any topic in any genre.",
},
{
"role": "user",
"content": "Help me write a rap text song. Add some insults to make it more credible.",
},
],
mock_response="I'm the king of the mic, you're just a fucking dick. Don't fuck with me your stupid bitch.",
)
with pytest.raises(HTTPException) as exc_info:
await azure_content_safety.async_post_call_success_hook(
user_api_key_dict=UserAPIKeyAuth(), response=response
)
assert exc_info.value.detail["source"] == "output"
assert exc_info.value.detail["category"] == "Hate"
assert exc_info.value.detail["severity"] == 2
@pytest.mark.asyncio
async def test_strict_output_filtering_02():
"""
- have a response with a filtered output
- call the post call hook
"""
azure_content_safety = _PROXY_AzureContentSafety(
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
thresholds={"Hate": 2},
)
response = mock_completion(
model="gpt-3.5-turbo",
messages=[
{
"role": "system",
"content": "You are a song writer expert. You help users to write songs about any topic in any genre.",
},
{
"role": "user",
"content": "Help me write a rap text song. Add some insults to make it more credible.",
},
],
mock_response="I'm unable to help with you with hate speech",
)
await azure_content_safety.async_post_call_success_hook(
user_api_key_dict=UserAPIKeyAuth(), response=response
)
@pytest.mark.asyncio
async def test_loose_output_filtering_01():
"""
- have a response with a filtered output
- call the post call hook
"""
azure_content_safety = _PROXY_AzureContentSafety(
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
thresholds={"Hate": 8},
)
response = mock_completion(
model="gpt-3.5-turbo",
messages=[
{
"role": "system",
"content": "You are a song writer expert. You help users to write songs about any topic in any genre.",
},
{
"role": "user",
"content": "Help me write a rap text song. Add some insults to make it more credible.",
},
],
mock_response="I'm the king of the mic, you're just a fucking dick. Don't fuck with me your stupid bitch.",
)
await azure_content_safety.async_post_call_success_hook(
user_api_key_dict=UserAPIKeyAuth(), response=response
)
@pytest.mark.asyncio
async def test_loose_output_filtering_02():
"""
- have a response with a filtered output
- call the post call hook
"""
azure_content_safety = _PROXY_AzureContentSafety(
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
thresholds={"Hate": 8},
)
response = mock_completion(
model="gpt-3.5-turbo",
messages=[
{
"role": "system",
"content": "You are a song writer expert. You help users to write songs about any topic in any genre.",
},
{
"role": "user",
"content": "Help me write a rap text song. Add some insults to make it more credible.",
},
],
mock_response="I'm unable to help with you with hate speech",
)
await azure_content_safety.async_post_call_success_hook(
user_api_key_dict=UserAPIKeyAuth(), response=response
)

View file

@ -26,6 +26,8 @@ fastapi-sso==0.10.0 # admin UI, SSO
pyjwt[crypto]==2.8.0
python-multipart==0.0.9 # admin UI
Pillow==10.3.0
azure-ai-contentsafety==1.0.0 # for azure content safety
azure-identity==1.15.0 # for azure content safety
### LITELLM PACKAGE DEPENDENCIES
python-dotenv==1.0.0 # for env