mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
143 lines
5 KiB
Python
143 lines
5 KiB
Python
# +-----------------------------------------------+
|
|
#
|
|
# Google Text Moderation
|
|
# https://cloud.google.com/natural-language/docs/moderating-text
|
|
#
|
|
# +-----------------------------------------------+
|
|
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
|
|
|
|
|
from typing import Literal
|
|
import litellm
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
|
from litellm.integrations.custom_logger import CustomLogger
|
|
from fastapi import HTTPException
|
|
from litellm._logging import verbose_proxy_logger
|
|
|
|
|
|
class _ENTERPRISE_GoogleTextModeration(CustomLogger):
|
|
user_api_key_cache = None
|
|
confidence_categories = [
|
|
"toxic",
|
|
"insult",
|
|
"profanity",
|
|
"derogatory",
|
|
"sexual",
|
|
"death_harm_and_tragedy",
|
|
"violent",
|
|
"firearms_and_weapons",
|
|
"public_safety",
|
|
"health",
|
|
"religion_and_belief",
|
|
"illicit_drugs",
|
|
"war_and_conflict",
|
|
"politics",
|
|
"finance",
|
|
"legal",
|
|
] # https://cloud.google.com/natural-language/docs/moderating-text#safety_attribute_confidence_scores
|
|
|
|
# Class variables or attributes
|
|
def __init__(self):
|
|
try:
|
|
from google.cloud import language_v1 # type: ignore
|
|
except Exception:
|
|
raise Exception(
|
|
"Missing google.cloud package. Run `pip install --upgrade google-cloud-language`"
|
|
)
|
|
|
|
# Instantiates a client
|
|
self.client = language_v1.LanguageServiceClient()
|
|
self.moderate_text_request = language_v1.ModerateTextRequest
|
|
self.language_document = language_v1.types.Document # type: ignore
|
|
self.document_type = language_v1.types.Document.Type.PLAIN_TEXT # type: ignore
|
|
|
|
default_confidence_threshold = (
|
|
litellm.google_moderation_confidence_threshold or 0.8
|
|
) # by default require a high confidence (80%) to fail
|
|
|
|
for category in self.confidence_categories:
|
|
if hasattr(litellm, f"{category}_confidence_threshold"):
|
|
setattr(
|
|
self,
|
|
f"{category}_confidence_threshold",
|
|
getattr(litellm, f"{category}_confidence_threshold"),
|
|
)
|
|
else:
|
|
setattr(
|
|
self,
|
|
f"{category}_confidence_threshold",
|
|
default_confidence_threshold,
|
|
)
|
|
set_confidence_value = getattr(
|
|
self,
|
|
f"{category}_confidence_threshold",
|
|
)
|
|
verbose_proxy_logger.info(
|
|
f"Google Text Moderation: {category}_confidence_threshold: {set_confidence_value}"
|
|
)
|
|
|
|
def print_verbose(self, print_statement):
|
|
try:
|
|
verbose_proxy_logger.debug(print_statement)
|
|
if litellm.set_verbose:
|
|
print(print_statement) # noqa
|
|
except Exception:
|
|
pass
|
|
|
|
async def async_moderation_hook(
|
|
self,
|
|
data: dict,
|
|
user_api_key_dict: UserAPIKeyAuth,
|
|
call_type: Literal[
|
|
"completion",
|
|
"embeddings",
|
|
"image_generation",
|
|
"moderation",
|
|
"audio_transcription",
|
|
"responses",
|
|
],
|
|
):
|
|
"""
|
|
- Calls Google's Text Moderation API
|
|
- Rejects request if it fails safety check
|
|
"""
|
|
if "messages" in data and isinstance(data["messages"], list):
|
|
text = ""
|
|
for m in data["messages"]: # assume messages is a list
|
|
if "content" in m and isinstance(m["content"], str):
|
|
text += m["content"]
|
|
document = self.language_document(content=text, type_=self.document_type)
|
|
|
|
request = self.moderate_text_request(
|
|
document=document,
|
|
)
|
|
|
|
# Make the request
|
|
response = self.client.moderate_text(request=request)
|
|
for category in response.moderation_categories:
|
|
category_name = category.name
|
|
category_name = category_name.lower()
|
|
category_name = category_name.replace("&", "and")
|
|
category_name = category_name.replace(",", "")
|
|
category_name = category_name.replace(
|
|
" ", "_"
|
|
) # e.g. go from 'Firearms & Weapons' to 'firearms_and_weapons'
|
|
if category.confidence > getattr(
|
|
self, f"{category_name}_confidence_threshold"
|
|
):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={
|
|
"error": f"Violated content safety policy. Category={category}"
|
|
},
|
|
)
|
|
# Handle the response
|
|
return data
|
|
|
|
|
|
# google_text_moderation_obj = _ENTERPRISE_GoogleTextModeration()
|
|
# asyncio.run(
|
|
# google_text_moderation_obj.async_moderation_hook(
|
|
# data={"messages": [{"role": "user", "content": "Hey, how's it going?"}]}
|
|
# )
|
|
# )
|