forked from phoenix/litellm-mirror
✨ feat: Add Azure Content-Safety Proxy hooks
Signed-off-by: Lunik <lunik@tiwabbit.fr>
This commit is contained in:
parent
7ffe410097
commit
6cec252b07
2 changed files with 174 additions and 0 deletions
157
litellm/proxy/hooks/azure_content_safety.py
Normal file
157
litellm/proxy/hooks/azure_content_safety.py
Normal file
|
@ -0,0 +1,157 @@
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
import litellm, traceback, sys, uuid
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
|
||||||
|
|
||||||
|
class _PROXY_AzureContentSafety(
|
||||||
|
CustomLogger
|
||||||
|
): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
||||||
|
# Class variables or attributes
|
||||||
|
|
||||||
|
def __init__(self, endpoint, api_key, thresholds=None):
|
||||||
|
try:
|
||||||
|
from azure.ai.contentsafety.aio import ContentSafetyClient
|
||||||
|
from azure.core.credentials import AzureKeyCredential
|
||||||
|
from azure.ai.contentsafety.models import TextCategory
|
||||||
|
from azure.ai.contentsafety.models import AnalyzeTextOptions
|
||||||
|
from azure.core.exceptions import HttpResponseError
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(
|
||||||
|
f"\033[91mAzure Content-Safety not installed, try running 'pip install azure-ai-contentsafety' to fix this error: {e}\n{traceback.format_exc()}\033[0m"
|
||||||
|
)
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.api_key = api_key
|
||||||
|
self.text_category = TextCategory
|
||||||
|
self.analyze_text_options = AnalyzeTextOptions
|
||||||
|
self.azure_http_error = HttpResponseError
|
||||||
|
|
||||||
|
self.thresholds = self._configure_thresholds(thresholds)
|
||||||
|
|
||||||
|
self.client = ContentSafetyClient(
|
||||||
|
self.endpoint, AzureKeyCredential(self.api_key)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _configure_thresholds(self, thresholds=None):
|
||||||
|
default_thresholds = {
|
||||||
|
self.text_category.HATE: 6,
|
||||||
|
self.text_category.SELF_HARM: 6,
|
||||||
|
self.text_category.SEXUAL: 6,
|
||||||
|
self.text_category.VIOLENCE: 6,
|
||||||
|
}
|
||||||
|
|
||||||
|
if thresholds is None:
|
||||||
|
return default_thresholds
|
||||||
|
|
||||||
|
for key, default in default_thresholds.items():
|
||||||
|
if key not in thresholds:
|
||||||
|
thresholds[key] = default
|
||||||
|
|
||||||
|
return thresholds
|
||||||
|
|
||||||
|
def print_verbose(self, print_statement):
|
||||||
|
try:
|
||||||
|
verbose_proxy_logger.debug(print_statement)
|
||||||
|
if litellm.set_verbose:
|
||||||
|
print(print_statement) # noqa
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _severity(self, severity):
|
||||||
|
if severity >= 6:
|
||||||
|
return "high"
|
||||||
|
elif severity >= 4:
|
||||||
|
return "medium"
|
||||||
|
elif severity >= 2:
|
||||||
|
return "low"
|
||||||
|
else:
|
||||||
|
return "safe"
|
||||||
|
|
||||||
|
def _compute_result(self, response):
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
category_severity = {
|
||||||
|
item.category: item.severity for item in response.categories_analysis
|
||||||
|
}
|
||||||
|
for category in self.text_category:
|
||||||
|
severity = category_severity.get(category)
|
||||||
|
if severity is not None:
|
||||||
|
result[category] = {
|
||||||
|
"filtered": severity >= self.thresholds[category],
|
||||||
|
"severity": self._severity(severity),
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def test_violation(self, content: str, source: str = None):
|
||||||
|
self.print_verbose(f"Testing Azure Content-Safety for: {content}")
|
||||||
|
|
||||||
|
# Construct a request
|
||||||
|
request = self.analyze_text_options(text=content)
|
||||||
|
|
||||||
|
# Analyze text
|
||||||
|
try:
|
||||||
|
response = await self.client.analyze_text(request)
|
||||||
|
except self.azure_http_error as e:
|
||||||
|
self.print_verbose(
|
||||||
|
f"Error in Azure Content-Safety: {traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
traceback.print_exc()
|
||||||
|
raise
|
||||||
|
|
||||||
|
result = self._compute_result(response)
|
||||||
|
self.print_verbose(f"Azure Content-Safety Result: {result}")
|
||||||
|
|
||||||
|
for key, value in result.items():
|
||||||
|
if value["filtered"]:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": "Violated content safety policy",
|
||||||
|
"source": source,
|
||||||
|
"category": key,
|
||||||
|
"severity": value["severity"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_pre_call_hook(
|
||||||
|
self,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
cache: DualCache,
|
||||||
|
data: dict,
|
||||||
|
call_type: str, # "completion", "embeddings", "image_generation", "moderation"
|
||||||
|
):
|
||||||
|
self.print_verbose(f"Inside Azure Content-Safety Pre-Call Hook")
|
||||||
|
try:
|
||||||
|
if call_type == "completion" and "messages" in data:
|
||||||
|
for m in data["messages"]:
|
||||||
|
if "content" in m and isinstance(m["content"], str):
|
||||||
|
await self.test_violation(content=m["content"], source="input")
|
||||||
|
|
||||||
|
except HTTPException as e:
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
async def async_post_call_success_hook(
|
||||||
|
self,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
response,
|
||||||
|
):
|
||||||
|
self.print_verbose(f"Inside Azure Content-Safety Post-Call Hook")
|
||||||
|
if isinstance(response, litellm.ModelResponse) and isinstance(
|
||||||
|
response.choices[0], litellm.utils.Choices
|
||||||
|
):
|
||||||
|
await self.test_violation(
|
||||||
|
content=response.choices[0].message.content, source="output"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_post_call_streaming_hook(
|
||||||
|
self,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
response: str,
|
||||||
|
):
|
||||||
|
self.print_verbose(f"Inside Azure Content-Safety Call-Stream Hook")
|
||||||
|
await self.test_violation(content=response, source="output")
|
|
@ -2235,6 +2235,23 @@ class ProxyConfig:
|
||||||
|
|
||||||
batch_redis_obj = _PROXY_BatchRedisRequests()
|
batch_redis_obj = _PROXY_BatchRedisRequests()
|
||||||
imported_list.append(batch_redis_obj)
|
imported_list.append(batch_redis_obj)
|
||||||
|
elif (
|
||||||
|
isinstance(callback, str)
|
||||||
|
and callback == "azure_content_safety"
|
||||||
|
):
|
||||||
|
from litellm.proxy.hooks.azure_content_safety import (
|
||||||
|
_PROXY_AzureContentSafety,
|
||||||
|
)
|
||||||
|
|
||||||
|
azure_content_safety_params = litellm_settings["azure_content_safety_params"]
|
||||||
|
for k, v in azure_content_safety_params.items():
|
||||||
|
if v is not None and isinstance(v, str) and v.startswith("os.environ/"):
|
||||||
|
azure_content_safety_params[k] = litellm.get_secret(v)
|
||||||
|
|
||||||
|
azure_content_safety_obj = _PROXY_AzureContentSafety(
|
||||||
|
**azure_content_safety_params,
|
||||||
|
)
|
||||||
|
imported_list.append(azure_content_safety_obj)
|
||||||
else:
|
else:
|
||||||
imported_list.append(
|
imported_list.append(
|
||||||
get_instance_fn(
|
get_instance_fn(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue