From ddf0911c4632342cbf1ec18f086e44bb8d0b3ae7 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 17 Feb 2024 18:36:29 -0800 Subject: [PATCH] feat(google_text_moderation.py): allow user to use google text moderation for content mod on proxy --- .../google_text_moderation.py | 103 ++++++++++++++++-- enterprise/enterprise_hooks/llama_guard.py | 25 +++-- litellm/__init__.py | 1 + litellm/proxy/proxy_server.py | 12 ++ 4 files changed, 122 insertions(+), 19 deletions(-) diff --git a/enterprise/enterprise_hooks/google_text_moderation.py b/enterprise/enterprise_hooks/google_text_moderation.py index b79a32805..a6a48a385 100644 --- a/enterprise/enterprise_hooks/google_text_moderation.py +++ b/enterprise/enterprise_hooks/google_text_moderation.py @@ -26,10 +26,62 @@ import aiohttp, asyncio class _ENTERPRISE_GoogleTextModeration(CustomLogger): user_api_key_cache = None + confidence_categories = [ + "toxic", + "insult", + "profanity", + "derogatory", + "sexual", + "death_harm_and_tragedy", + "violent", + "firearms_and_weapons", + "public_safety", + "health", + "religion_and_belief", + "illicit_drugs", + "war_and_conflict", + "politics", + "finance", + "legal", + ] # https://cloud.google.com/natural-language/docs/moderating-text#safety_attribute_confidence_scores # Class variables or attributes - def __init__(self, mock_testing: bool = False): - pass + def __init__(self): + try: + from google.cloud import language_v1 + except: + raise Exception( + "Missing google.cloud package. Run `pip install --upgrade google-cloud-language`" + ) + + # Instantiates a client + self.client = language_v1.LanguageServiceClient() + self.moderate_text_request = language_v1.ModerateTextRequest + self.language_document = language_v1.types.Document + self.document_type = language_v1.types.Document.Type.PLAIN_TEXT + + if hasattr(litellm, "google_moderation_confidence_threshold"): + default_confidence_threshold = ( + litellm.google_moderation_confidence_threshold + ) + else: + default_confidence_threshold = ( + 0.8 # by default require a high confidence (80%) to fail + ) + + for category in self.confidence_categories: + if hasattr(litellm, f"{category}_confidence_threshold"): + setattr( + self, + f"{category}_confidence_threshold", + getattr(litellm, f"{category}_confidence_threshold"), + ) + else: + setattr( + self, + f"{category}_confidence_threshold", + default_confidence_threshold, + ) def print_verbose(self, print_statement): try: @@ -39,15 +91,52 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger): except: pass - async def async_pre_call_hook( + async def async_moderation_hook( self, - user_api_key_dict: UserAPIKeyAuth, - cache: DualCache, data: dict, - call_type: str, ): """ - Calls Google's Text Moderation API - Rejects request if it fails safety check """ - pass + if "messages" in data and isinstance(data["messages"], list): + text = "" + for m in data["messages"]: # assume messages is a list + if "content" in m and isinstance(m["content"], str): + text += m["content"] + document = self.language_document(content=text, type_=self.document_type) + + request = self.moderate_text_request( + document=document, + ) + + # Make the request + response = self.client.moderate_text(request=request) + print(response) + for category in response.moderation_categories: + category_name = category.name + category_name = category_name.lower() + category_name = category_name.replace("&", "and") + category_name = category_name.replace(",", "") + category_name = category_name.replace( + " ", "_" + ) # e.g. go from 'Firearms & Weapons' to 'firearms_and_weapons' + if category.confidence > getattr( + self, f"{category_name}_confidence_threshold" + ): + raise HTTPException( + status_code=400, + detail={ + "error": f"Violated content safety policy. Category={category}" + }, + ) + # Handle the response + return data + + +# google_text_moderation_obj = _ENTERPRISE_GoogleTextModeration() +# asyncio.run( +# google_text_moderation_obj.async_moderation_hook( +# data={"messages": [{"role": "user", "content": "Hey, how's it going?"}]} +# ) +# ) diff --git a/enterprise/enterprise_hooks/llama_guard.py b/enterprise/enterprise_hooks/llama_guard.py index c4f45909e..50cd64c9e 100644 --- a/enterprise/enterprise_hooks/llama_guard.py +++ b/enterprise/enterprise_hooks/llama_guard.py @@ -54,18 +54,19 @@ class _ENTERPRISE_LlamaGuard(CustomLogger): The llama guard prompt template is applied automatically in factory.py """ - safety_check_messages = data["messages"][ - -1 - ] # get the last response - llama guard has a 4k token limit - response = await litellm.acompletion( - model=self.model, - messages=[safety_check_messages], - hf_model_name="meta-llama/LlamaGuard-7b", - ) - - if "unsafe" in response.choices[0].message.content: - raise HTTPException( - status_code=400, detail={"error": "Violated content safety policy"} + if "messages" in data: + safety_check_messages = data["messages"][ + -1 + ] # get the last response - llama guard has a 4k token limit + response = await litellm.acompletion( + model=self.model, + messages=[safety_check_messages], + hf_model_name="meta-llama/LlamaGuard-7b", ) + if "unsafe" in response.choices[0].message.content: + raise HTTPException( + status_code=400, detail={"error": "Violated content safety policy"} + ) + return data diff --git a/litellm/__init__.py b/litellm/__init__.py index c263c8e8e..1b1ddf280 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -56,6 +56,7 @@ aleph_alpha_key: Optional[str] = None nlp_cloud_key: Optional[str] = None use_client: bool = False llamaguard_model_name: Optional[str] = None +google_moderation_confidence_threshold: Optional[float] = None logging: bool = True caching: bool = ( False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 0f4f3d5ef..e28b76aba 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1427,6 +1427,18 @@ class ProxyConfig: llama_guard_object = _ENTERPRISE_LlamaGuard() imported_list.append(llama_guard_object) + elif ( + isinstance(callback, str) + and callback == "google_text_moderation" + ): + from litellm.proxy.enterprise.enterprise_hooks.google_text_moderation import ( + _ENTERPRISE_GoogleTextModeration, + ) + + google_text_moderation_obj = ( + _ENTERPRISE_GoogleTextModeration() + ) + imported_list.append(google_text_moderation_obj) else: imported_list.append( get_instance_fn(