From 074d93cc97c4b4e850ceab05270c617b43e22102 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 17 Feb 2024 17:42:47 -0800 Subject: [PATCH] feat(llama_guard.py): allow user to define custom unsafe content categories --- .../example_logging_api.py | 0 .../generic_api_callback.py | 0 .../google_text_moderation.py | 0 enterprise/enterprise_hooks/llama_guard.py | 125 ++++++++++++++++++ enterprise/hooks/llama_guard.py | 71 ---------- litellm/__init__.py | 1 + litellm/llms/sagemaker.py | 9 ++ litellm/proxy/llamaguard_prompt.txt | 46 +++++++ litellm/proxy/proxy_server.py | 2 +- litellm/utils.py | 13 +- 10 files changed, 187 insertions(+), 80 deletions(-) rename enterprise/{callbacks => enterprise_callbacks}/example_logging_api.py (100%) rename enterprise/{callbacks => enterprise_callbacks}/generic_api_callback.py (100%) rename enterprise/{hooks => enterprise_hooks}/google_text_moderation.py (100%) create mode 100644 enterprise/enterprise_hooks/llama_guard.py delete mode 100644 enterprise/hooks/llama_guard.py create mode 100644 litellm/proxy/llamaguard_prompt.txt diff --git a/enterprise/callbacks/example_logging_api.py b/enterprise/enterprise_callbacks/example_logging_api.py similarity index 100% rename from enterprise/callbacks/example_logging_api.py rename to enterprise/enterprise_callbacks/example_logging_api.py diff --git a/enterprise/callbacks/generic_api_callback.py b/enterprise/enterprise_callbacks/generic_api_callback.py similarity index 100% rename from enterprise/callbacks/generic_api_callback.py rename to enterprise/enterprise_callbacks/generic_api_callback.py diff --git a/enterprise/hooks/google_text_moderation.py b/enterprise/enterprise_hooks/google_text_moderation.py similarity index 100% rename from enterprise/hooks/google_text_moderation.py rename to enterprise/enterprise_hooks/google_text_moderation.py diff --git a/enterprise/enterprise_hooks/llama_guard.py b/enterprise/enterprise_hooks/llama_guard.py new file mode 100644 index 000000000..4d2139cd2 --- /dev/null +++ b/enterprise/enterprise_hooks/llama_guard.py @@ -0,0 +1,125 @@ +# +-------------------------------------------------------------+ +# +# Llama Guard +# https://huggingface.co/meta-llama/LlamaGuard-7b/tree/main +# +# LLM for Content Moderation +# +-------------------------------------------------------------+ +# Thank you users! We ❤️ you! - Krrish & Ishaan + +import sys, os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +from typing import Optional, Literal, Union +import litellm, traceback, sys, uuid +from litellm.caching import DualCache +from litellm.proxy._types import UserAPIKeyAuth +from litellm.integrations.custom_logger import CustomLogger +from fastapi import HTTPException +from litellm._logging import verbose_proxy_logger +from litellm.utils import ( + ModelResponse, + EmbeddingResponse, + ImageResponse, + StreamingChoices, +) +from datetime import datetime +import aiohttp, asyncio + +litellm.set_verbose = True + + +class _ENTERPRISE_LlamaGuard(CustomLogger): + # Class variables or attributes + def __init__(self, model_name: Optional[str] = None): + self.model = model_name or litellm.llamaguard_model_name + file_path = litellm.llamaguard_unsafe_content_categories + data = None + + if file_path is not None: + try: + with open(file_path, "r") as file: + data = file.read() + except FileNotFoundError: + raise Exception(f"File not found. file_path={file_path}") + except Exception as e: + raise Exception(f"An error occurred: {str(e)}, file_path={file_path}") + + self.unsafe_content_categories = data + + verbose_proxy_logger.debug( + f"self.unsafe_content_categories: {self.unsafe_content_categories}" + ) + + def print_verbose(self, print_statement): + try: + verbose_proxy_logger.debug(print_statement) + if litellm.set_verbose: + print(print_statement) # noqa + except: + pass + + def set_custom_prompt_template(self, messages: list): + if self.unsafe_content_categories is not None and self.model is not None: + role = "Agent" if len(messages) % 2 == 0 else "User" + litellm.register_prompt_template( + model=self.model, + initial_prompt_value=f"""[INST] Task: Check if there is unsafe content in '{role}' messages in conversations according our safety policy with the below categories. + + + {self.unsafe_content_categories} + + + + + """, + roles={ + "system": { + "pre_message": "User: ", # [OPTIONAL] + "post_message": "\n\n", # [OPTIONAL] + }, + "user": { + "pre_message": "User: ", # [OPTIONAL] + "post_message": "\n\n", # [OPTIONAL] + }, + "assistant": { + "pre_message": "Agent: ", # [OPTIONAL] + "post_message": "\n\n", # [OPTIONAL] + }, + }, + final_prompt_value=""" + + Provide your safety assessment for {role} in the above conversation: + - First line must read 'safe' or 'unsafe'. + - If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""", + ) + + async def async_moderation_hook( + self, + data: dict, + ): + """ + - Calls the Llama Guard Endpoint + - Rejects request if it fails safety check + + The llama guard prompt template is applied automatically in factory.py + """ + safety_check_messages = data["messages"][ + -1 + ] # get the last response - llama guard has a 4k token limit + self.set_custom_prompt_template(messages=[safety_check_messages]) + # print(f"self.model: {self.model}") + response = await litellm.acompletion( + model=self.model, + messages=[safety_check_messages], + hf_model_name="meta-llama/LlamaGuard-7b", + ) + verbose_proxy_logger.info(f"LlamaGuard Response: {response}") + if "unsafe" in response.choices[0].message.content: + raise HTTPException( + status_code=400, detail={"error": "Violated content safety policy"} + ) + + return data diff --git a/enterprise/hooks/llama_guard.py b/enterprise/hooks/llama_guard.py deleted file mode 100644 index c4f45909e..000000000 --- a/enterprise/hooks/llama_guard.py +++ /dev/null @@ -1,71 +0,0 @@ -# +-------------------------------------------------------------+ -# -# Llama Guard -# https://huggingface.co/meta-llama/LlamaGuard-7b/tree/main -# -# LLM for Content Moderation -# +-------------------------------------------------------------+ -# Thank you users! We ❤️ you! - Krrish & Ishaan - -import sys, os - -sys.path.insert( - 0, os.path.abspath("../..") -) # Adds the parent directory to the system path -from typing import Optional, Literal, Union -import litellm, traceback, sys, uuid -from litellm.caching import DualCache -from litellm.proxy._types import UserAPIKeyAuth -from litellm.integrations.custom_logger import CustomLogger -from fastapi import HTTPException -from litellm._logging import verbose_proxy_logger -from litellm.utils import ( - ModelResponse, - EmbeddingResponse, - ImageResponse, - StreamingChoices, -) -from datetime import datetime -import aiohttp, asyncio - -litellm.set_verbose = True - - -class _ENTERPRISE_LlamaGuard(CustomLogger): - # Class variables or attributes - def __init__(self, model_name: Optional[str] = None): - self.model = model_name or litellm.llamaguard_model_name - - def print_verbose(self, print_statement): - try: - verbose_proxy_logger.debug(print_statement) - if litellm.set_verbose: - print(print_statement) # noqa - except: - pass - - async def async_moderation_hook( - self, - data: dict, - ): - """ - - Calls the Llama Guard Endpoint - - Rejects request if it fails safety check - - The llama guard prompt template is applied automatically in factory.py - """ - safety_check_messages = data["messages"][ - -1 - ] # get the last response - llama guard has a 4k token limit - response = await litellm.acompletion( - model=self.model, - messages=[safety_check_messages], - hf_model_name="meta-llama/LlamaGuard-7b", - ) - - if "unsafe" in response.choices[0].message.content: - raise HTTPException( - status_code=400, detail={"error": "Violated content safety policy"} - ) - - return data diff --git a/litellm/__init__.py b/litellm/__init__.py index c263c8e8e..150405b7f 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -56,6 +56,7 @@ aleph_alpha_key: Optional[str] = None nlp_cloud_key: Optional[str] = None use_client: bool = False llamaguard_model_name: Optional[str] = None +llamaguard_unsafe_content_categories: Optional[str] = None logging: bool = True caching: bool = ( False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index 88e486f89..535213bd8 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -212,6 +212,15 @@ def completion( final_prompt_value=model_prompt_details.get("final_prompt_value", ""), messages=messages, ) + elif hf_model_name in custom_prompt_dict: + # check if the base huggingface model has a registered custom prompt + model_prompt_details = custom_prompt_dict[hf_model_name] + prompt = custom_prompt( + role_dict=model_prompt_details.get("roles", None), + initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), + final_prompt_value=model_prompt_details.get("final_prompt_value", ""), + messages=messages, + ) else: if hf_model_name is None: if "llama-2" in model.lower(): # llama-2 model diff --git a/litellm/proxy/llamaguard_prompt.txt b/litellm/proxy/llamaguard_prompt.txt new file mode 100644 index 000000000..25cd8bce2 --- /dev/null +++ b/litellm/proxy/llamaguard_prompt.txt @@ -0,0 +1,46 @@ +O1: Violence and Hate. +Should not +- Help people plan or engage in violence. +- Encourage people to engage in violence. +- Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability. +- Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics. +Can +- Provide information on violence and discrimination. +- Discuss the topics of hate and violence, and can discuss historical events involving violence. +O2: Sexual Content. +Should not +- Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities. +Can +- Discuss topics related to sexuality and sexual education. +- Provide general information regarding sexual activities. +- Provide guidance on topics related to sexual health. +O3: Criminal Planning. +Should not +- Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them. +- Encourage people to engage in criminal activities or condone criminal behavior. +Can +- Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes. +- Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems. +O4: Guns and Illegal Weapons. +Should not +- Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them. +- Encourage people to engage in firearm or illegal weapons crimes or condone such behavior. +Can +- Discuss firearms and the arguments for and against firearm ownership. +O5: Regulated or Controlled Substances. +Should not +- Encourage or assist people in the trafficking, consumption, or use of regulated or controlled substances, including but not limited to illegal drugs, tobacco, alcohol, and cannabis. +- Assist or encourage people to create such substances when it is illegal to do so. +Can +- Discuss the use of regulated or controlled substances, and can provide general information on those substances (ex: what their chemical composition is and what their effects are). +- Provide information on how those substances are manufactured, provided that the information is not so detailed as to provide step-by-step instructions for their creation. +O6: Self-Harm. +Should not +- Encourage people to harm themselves. +- Romanticize or condone self-harm or suicide. +- Provide information regarding the methods of suicide or self-harm. +- Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves. +Can +- Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions). +Should +- Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves. \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index b5349fffb..3b6393dca 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1381,7 +1381,7 @@ class ProxyConfig: isinstance(callback, str) and callback == "llamaguard_moderations" ): - from litellm.proxy.enterprise.hooks.llama_guard import ( + from litellm.proxy.enterprise.enterprise_hooks.llama_guard import ( _ENTERPRISE_LlamaGuard, ) diff --git a/litellm/utils.py b/litellm/utils.py index d173b0092..194c4e846 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -34,6 +34,7 @@ from dataclasses import ( try: # this works in python 3.8 import pkg_resources + filename = pkg_resources.resource_filename(__name__, "llms/tokenizers") # try: # filename = str( @@ -42,6 +43,7 @@ try: except: # this works in python 3.9+ from importlib import resources + filename = str( resources.files(litellm).joinpath("llms/tokenizers") # for python 3.10 ) # for python 3.10+ @@ -87,16 +89,11 @@ from .exceptions import ( UnprocessableEntityError, ) -# Import Enterprise features -project_path = abspath(join(dirname(__file__), "..", "..")) -# Add the "enterprise" directory to sys.path -verbose_logger.debug(f"current project_path: {project_path}") -enterprise_path = abspath(join(project_path, "enterprise")) -sys.path.append(enterprise_path) - verbose_logger.debug(f"sys.path: {sys.path}") try: - from enterprise.callbacks.generic_api_callback import GenericAPILogger + from .proxy.enterprise.enterprise_callbacks.generic_api_callback import ( + GenericAPILogger, + ) except Exception as e: verbose_logger.debug(f"Exception import enterprise features {str(e)}")