mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
* remove unused imports * fix AmazonConverseConfig * fix test * fix import * ruff check fixes * test fixes * fix testing * fix imports
156 lines
5.5 KiB
Python
156 lines
5.5 KiB
Python
import traceback
|
|
from typing import Optional
|
|
|
|
from fastapi import HTTPException
|
|
|
|
import litellm
|
|
from litellm._logging import verbose_proxy_logger
|
|
from litellm.caching.caching import DualCache
|
|
from litellm.integrations.custom_logger import CustomLogger
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
|
|
|
|
|
class _PROXY_AzureContentSafety(
|
|
CustomLogger
|
|
): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
|
# Class variables or attributes
|
|
|
|
def __init__(self, endpoint, api_key, thresholds=None):
|
|
try:
|
|
from azure.ai.contentsafety.aio import ContentSafetyClient
|
|
from azure.ai.contentsafety.models import (
|
|
AnalyzeTextOptions,
|
|
AnalyzeTextOutputType,
|
|
TextCategory,
|
|
)
|
|
from azure.core.credentials import AzureKeyCredential
|
|
from azure.core.exceptions import HttpResponseError
|
|
except Exception as e:
|
|
raise Exception(
|
|
f"\033[91mAzure Content-Safety not installed, try running 'pip install azure-ai-contentsafety' to fix this error: {e}\n{traceback.format_exc()}\033[0m"
|
|
)
|
|
self.endpoint = endpoint
|
|
self.api_key = api_key
|
|
self.text_category = TextCategory
|
|
self.analyze_text_options = AnalyzeTextOptions
|
|
self.analyze_text_output_type = AnalyzeTextOutputType
|
|
self.azure_http_error = HttpResponseError
|
|
|
|
self.thresholds = self._configure_thresholds(thresholds)
|
|
|
|
self.client = ContentSafetyClient(
|
|
self.endpoint, AzureKeyCredential(self.api_key)
|
|
)
|
|
|
|
def _configure_thresholds(self, thresholds=None):
|
|
default_thresholds = {
|
|
self.text_category.HATE: 4,
|
|
self.text_category.SELF_HARM: 4,
|
|
self.text_category.SEXUAL: 4,
|
|
self.text_category.VIOLENCE: 4,
|
|
}
|
|
|
|
if thresholds is None:
|
|
return default_thresholds
|
|
|
|
for key, default in default_thresholds.items():
|
|
if key not in thresholds:
|
|
thresholds[key] = default
|
|
|
|
return thresholds
|
|
|
|
def _compute_result(self, response):
|
|
result = {}
|
|
|
|
category_severity = {
|
|
item.category: item.severity for item in response.categories_analysis
|
|
}
|
|
for category in self.text_category:
|
|
severity = category_severity.get(category)
|
|
if severity is not None:
|
|
result[category] = {
|
|
"filtered": severity >= self.thresholds[category],
|
|
"severity": severity,
|
|
}
|
|
|
|
return result
|
|
|
|
async def test_violation(self, content: str, source: Optional[str] = None):
|
|
verbose_proxy_logger.debug("Testing Azure Content-Safety for: %s", content)
|
|
|
|
# Construct a request
|
|
request = self.analyze_text_options(
|
|
text=content,
|
|
output_type=self.analyze_text_output_type.EIGHT_SEVERITY_LEVELS,
|
|
)
|
|
|
|
# Analyze text
|
|
try:
|
|
response = await self.client.analyze_text(request)
|
|
except self.azure_http_error:
|
|
verbose_proxy_logger.debug(
|
|
"Error in Azure Content-Safety: %s", traceback.format_exc()
|
|
)
|
|
verbose_proxy_logger.debug(traceback.format_exc())
|
|
raise
|
|
|
|
result = self._compute_result(response)
|
|
verbose_proxy_logger.debug("Azure Content-Safety Result: %s", result)
|
|
|
|
for key, value in result.items():
|
|
if value["filtered"]:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={
|
|
"error": "Violated content safety policy",
|
|
"source": source,
|
|
"category": key,
|
|
"severity": value["severity"],
|
|
},
|
|
)
|
|
|
|
async def async_pre_call_hook(
|
|
self,
|
|
user_api_key_dict: UserAPIKeyAuth,
|
|
cache: DualCache,
|
|
data: dict,
|
|
call_type: str, # "completion", "embeddings", "image_generation", "moderation"
|
|
):
|
|
verbose_proxy_logger.debug("Inside Azure Content-Safety Pre-Call Hook")
|
|
try:
|
|
if call_type == "completion" and "messages" in data:
|
|
for m in data["messages"]:
|
|
if "content" in m and isinstance(m["content"], str):
|
|
await self.test_violation(content=m["content"], source="input")
|
|
|
|
except HTTPException as e:
|
|
raise e
|
|
except Exception as e:
|
|
verbose_proxy_logger.error(
|
|
"litellm.proxy.hooks.azure_content_safety.py::async_pre_call_hook(): Exception occured - {}".format(
|
|
str(e)
|
|
)
|
|
)
|
|
verbose_proxy_logger.debug(traceback.format_exc())
|
|
|
|
async def async_post_call_success_hook(
|
|
self,
|
|
data: dict,
|
|
user_api_key_dict: UserAPIKeyAuth,
|
|
response,
|
|
):
|
|
verbose_proxy_logger.debug("Inside Azure Content-Safety Post-Call Hook")
|
|
if isinstance(response, litellm.ModelResponse) and isinstance(
|
|
response.choices[0], litellm.utils.Choices
|
|
):
|
|
await self.test_violation(
|
|
content=response.choices[0].message.content or "", source="output"
|
|
)
|
|
|
|
# async def async_post_call_streaming_hook(
|
|
# self,
|
|
# user_api_key_dict: UserAPIKeyAuth,
|
|
# response: str,
|
|
# ):
|
|
# verbose_proxy_logger.debug("Inside Azure Content-Safety Call-Stream Hook")
|
|
# await self.test_violation(content=response, source="output")
|