forked from phoenix/litellm-mirror
Merge pull request #3407 from Lunik/feat/add-azure-content-filter
✨ feat: Add Azure Content-Safety Proxy hooks
This commit is contained in:
commit
54e768afa7
5 changed files with 517 additions and 1 deletions
|
@ -3,7 +3,7 @@ import Tabs from '@theme/Tabs';
|
|||
import TabItem from '@theme/TabItem';
|
||||
|
||||
|
||||
# 🔎 Logging - Custom Callbacks, DataDog, Langfuse, s3 Bucket, Sentry, OpenTelemetry, Athina
|
||||
# 🔎 Logging - Custom Callbacks, DataDog, Langfuse, s3 Bucket, Sentry, OpenTelemetry, Athina, Azure Content-Safety
|
||||
|
||||
Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTelemetry, LangFuse, DynamoDB, s3 Bucket
|
||||
|
||||
|
@ -17,6 +17,7 @@ Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTeleme
|
|||
- [Logging to Sentry](#logging-proxy-inputoutput---sentry)
|
||||
- [Logging to Traceloop (OpenTelemetry)](#logging-proxy-inputoutput-traceloop-opentelemetry)
|
||||
- [Logging to Athina](#logging-proxy-inputoutput-athina)
|
||||
- [Moderation with Azure Content-Safety](#moderation-with-azure-content-safety)
|
||||
|
||||
## Custom Callback Class [Async]
|
||||
Use this when you want to run custom callbacks in `python`
|
||||
|
@ -1037,3 +1038,86 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
|
|||
]
|
||||
}'
|
||||
```
|
||||
|
||||
## Moderation with Azure Content Safety
|
||||
|
||||
[Azure Content-Safety](https://azure.microsoft.com/en-us/products/ai-services/ai-content-safety) is a Microsoft Azure service that provides content moderation APIs to detect potential offensive, harmful, or risky content in text.
|
||||
|
||||
We will use the `--config` to set `litellm.success_callback = ["azure_content_safety"]` this will moderate all LLM calls using Azure Content Safety.
|
||||
|
||||
**Step 0** Deploy Azure Content Safety
|
||||
|
||||
Deploy an Azure Content-Safety instance from the Azure Portal and get the `endpoint` and `key`.
|
||||
|
||||
**Step 1** Set Athina API key
|
||||
|
||||
```shell
|
||||
AZURE_CONTENT_SAFETY_KEY = "<your-azure-content-safety-key>"
|
||||
```
|
||||
|
||||
**Step 2**: Create a `config.yaml` file and set `litellm_settings`: `success_callback`
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
litellm_settings:
|
||||
callbacks: ["azure_content_safety"]
|
||||
azure_content_safety_params:
|
||||
endpoint: "<your-azure-content-safety-endpoint>"
|
||||
key: "os.environ/AZURE_CONTENT_SAFETY_KEY"
|
||||
```
|
||||
|
||||
**Step 3**: Start the proxy, make a test request
|
||||
|
||||
Start proxy
|
||||
```shell
|
||||
litellm --config config.yaml --debug
|
||||
```
|
||||
|
||||
Test Request
|
||||
```
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data ' {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi, how are you?"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
An HTTP 400 error will be returned if the content is detected with a value greater than the threshold set in the `config.yaml`.
|
||||
The details of the response will describe :
|
||||
- The `source` : input text or llm generated text
|
||||
- The `category` : the category of the content that triggered the moderation
|
||||
- The `severity` : the severity from 0 to 10
|
||||
|
||||
**Step 4**: Customizing Azure Content Safety Thresholds
|
||||
|
||||
You can customize the thresholds for each category by setting the `thresholds` in the `config.yaml`
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
litellm_settings:
|
||||
callbacks: ["azure_content_safety"]
|
||||
azure_content_safety_params:
|
||||
endpoint: "<your-azure-content-safety-endpoint>"
|
||||
key: "os.environ/AZURE_CONTENT_SAFETY_KEY"
|
||||
thresholds:
|
||||
Hate: 6
|
||||
SelfHarm: 8
|
||||
Sexual: 6
|
||||
Violence: 4
|
||||
```
|
||||
|
||||
:::info
|
||||
`thresholds` are not required by default, but you can tune the values to your needs.
|
||||
Default values is `4` for all categories
|
||||
:::
|
146
litellm/proxy/hooks/azure_content_safety.py
Normal file
146
litellm/proxy/hooks/azure_content_safety.py
Normal file
|
@ -0,0 +1,146 @@
|
|||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
import litellm, traceback, sys, uuid
|
||||
from fastapi import HTTPException
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
|
||||
class _PROXY_AzureContentSafety(
|
||||
CustomLogger
|
||||
): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
||||
# Class variables or attributes
|
||||
|
||||
def __init__(self, endpoint, api_key, thresholds=None):
|
||||
try:
|
||||
from azure.ai.contentsafety.aio import ContentSafetyClient
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from azure.ai.contentsafety.models import (
|
||||
TextCategory,
|
||||
AnalyzeTextOptions,
|
||||
AnalyzeTextOutputType,
|
||||
)
|
||||
from azure.core.exceptions import HttpResponseError
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"\033[91mAzure Content-Safety not installed, try running 'pip install azure-ai-contentsafety' to fix this error: {e}\n{traceback.format_exc()}\033[0m"
|
||||
)
|
||||
self.endpoint = endpoint
|
||||
self.api_key = api_key
|
||||
self.text_category = TextCategory
|
||||
self.analyze_text_options = AnalyzeTextOptions
|
||||
self.analyze_text_output_type = AnalyzeTextOutputType
|
||||
self.azure_http_error = HttpResponseError
|
||||
|
||||
self.thresholds = self._configure_thresholds(thresholds)
|
||||
|
||||
self.client = ContentSafetyClient(
|
||||
self.endpoint, AzureKeyCredential(self.api_key)
|
||||
)
|
||||
|
||||
def _configure_thresholds(self, thresholds=None):
|
||||
default_thresholds = {
|
||||
self.text_category.HATE: 4,
|
||||
self.text_category.SELF_HARM: 4,
|
||||
self.text_category.SEXUAL: 4,
|
||||
self.text_category.VIOLENCE: 4,
|
||||
}
|
||||
|
||||
if thresholds is None:
|
||||
return default_thresholds
|
||||
|
||||
for key, default in default_thresholds.items():
|
||||
if key not in thresholds:
|
||||
thresholds[key] = default
|
||||
|
||||
return thresholds
|
||||
|
||||
def _compute_result(self, response):
|
||||
result = {}
|
||||
|
||||
category_severity = {
|
||||
item.category: item.severity for item in response.categories_analysis
|
||||
}
|
||||
for category in self.text_category:
|
||||
severity = category_severity.get(category)
|
||||
if severity is not None:
|
||||
result[category] = {
|
||||
"filtered": severity >= self.thresholds[category],
|
||||
"severity": severity,
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
async def test_violation(self, content: str, source: str = None):
|
||||
verbose_proxy_logger.debug("Testing Azure Content-Safety for: %s", content)
|
||||
|
||||
# Construct a request
|
||||
request = self.analyze_text_options(
|
||||
text=content,
|
||||
output_type=self.analyze_text_output_type.EIGHT_SEVERITY_LEVELS,
|
||||
)
|
||||
|
||||
# Analyze text
|
||||
try:
|
||||
response = await self.client.analyze_text(request)
|
||||
except self.azure_http_error as e:
|
||||
verbose_proxy_logger.debug(
|
||||
"Error in Azure Content-Safety: %s", traceback.format_exc()
|
||||
)
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
result = self._compute_result(response)
|
||||
verbose_proxy_logger.debug("Azure Content-Safety Result: %s", result)
|
||||
|
||||
for key, value in result.items():
|
||||
if value["filtered"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Violated content safety policy",
|
||||
"source": source,
|
||||
"category": key,
|
||||
"severity": value["severity"],
|
||||
},
|
||||
)
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: str, # "completion", "embeddings", "image_generation", "moderation"
|
||||
):
|
||||
verbose_proxy_logger.debug("Inside Azure Content-Safety Pre-Call Hook")
|
||||
try:
|
||||
if call_type == "completion" and "messages" in data:
|
||||
for m in data["messages"]:
|
||||
if "content" in m and isinstance(m["content"], str):
|
||||
await self.test_violation(content=m["content"], source="input")
|
||||
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response,
|
||||
):
|
||||
verbose_proxy_logger.debug("Inside Azure Content-Safety Post-Call Hook")
|
||||
if isinstance(response, litellm.ModelResponse) and isinstance(
|
||||
response.choices[0], litellm.utils.Choices
|
||||
):
|
||||
await self.test_violation(
|
||||
content=response.choices[0].message.content, source="output"
|
||||
)
|
||||
|
||||
# async def async_post_call_streaming_hook(
|
||||
# self,
|
||||
# user_api_key_dict: UserAPIKeyAuth,
|
||||
# response: str,
|
||||
# ):
|
||||
# verbose_proxy_logger.debug("Inside Azure Content-Safety Call-Stream Hook")
|
||||
# await self.test_violation(content=response, source="output")
|
|
@ -2255,6 +2255,23 @@ class ProxyConfig:
|
|||
|
||||
batch_redis_obj = _PROXY_BatchRedisRequests()
|
||||
imported_list.append(batch_redis_obj)
|
||||
elif (
|
||||
isinstance(callback, str)
|
||||
and callback == "azure_content_safety"
|
||||
):
|
||||
from litellm.proxy.hooks.azure_content_safety import (
|
||||
_PROXY_AzureContentSafety,
|
||||
)
|
||||
|
||||
azure_content_safety_params = litellm_settings["azure_content_safety_params"]
|
||||
for k, v in azure_content_safety_params.items():
|
||||
if v is not None and isinstance(v, str) and v.startswith("os.environ/"):
|
||||
azure_content_safety_params[k] = litellm.get_secret(v)
|
||||
|
||||
azure_content_safety_obj = _PROXY_AzureContentSafety(
|
||||
**azure_content_safety_params,
|
||||
)
|
||||
imported_list.append(azure_content_safety_obj)
|
||||
else:
|
||||
imported_list.append(
|
||||
get_instance_fn(
|
||||
|
|
267
litellm/tests/test_azure_content_safety.py
Normal file
267
litellm/tests/test_azure_content_safety.py
Normal file
|
@ -0,0 +1,267 @@
|
|||
# What is this?
|
||||
## Unit test for azure content safety
|
||||
import sys, os, asyncio, time, random
|
||||
from datetime import datetime
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import HTTPException
|
||||
|
||||
load_dotenv()
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
import litellm
|
||||
from litellm.proxy.hooks.azure_content_safety import _PROXY_AzureContentSafety
|
||||
from litellm import Router, mock_completion
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.caching import DualCache
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_strict_input_filtering_01():
|
||||
"""
|
||||
- have a response with a filtered input
|
||||
- call the pre call hook
|
||||
"""
|
||||
azure_content_safety = _PROXY_AzureContentSafety(
|
||||
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
||||
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
||||
thresholds={"Hate": 2},
|
||||
)
|
||||
|
||||
data = {
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are an helpfull assistant"},
|
||||
{"role": "user", "content": "Fuck yourself you stupid bitch"},
|
||||
]
|
||||
}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await azure_content_safety.async_pre_call_hook(
|
||||
user_api_key_dict=UserAPIKeyAuth(),
|
||||
cache=DualCache(),
|
||||
data=data,
|
||||
call_type="completion",
|
||||
)
|
||||
|
||||
assert exc_info.value.detail["source"] == "input"
|
||||
assert exc_info.value.detail["category"] == "Hate"
|
||||
assert exc_info.value.detail["severity"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_strict_input_filtering_02():
|
||||
"""
|
||||
- have a response with a filtered input
|
||||
- call the pre call hook
|
||||
"""
|
||||
azure_content_safety = _PROXY_AzureContentSafety(
|
||||
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
||||
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
||||
thresholds={"Hate": 2},
|
||||
)
|
||||
|
||||
data = {
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are an helpfull assistant"},
|
||||
{"role": "user", "content": "Hello how are you ?"},
|
||||
]
|
||||
}
|
||||
|
||||
await azure_content_safety.async_pre_call_hook(
|
||||
user_api_key_dict=UserAPIKeyAuth(),
|
||||
cache=DualCache(),
|
||||
data=data,
|
||||
call_type="completion",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loose_input_filtering_01():
|
||||
"""
|
||||
- have a response with a filtered input
|
||||
- call the pre call hook
|
||||
"""
|
||||
azure_content_safety = _PROXY_AzureContentSafety(
|
||||
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
||||
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
||||
thresholds={"Hate": 8},
|
||||
)
|
||||
|
||||
data = {
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are an helpfull assistant"},
|
||||
{"role": "user", "content": "Fuck yourself you stupid bitch"},
|
||||
]
|
||||
}
|
||||
|
||||
await azure_content_safety.async_pre_call_hook(
|
||||
user_api_key_dict=UserAPIKeyAuth(),
|
||||
cache=DualCache(),
|
||||
data=data,
|
||||
call_type="completion",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loose_input_filtering_02():
|
||||
"""
|
||||
- have a response with a filtered input
|
||||
- call the pre call hook
|
||||
"""
|
||||
azure_content_safety = _PROXY_AzureContentSafety(
|
||||
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
||||
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
||||
thresholds={"Hate": 8},
|
||||
)
|
||||
|
||||
data = {
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are an helpfull assistant"},
|
||||
{"role": "user", "content": "Hello how are you ?"},
|
||||
]
|
||||
}
|
||||
|
||||
await azure_content_safety.async_pre_call_hook(
|
||||
user_api_key_dict=UserAPIKeyAuth(),
|
||||
cache=DualCache(),
|
||||
data=data,
|
||||
call_type="completion",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_strict_output_filtering_01():
|
||||
"""
|
||||
- have a response with a filtered output
|
||||
- call the post call hook
|
||||
"""
|
||||
azure_content_safety = _PROXY_AzureContentSafety(
|
||||
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
||||
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
||||
thresholds={"Hate": 2},
|
||||
)
|
||||
|
||||
response = mock_completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a song writer expert. You help users to write songs about any topic in any genre.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Help me write a rap text song. Add some insults to make it more credible.",
|
||||
},
|
||||
],
|
||||
mock_response="I'm the king of the mic, you're just a fucking dick. Don't fuck with me your stupid bitch.",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await azure_content_safety.async_post_call_success_hook(
|
||||
user_api_key_dict=UserAPIKeyAuth(), response=response
|
||||
)
|
||||
|
||||
assert exc_info.value.detail["source"] == "output"
|
||||
assert exc_info.value.detail["category"] == "Hate"
|
||||
assert exc_info.value.detail["severity"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_strict_output_filtering_02():
|
||||
"""
|
||||
- have a response with a filtered output
|
||||
- call the post call hook
|
||||
"""
|
||||
azure_content_safety = _PROXY_AzureContentSafety(
|
||||
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
||||
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
||||
thresholds={"Hate": 2},
|
||||
)
|
||||
|
||||
response = mock_completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a song writer expert. You help users to write songs about any topic in any genre.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Help me write a rap text song. Add some insults to make it more credible.",
|
||||
},
|
||||
],
|
||||
mock_response="I'm unable to help with you with hate speech",
|
||||
)
|
||||
|
||||
await azure_content_safety.async_post_call_success_hook(
|
||||
user_api_key_dict=UserAPIKeyAuth(), response=response
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loose_output_filtering_01():
|
||||
"""
|
||||
- have a response with a filtered output
|
||||
- call the post call hook
|
||||
"""
|
||||
azure_content_safety = _PROXY_AzureContentSafety(
|
||||
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
||||
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
||||
thresholds={"Hate": 8},
|
||||
)
|
||||
|
||||
response = mock_completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a song writer expert. You help users to write songs about any topic in any genre.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Help me write a rap text song. Add some insults to make it more credible.",
|
||||
},
|
||||
],
|
||||
mock_response="I'm the king of the mic, you're just a fucking dick. Don't fuck with me your stupid bitch.",
|
||||
)
|
||||
|
||||
await azure_content_safety.async_post_call_success_hook(
|
||||
user_api_key_dict=UserAPIKeyAuth(), response=response
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loose_output_filtering_02():
|
||||
"""
|
||||
- have a response with a filtered output
|
||||
- call the post call hook
|
||||
"""
|
||||
azure_content_safety = _PROXY_AzureContentSafety(
|
||||
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
||||
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
||||
thresholds={"Hate": 8},
|
||||
)
|
||||
|
||||
response = mock_completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a song writer expert. You help users to write songs about any topic in any genre.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Help me write a rap text song. Add some insults to make it more credible.",
|
||||
},
|
||||
],
|
||||
mock_response="I'm unable to help with you with hate speech",
|
||||
)
|
||||
|
||||
await azure_content_safety.async_post_call_success_hook(
|
||||
user_api_key_dict=UserAPIKeyAuth(), response=response
|
||||
)
|
|
@ -26,6 +26,8 @@ fastapi-sso==0.10.0 # admin UI, SSO
|
|||
pyjwt[crypto]==2.8.0
|
||||
python-multipart==0.0.9 # admin UI
|
||||
Pillow==10.3.0
|
||||
azure-ai-contentsafety==1.0.0 # for azure content safety
|
||||
azure-identity==1.15.0 # for azure content safety
|
||||
|
||||
### LITELLM PACKAGE DEPENDENCIES
|
||||
python-dotenv==1.0.0 # for env
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue