feat(google_text_moderation.py): allow user to use google text moderation for content mod on proxy

This commit is contained in:
Krrish Dholakia 2024-02-17 18:36:29 -08:00
parent 73acdf3736
commit ddf0911c46
4 changed files with 122 additions and 19 deletions

View file

@ -26,10 +26,62 @@ import aiohttp, asyncio
class _ENTERPRISE_GoogleTextModeration(CustomLogger):
user_api_key_cache = None
confidence_categories = [
"toxic",
"insult",
"profanity",
"derogatory",
"sexual",
"death_harm_and_tragedy",
"violent",
"firearms_and_weapons",
"public_safety",
"health",
"religion_and_belief",
"illicit_drugs",
"war_and_conflict",
"politics",
"finance",
"legal",
] # https://cloud.google.com/natural-language/docs/moderating-text#safety_attribute_confidence_scores
# Class variables or attributes
def __init__(self, mock_testing: bool = False):
pass
def __init__(self):
try:
from google.cloud import language_v1
except:
raise Exception(
"Missing google.cloud package. Run `pip install --upgrade google-cloud-language`"
)
# Instantiates a client
self.client = language_v1.LanguageServiceClient()
self.moderate_text_request = language_v1.ModerateTextRequest
self.language_document = language_v1.types.Document
self.document_type = language_v1.types.Document.Type.PLAIN_TEXT
if hasattr(litellm, "google_moderation_confidence_threshold"):
default_confidence_threshold = (
litellm.google_moderation_confidence_threshold
)
else:
default_confidence_threshold = (
0.8 # by default require a high confidence (80%) to fail
)
for category in self.confidence_categories:
if hasattr(litellm, f"{category}_confidence_threshold"):
setattr(
self,
f"{category}_confidence_threshold",
getattr(litellm, f"{category}_confidence_threshold"),
)
else:
setattr(
self,
f"{category}_confidence_threshold",
default_confidence_threshold,
)
def print_verbose(self, print_statement):
try:
@ -39,15 +91,52 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger):
except:
pass
async def async_pre_call_hook(
async def async_moderation_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
):
"""
- Calls Google's Text Moderation API
- Rejects request if it fails safety check
"""
pass
if "messages" in data and isinstance(data["messages"], list):
text = ""
for m in data["messages"]: # assume messages is a list
if "content" in m and isinstance(m["content"], str):
text += m["content"]
document = self.language_document(content=text, type_=self.document_type)
request = self.moderate_text_request(
document=document,
)
# Make the request
response = self.client.moderate_text(request=request)
print(response)
for category in response.moderation_categories:
category_name = category.name
category_name = category_name.lower()
category_name = category_name.replace("&", "and")
category_name = category_name.replace(",", "")
category_name = category_name.replace(
" ", "_"
) # e.g. go from 'Firearms & Weapons' to 'firearms_and_weapons'
if category.confidence > getattr(
self, f"{category_name}_confidence_threshold"
):
raise HTTPException(
status_code=400,
detail={
"error": f"Violated content safety policy. Category={category}"
},
)
# Handle the response
return data
# google_text_moderation_obj = _ENTERPRISE_GoogleTextModeration()
# asyncio.run(
# google_text_moderation_obj.async_moderation_hook(
# data={"messages": [{"role": "user", "content": "Hey, how's it going?"}]}
# )
# )

View file

@ -54,18 +54,19 @@ class _ENTERPRISE_LlamaGuard(CustomLogger):
The llama guard prompt template is applied automatically in factory.py
"""
safety_check_messages = data["messages"][
-1
] # get the last response - llama guard has a 4k token limit
response = await litellm.acompletion(
model=self.model,
messages=[safety_check_messages],
hf_model_name="meta-llama/LlamaGuard-7b",
)
if "unsafe" in response.choices[0].message.content:
raise HTTPException(
status_code=400, detail={"error": "Violated content safety policy"}
if "messages" in data:
safety_check_messages = data["messages"][
-1
] # get the last response - llama guard has a 4k token limit
response = await litellm.acompletion(
model=self.model,
messages=[safety_check_messages],
hf_model_name="meta-llama/LlamaGuard-7b",
)
if "unsafe" in response.choices[0].message.content:
raise HTTPException(
status_code=400, detail={"error": "Violated content safety policy"}
)
return data

View file

@ -56,6 +56,7 @@ aleph_alpha_key: Optional[str] = None
nlp_cloud_key: Optional[str] = None
use_client: bool = False
llamaguard_model_name: Optional[str] = None
google_moderation_confidence_threshold: Optional[float] = None
logging: bool = True
caching: bool = (
False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648

View file

@ -1427,6 +1427,18 @@ class ProxyConfig:
llama_guard_object = _ENTERPRISE_LlamaGuard()
imported_list.append(llama_guard_object)
elif (
isinstance(callback, str)
and callback == "google_text_moderation"
):
from litellm.proxy.enterprise.enterprise_hooks.google_text_moderation import (
_ENTERPRISE_GoogleTextModeration,
)
google_text_moderation_obj = (
_ENTERPRISE_GoogleTextModeration()
)
imported_list.append(google_text_moderation_obj)
else:
imported_list.append(
get_instance_fn(