forked from phoenix/litellm-mirror
feat(google_text_moderation.py): allow user to use google text moderation for content mod on proxy
This commit is contained in:
parent
73acdf3736
commit
ddf0911c46
4 changed files with 122 additions and 19 deletions
|
@ -26,10 +26,62 @@ import aiohttp, asyncio
|
||||||
|
|
||||||
class _ENTERPRISE_GoogleTextModeration(CustomLogger):
|
class _ENTERPRISE_GoogleTextModeration(CustomLogger):
|
||||||
user_api_key_cache = None
|
user_api_key_cache = None
|
||||||
|
confidence_categories = [
|
||||||
|
"toxic",
|
||||||
|
"insult",
|
||||||
|
"profanity",
|
||||||
|
"derogatory",
|
||||||
|
"sexual",
|
||||||
|
"death_harm_and_tragedy",
|
||||||
|
"violent",
|
||||||
|
"firearms_and_weapons",
|
||||||
|
"public_safety",
|
||||||
|
"health",
|
||||||
|
"religion_and_belief",
|
||||||
|
"illicit_drugs",
|
||||||
|
"war_and_conflict",
|
||||||
|
"politics",
|
||||||
|
"finance",
|
||||||
|
"legal",
|
||||||
|
] # https://cloud.google.com/natural-language/docs/moderating-text#safety_attribute_confidence_scores
|
||||||
|
|
||||||
# Class variables or attributes
|
# Class variables or attributes
|
||||||
def __init__(self, mock_testing: bool = False):
|
def __init__(self):
|
||||||
pass
|
try:
|
||||||
|
from google.cloud import language_v1
|
||||||
|
except:
|
||||||
|
raise Exception(
|
||||||
|
"Missing google.cloud package. Run `pip install --upgrade google-cloud-language`"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Instantiates a client
|
||||||
|
self.client = language_v1.LanguageServiceClient()
|
||||||
|
self.moderate_text_request = language_v1.ModerateTextRequest
|
||||||
|
self.language_document = language_v1.types.Document
|
||||||
|
self.document_type = language_v1.types.Document.Type.PLAIN_TEXT
|
||||||
|
|
||||||
|
if hasattr(litellm, "google_moderation_confidence_threshold"):
|
||||||
|
default_confidence_threshold = (
|
||||||
|
litellm.google_moderation_confidence_threshold
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
default_confidence_threshold = (
|
||||||
|
0.8 # by default require a high confidence (80%) to fail
|
||||||
|
)
|
||||||
|
|
||||||
|
for category in self.confidence_categories:
|
||||||
|
if hasattr(litellm, f"{category}_confidence_threshold"):
|
||||||
|
setattr(
|
||||||
|
self,
|
||||||
|
f"{category}_confidence_threshold",
|
||||||
|
getattr(litellm, f"{category}_confidence_threshold"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
setattr(
|
||||||
|
self,
|
||||||
|
f"{category}_confidence_threshold",
|
||||||
|
default_confidence_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
def print_verbose(self, print_statement):
|
def print_verbose(self, print_statement):
|
||||||
try:
|
try:
|
||||||
|
@ -39,15 +91,52 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger):
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def async_pre_call_hook(
|
async def async_moderation_hook(
|
||||||
self,
|
self,
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
|
||||||
cache: DualCache,
|
|
||||||
data: dict,
|
data: dict,
|
||||||
call_type: str,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
- Calls Google's Text Moderation API
|
- Calls Google's Text Moderation API
|
||||||
- Rejects request if it fails safety check
|
- Rejects request if it fails safety check
|
||||||
"""
|
"""
|
||||||
pass
|
if "messages" in data and isinstance(data["messages"], list):
|
||||||
|
text = ""
|
||||||
|
for m in data["messages"]: # assume messages is a list
|
||||||
|
if "content" in m and isinstance(m["content"], str):
|
||||||
|
text += m["content"]
|
||||||
|
document = self.language_document(content=text, type_=self.document_type)
|
||||||
|
|
||||||
|
request = self.moderate_text_request(
|
||||||
|
document=document,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make the request
|
||||||
|
response = self.client.moderate_text(request=request)
|
||||||
|
print(response)
|
||||||
|
for category in response.moderation_categories:
|
||||||
|
category_name = category.name
|
||||||
|
category_name = category_name.lower()
|
||||||
|
category_name = category_name.replace("&", "and")
|
||||||
|
category_name = category_name.replace(",", "")
|
||||||
|
category_name = category_name.replace(
|
||||||
|
" ", "_"
|
||||||
|
) # e.g. go from 'Firearms & Weapons' to 'firearms_and_weapons'
|
||||||
|
if category.confidence > getattr(
|
||||||
|
self, f"{category_name}_confidence_threshold"
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": f"Violated content safety policy. Category={category}"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Handle the response
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
# google_text_moderation_obj = _ENTERPRISE_GoogleTextModeration()
|
||||||
|
# asyncio.run(
|
||||||
|
# google_text_moderation_obj.async_moderation_hook(
|
||||||
|
# data={"messages": [{"role": "user", "content": "Hey, how's it going?"}]}
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
|
|
@ -54,18 +54,19 @@ class _ENTERPRISE_LlamaGuard(CustomLogger):
|
||||||
|
|
||||||
The llama guard prompt template is applied automatically in factory.py
|
The llama guard prompt template is applied automatically in factory.py
|
||||||
"""
|
"""
|
||||||
safety_check_messages = data["messages"][
|
if "messages" in data:
|
||||||
-1
|
safety_check_messages = data["messages"][
|
||||||
] # get the last response - llama guard has a 4k token limit
|
-1
|
||||||
response = await litellm.acompletion(
|
] # get the last response - llama guard has a 4k token limit
|
||||||
model=self.model,
|
response = await litellm.acompletion(
|
||||||
messages=[safety_check_messages],
|
model=self.model,
|
||||||
hf_model_name="meta-llama/LlamaGuard-7b",
|
messages=[safety_check_messages],
|
||||||
)
|
hf_model_name="meta-llama/LlamaGuard-7b",
|
||||||
|
|
||||||
if "unsafe" in response.choices[0].message.content:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400, detail={"error": "Violated content safety policy"}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "unsafe" in response.choices[0].message.content:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail={"error": "Violated content safety policy"}
|
||||||
|
)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
|
@ -56,6 +56,7 @@ aleph_alpha_key: Optional[str] = None
|
||||||
nlp_cloud_key: Optional[str] = None
|
nlp_cloud_key: Optional[str] = None
|
||||||
use_client: bool = False
|
use_client: bool = False
|
||||||
llamaguard_model_name: Optional[str] = None
|
llamaguard_model_name: Optional[str] = None
|
||||||
|
google_moderation_confidence_threshold: Optional[float] = None
|
||||||
logging: bool = True
|
logging: bool = True
|
||||||
caching: bool = (
|
caching: bool = (
|
||||||
False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||||
|
|
|
@ -1427,6 +1427,18 @@ class ProxyConfig:
|
||||||
|
|
||||||
llama_guard_object = _ENTERPRISE_LlamaGuard()
|
llama_guard_object = _ENTERPRISE_LlamaGuard()
|
||||||
imported_list.append(llama_guard_object)
|
imported_list.append(llama_guard_object)
|
||||||
|
elif (
|
||||||
|
isinstance(callback, str)
|
||||||
|
and callback == "google_text_moderation"
|
||||||
|
):
|
||||||
|
from litellm.proxy.enterprise.enterprise_hooks.google_text_moderation import (
|
||||||
|
_ENTERPRISE_GoogleTextModeration,
|
||||||
|
)
|
||||||
|
|
||||||
|
google_text_moderation_obj = (
|
||||||
|
_ENTERPRISE_GoogleTextModeration()
|
||||||
|
)
|
||||||
|
imported_list.append(google_text_moderation_obj)
|
||||||
else:
|
else:
|
||||||
imported_list.append(
|
imported_list.append(
|
||||||
get_instance_fn(
|
get_instance_fn(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue