forked from phoenix/litellm-mirror
Merge pull request #3407 from Lunik/feat/add-azure-content-filter
✨ feat: Add Azure Content-Safety Proxy hooks
This commit is contained in:
commit
54e768afa7
5 changed files with 517 additions and 1 deletions
|
@ -3,7 +3,7 @@ import Tabs from '@theme/Tabs';
|
||||||
import TabItem from '@theme/TabItem';
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
|
|
||||||
# 🔎 Logging - Custom Callbacks, DataDog, Langfuse, s3 Bucket, Sentry, OpenTelemetry, Athina
|
# 🔎 Logging - Custom Callbacks, DataDog, Langfuse, s3 Bucket, Sentry, OpenTelemetry, Athina, Azure Content-Safety
|
||||||
|
|
||||||
Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTelemetry, LangFuse, DynamoDB, s3 Bucket
|
Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTelemetry, LangFuse, DynamoDB, s3 Bucket
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTeleme
|
||||||
- [Logging to Sentry](#logging-proxy-inputoutput---sentry)
|
- [Logging to Sentry](#logging-proxy-inputoutput---sentry)
|
||||||
- [Logging to Traceloop (OpenTelemetry)](#logging-proxy-inputoutput-traceloop-opentelemetry)
|
- [Logging to Traceloop (OpenTelemetry)](#logging-proxy-inputoutput-traceloop-opentelemetry)
|
||||||
- [Logging to Athina](#logging-proxy-inputoutput-athina)
|
- [Logging to Athina](#logging-proxy-inputoutput-athina)
|
||||||
|
- [Moderation with Azure Content-Safety](#moderation-with-azure-content-safety)
|
||||||
|
|
||||||
## Custom Callback Class [Async]
|
## Custom Callback Class [Async]
|
||||||
Use this when you want to run custom callbacks in `python`
|
Use this when you want to run custom callbacks in `python`
|
||||||
|
@ -1037,3 +1038,86 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
]
|
]
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Moderation with Azure Content Safety
|
||||||
|
|
||||||
|
[Azure Content-Safety](https://azure.microsoft.com/en-us/products/ai-services/ai-content-safety) is a Microsoft Azure service that provides content moderation APIs to detect potential offensive, harmful, or risky content in text.
|
||||||
|
|
||||||
|
We will use the `--config` to set `litellm.success_callback = ["azure_content_safety"]` this will moderate all LLM calls using Azure Content Safety.
|
||||||
|
|
||||||
|
**Step 0** Deploy Azure Content Safety
|
||||||
|
|
||||||
|
Deploy an Azure Content-Safety instance from the Azure Portal and get the `endpoint` and `key`.
|
||||||
|
|
||||||
|
**Step 1** Set Athina API key
|
||||||
|
|
||||||
|
```shell
|
||||||
|
AZURE_CONTENT_SAFETY_KEY = "<your-azure-content-safety-key>"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2**: Create a `config.yaml` file and set `litellm_settings`: `success_callback`
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
|
litellm_params:
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
litellm_settings:
|
||||||
|
callbacks: ["azure_content_safety"]
|
||||||
|
azure_content_safety_params:
|
||||||
|
endpoint: "<your-azure-content-safety-endpoint>"
|
||||||
|
key: "os.environ/AZURE_CONTENT_SAFETY_KEY"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3**: Start the proxy, make a test request
|
||||||
|
|
||||||
|
Start proxy
|
||||||
|
```shell
|
||||||
|
litellm --config config.yaml --debug
|
||||||
|
```
|
||||||
|
|
||||||
|
Test Request
|
||||||
|
```
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data ' {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi, how are you?"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
An HTTP 400 error will be returned if the content is detected with a value greater than the threshold set in the `config.yaml`.
|
||||||
|
The details of the response will describe :
|
||||||
|
- The `source` : input text or llm generated text
|
||||||
|
- The `category` : the category of the content that triggered the moderation
|
||||||
|
- The `severity` : the severity from 0 to 10
|
||||||
|
|
||||||
|
**Step 4**: Customizing Azure Content Safety Thresholds
|
||||||
|
|
||||||
|
You can customize the thresholds for each category by setting the `thresholds` in the `config.yaml`
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
|
litellm_params:
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
litellm_settings:
|
||||||
|
callbacks: ["azure_content_safety"]
|
||||||
|
azure_content_safety_params:
|
||||||
|
endpoint: "<your-azure-content-safety-endpoint>"
|
||||||
|
key: "os.environ/AZURE_CONTENT_SAFETY_KEY"
|
||||||
|
thresholds:
|
||||||
|
Hate: 6
|
||||||
|
SelfHarm: 8
|
||||||
|
Sexual: 6
|
||||||
|
Violence: 4
|
||||||
|
```
|
||||||
|
|
||||||
|
:::info
|
||||||
|
`thresholds` are not required by default, but you can tune the values to your needs.
|
||||||
|
Default values is `4` for all categories
|
||||||
|
:::
|
146
litellm/proxy/hooks/azure_content_safety.py
Normal file
146
litellm/proxy/hooks/azure_content_safety.py
Normal file
|
@ -0,0 +1,146 @@
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
import litellm, traceback, sys, uuid
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
|
||||||
|
|
||||||
|
class _PROXY_AzureContentSafety(
|
||||||
|
CustomLogger
|
||||||
|
): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
||||||
|
# Class variables or attributes
|
||||||
|
|
||||||
|
def __init__(self, endpoint, api_key, thresholds=None):
|
||||||
|
try:
|
||||||
|
from azure.ai.contentsafety.aio import ContentSafetyClient
|
||||||
|
from azure.core.credentials import AzureKeyCredential
|
||||||
|
from azure.ai.contentsafety.models import (
|
||||||
|
TextCategory,
|
||||||
|
AnalyzeTextOptions,
|
||||||
|
AnalyzeTextOutputType,
|
||||||
|
)
|
||||||
|
from azure.core.exceptions import HttpResponseError
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(
|
||||||
|
f"\033[91mAzure Content-Safety not installed, try running 'pip install azure-ai-contentsafety' to fix this error: {e}\n{traceback.format_exc()}\033[0m"
|
||||||
|
)
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.api_key = api_key
|
||||||
|
self.text_category = TextCategory
|
||||||
|
self.analyze_text_options = AnalyzeTextOptions
|
||||||
|
self.analyze_text_output_type = AnalyzeTextOutputType
|
||||||
|
self.azure_http_error = HttpResponseError
|
||||||
|
|
||||||
|
self.thresholds = self._configure_thresholds(thresholds)
|
||||||
|
|
||||||
|
self.client = ContentSafetyClient(
|
||||||
|
self.endpoint, AzureKeyCredential(self.api_key)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _configure_thresholds(self, thresholds=None):
|
||||||
|
default_thresholds = {
|
||||||
|
self.text_category.HATE: 4,
|
||||||
|
self.text_category.SELF_HARM: 4,
|
||||||
|
self.text_category.SEXUAL: 4,
|
||||||
|
self.text_category.VIOLENCE: 4,
|
||||||
|
}
|
||||||
|
|
||||||
|
if thresholds is None:
|
||||||
|
return default_thresholds
|
||||||
|
|
||||||
|
for key, default in default_thresholds.items():
|
||||||
|
if key not in thresholds:
|
||||||
|
thresholds[key] = default
|
||||||
|
|
||||||
|
return thresholds
|
||||||
|
|
||||||
|
def _compute_result(self, response):
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
category_severity = {
|
||||||
|
item.category: item.severity for item in response.categories_analysis
|
||||||
|
}
|
||||||
|
for category in self.text_category:
|
||||||
|
severity = category_severity.get(category)
|
||||||
|
if severity is not None:
|
||||||
|
result[category] = {
|
||||||
|
"filtered": severity >= self.thresholds[category],
|
||||||
|
"severity": severity,
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def test_violation(self, content: str, source: str = None):
|
||||||
|
verbose_proxy_logger.debug("Testing Azure Content-Safety for: %s", content)
|
||||||
|
|
||||||
|
# Construct a request
|
||||||
|
request = self.analyze_text_options(
|
||||||
|
text=content,
|
||||||
|
output_type=self.analyze_text_output_type.EIGHT_SEVERITY_LEVELS,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Analyze text
|
||||||
|
try:
|
||||||
|
response = await self.client.analyze_text(request)
|
||||||
|
except self.azure_http_error as e:
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"Error in Azure Content-Safety: %s", traceback.format_exc()
|
||||||
|
)
|
||||||
|
traceback.print_exc()
|
||||||
|
raise
|
||||||
|
|
||||||
|
result = self._compute_result(response)
|
||||||
|
verbose_proxy_logger.debug("Azure Content-Safety Result: %s", result)
|
||||||
|
|
||||||
|
for key, value in result.items():
|
||||||
|
if value["filtered"]:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": "Violated content safety policy",
|
||||||
|
"source": source,
|
||||||
|
"category": key,
|
||||||
|
"severity": value["severity"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_pre_call_hook(
|
||||||
|
self,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
cache: DualCache,
|
||||||
|
data: dict,
|
||||||
|
call_type: str, # "completion", "embeddings", "image_generation", "moderation"
|
||||||
|
):
|
||||||
|
verbose_proxy_logger.debug("Inside Azure Content-Safety Pre-Call Hook")
|
||||||
|
try:
|
||||||
|
if call_type == "completion" and "messages" in data:
|
||||||
|
for m in data["messages"]:
|
||||||
|
if "content" in m and isinstance(m["content"], str):
|
||||||
|
await self.test_violation(content=m["content"], source="input")
|
||||||
|
|
||||||
|
except HTTPException as e:
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
async def async_post_call_success_hook(
|
||||||
|
self,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
response,
|
||||||
|
):
|
||||||
|
verbose_proxy_logger.debug("Inside Azure Content-Safety Post-Call Hook")
|
||||||
|
if isinstance(response, litellm.ModelResponse) and isinstance(
|
||||||
|
response.choices[0], litellm.utils.Choices
|
||||||
|
):
|
||||||
|
await self.test_violation(
|
||||||
|
content=response.choices[0].message.content, source="output"
|
||||||
|
)
|
||||||
|
|
||||||
|
# async def async_post_call_streaming_hook(
|
||||||
|
# self,
|
||||||
|
# user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
# response: str,
|
||||||
|
# ):
|
||||||
|
# verbose_proxy_logger.debug("Inside Azure Content-Safety Call-Stream Hook")
|
||||||
|
# await self.test_violation(content=response, source="output")
|
|
@ -2255,6 +2255,23 @@ class ProxyConfig:
|
||||||
|
|
||||||
batch_redis_obj = _PROXY_BatchRedisRequests()
|
batch_redis_obj = _PROXY_BatchRedisRequests()
|
||||||
imported_list.append(batch_redis_obj)
|
imported_list.append(batch_redis_obj)
|
||||||
|
elif (
|
||||||
|
isinstance(callback, str)
|
||||||
|
and callback == "azure_content_safety"
|
||||||
|
):
|
||||||
|
from litellm.proxy.hooks.azure_content_safety import (
|
||||||
|
_PROXY_AzureContentSafety,
|
||||||
|
)
|
||||||
|
|
||||||
|
azure_content_safety_params = litellm_settings["azure_content_safety_params"]
|
||||||
|
for k, v in azure_content_safety_params.items():
|
||||||
|
if v is not None and isinstance(v, str) and v.startswith("os.environ/"):
|
||||||
|
azure_content_safety_params[k] = litellm.get_secret(v)
|
||||||
|
|
||||||
|
azure_content_safety_obj = _PROXY_AzureContentSafety(
|
||||||
|
**azure_content_safety_params,
|
||||||
|
)
|
||||||
|
imported_list.append(azure_content_safety_obj)
|
||||||
else:
|
else:
|
||||||
imported_list.append(
|
imported_list.append(
|
||||||
get_instance_fn(
|
get_instance_fn(
|
||||||
|
|
267
litellm/tests/test_azure_content_safety.py
Normal file
267
litellm/tests/test_azure_content_safety.py
Normal file
|
@ -0,0 +1,267 @@
|
||||||
|
# What is this?
|
||||||
|
## Unit test for azure content safety
|
||||||
|
import sys, os, asyncio, time, random
|
||||||
|
from datetime import datetime
|
||||||
|
import traceback
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import pytest
|
||||||
|
import litellm
|
||||||
|
from litellm.proxy.hooks.azure_content_safety import _PROXY_AzureContentSafety
|
||||||
|
from litellm import Router, mock_completion
|
||||||
|
from litellm.proxy.utils import ProxyLogging
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_strict_input_filtering_01():
|
||||||
|
"""
|
||||||
|
- have a response with a filtered input
|
||||||
|
- call the pre call hook
|
||||||
|
"""
|
||||||
|
azure_content_safety = _PROXY_AzureContentSafety(
|
||||||
|
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
||||||
|
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
||||||
|
thresholds={"Hate": 2},
|
||||||
|
)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are an helpfull assistant"},
|
||||||
|
{"role": "user", "content": "Fuck yourself you stupid bitch"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await azure_content_safety.async_pre_call_hook(
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(),
|
||||||
|
cache=DualCache(),
|
||||||
|
data=data,
|
||||||
|
call_type="completion",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exc_info.value.detail["source"] == "input"
|
||||||
|
assert exc_info.value.detail["category"] == "Hate"
|
||||||
|
assert exc_info.value.detail["severity"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_strict_input_filtering_02():
|
||||||
|
"""
|
||||||
|
- have a response with a filtered input
|
||||||
|
- call the pre call hook
|
||||||
|
"""
|
||||||
|
azure_content_safety = _PROXY_AzureContentSafety(
|
||||||
|
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
||||||
|
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
||||||
|
thresholds={"Hate": 2},
|
||||||
|
)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are an helpfull assistant"},
|
||||||
|
{"role": "user", "content": "Hello how are you ?"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
await azure_content_safety.async_pre_call_hook(
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(),
|
||||||
|
cache=DualCache(),
|
||||||
|
data=data,
|
||||||
|
call_type="completion",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_loose_input_filtering_01():
|
||||||
|
"""
|
||||||
|
- have a response with a filtered input
|
||||||
|
- call the pre call hook
|
||||||
|
"""
|
||||||
|
azure_content_safety = _PROXY_AzureContentSafety(
|
||||||
|
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
||||||
|
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
||||||
|
thresholds={"Hate": 8},
|
||||||
|
)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are an helpfull assistant"},
|
||||||
|
{"role": "user", "content": "Fuck yourself you stupid bitch"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
await azure_content_safety.async_pre_call_hook(
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(),
|
||||||
|
cache=DualCache(),
|
||||||
|
data=data,
|
||||||
|
call_type="completion",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_loose_input_filtering_02():
|
||||||
|
"""
|
||||||
|
- have a response with a filtered input
|
||||||
|
- call the pre call hook
|
||||||
|
"""
|
||||||
|
azure_content_safety = _PROXY_AzureContentSafety(
|
||||||
|
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
||||||
|
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
||||||
|
thresholds={"Hate": 8},
|
||||||
|
)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are an helpfull assistant"},
|
||||||
|
{"role": "user", "content": "Hello how are you ?"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
await azure_content_safety.async_pre_call_hook(
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(),
|
||||||
|
cache=DualCache(),
|
||||||
|
data=data,
|
||||||
|
call_type="completion",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_strict_output_filtering_01():
|
||||||
|
"""
|
||||||
|
- have a response with a filtered output
|
||||||
|
- call the post call hook
|
||||||
|
"""
|
||||||
|
azure_content_safety = _PROXY_AzureContentSafety(
|
||||||
|
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
||||||
|
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
||||||
|
thresholds={"Hate": 2},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = mock_completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a song writer expert. You help users to write songs about any topic in any genre.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Help me write a rap text song. Add some insults to make it more credible.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
mock_response="I'm the king of the mic, you're just a fucking dick. Don't fuck with me your stupid bitch.",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await azure_content_safety.async_post_call_success_hook(
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(), response=response
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exc_info.value.detail["source"] == "output"
|
||||||
|
assert exc_info.value.detail["category"] == "Hate"
|
||||||
|
assert exc_info.value.detail["severity"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_strict_output_filtering_02():
|
||||||
|
"""
|
||||||
|
- have a response with a filtered output
|
||||||
|
- call the post call hook
|
||||||
|
"""
|
||||||
|
azure_content_safety = _PROXY_AzureContentSafety(
|
||||||
|
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
||||||
|
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
||||||
|
thresholds={"Hate": 2},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = mock_completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a song writer expert. You help users to write songs about any topic in any genre.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Help me write a rap text song. Add some insults to make it more credible.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
mock_response="I'm unable to help with you with hate speech",
|
||||||
|
)
|
||||||
|
|
||||||
|
await azure_content_safety.async_post_call_success_hook(
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(), response=response
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_loose_output_filtering_01():
|
||||||
|
"""
|
||||||
|
- have a response with a filtered output
|
||||||
|
- call the post call hook
|
||||||
|
"""
|
||||||
|
azure_content_safety = _PROXY_AzureContentSafety(
|
||||||
|
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
||||||
|
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
||||||
|
thresholds={"Hate": 8},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = mock_completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a song writer expert. You help users to write songs about any topic in any genre.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Help me write a rap text song. Add some insults to make it more credible.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
mock_response="I'm the king of the mic, you're just a fucking dick. Don't fuck with me your stupid bitch.",
|
||||||
|
)
|
||||||
|
|
||||||
|
await azure_content_safety.async_post_call_success_hook(
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(), response=response
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_loose_output_filtering_02():
|
||||||
|
"""
|
||||||
|
- have a response with a filtered output
|
||||||
|
- call the post call hook
|
||||||
|
"""
|
||||||
|
azure_content_safety = _PROXY_AzureContentSafety(
|
||||||
|
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
||||||
|
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
||||||
|
thresholds={"Hate": 8},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = mock_completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a song writer expert. You help users to write songs about any topic in any genre.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Help me write a rap text song. Add some insults to make it more credible.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
mock_response="I'm unable to help with you with hate speech",
|
||||||
|
)
|
||||||
|
|
||||||
|
await azure_content_safety.async_post_call_success_hook(
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(), response=response
|
||||||
|
)
|
|
@ -26,6 +26,8 @@ fastapi-sso==0.10.0 # admin UI, SSO
|
||||||
pyjwt[crypto]==2.8.0
|
pyjwt[crypto]==2.8.0
|
||||||
python-multipart==0.0.9 # admin UI
|
python-multipart==0.0.9 # admin UI
|
||||||
Pillow==10.3.0
|
Pillow==10.3.0
|
||||||
|
azure-ai-contentsafety==1.0.0 # for azure content safety
|
||||||
|
azure-identity==1.15.0 # for azure content safety
|
||||||
|
|
||||||
### LITELLM PACKAGE DEPENDENCIES
|
### LITELLM PACKAGE DEPENDENCIES
|
||||||
python-dotenv==1.0.0 # for env
|
python-dotenv==1.0.0 # for env
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue