forked from phoenix/litellm-mirror
feat(llama_guard.py): allow user to define custom unsafe content categories
This commit is contained in:
parent
e3fab50853
commit
074d93cc97
10 changed files with 187 additions and 80 deletions
125
enterprise/enterprise_hooks/llama_guard.py
Normal file
125
enterprise/enterprise_hooks/llama_guard.py
Normal file
|
@ -0,0 +1,125 @@
|
|||
# +-------------------------------------------------------------+
|
||||
#
|
||||
# Llama Guard
|
||||
# https://huggingface.co/meta-llama/LlamaGuard-7b/tree/main
|
||||
#
|
||||
# LLM for Content Moderation
|
||||
# +-------------------------------------------------------------+
|
||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import sys, os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from typing import Optional, Literal, Union
|
||||
import litellm, traceback, sys, uuid
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from fastapi import HTTPException
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.utils import (
|
||||
ModelResponse,
|
||||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
StreamingChoices,
|
||||
)
|
||||
from datetime import datetime
|
||||
import aiohttp, asyncio
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
|
||||
class _ENTERPRISE_LlamaGuard(CustomLogger):
|
||||
# Class variables or attributes
|
||||
def __init__(self, model_name: Optional[str] = None):
|
||||
self.model = model_name or litellm.llamaguard_model_name
|
||||
file_path = litellm.llamaguard_unsafe_content_categories
|
||||
data = None
|
||||
|
||||
if file_path is not None:
|
||||
try:
|
||||
with open(file_path, "r") as file:
|
||||
data = file.read()
|
||||
except FileNotFoundError:
|
||||
raise Exception(f"File not found. file_path={file_path}")
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred: {str(e)}, file_path={file_path}")
|
||||
|
||||
self.unsafe_content_categories = data
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"self.unsafe_content_categories: {self.unsafe_content_categories}"
|
||||
)
|
||||
|
||||
def print_verbose(self, print_statement):
|
||||
try:
|
||||
verbose_proxy_logger.debug(print_statement)
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
except:
|
||||
pass
|
||||
|
||||
def set_custom_prompt_template(self, messages: list):
|
||||
if self.unsafe_content_categories is not None and self.model is not None:
|
||||
role = "Agent" if len(messages) % 2 == 0 else "User"
|
||||
litellm.register_prompt_template(
|
||||
model=self.model,
|
||||
initial_prompt_value=f"""[INST] Task: Check if there is unsafe content in '{role}' messages in conversations according our safety policy with the below categories.
|
||||
|
||||
<BEGIN UNSAFE CONTENT CATEGORIES>
|
||||
{self.unsafe_content_categories}
|
||||
<END UNSAFE CONTENT CATEGORIES>
|
||||
|
||||
<BEGIN CONVERSATION>
|
||||
|
||||
""",
|
||||
roles={
|
||||
"system": {
|
||||
"pre_message": "User: ", # [OPTIONAL]
|
||||
"post_message": "\n\n", # [OPTIONAL]
|
||||
},
|
||||
"user": {
|
||||
"pre_message": "User: ", # [OPTIONAL]
|
||||
"post_message": "\n\n", # [OPTIONAL]
|
||||
},
|
||||
"assistant": {
|
||||
"pre_message": "Agent: ", # [OPTIONAL]
|
||||
"post_message": "\n\n", # [OPTIONAL]
|
||||
},
|
||||
},
|
||||
final_prompt_value="""<END CONVERSATION>
|
||||
|
||||
Provide your safety assessment for {role} in the above conversation:
|
||||
- First line must read 'safe' or 'unsafe'.
|
||||
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""",
|
||||
)
|
||||
|
||||
async def async_moderation_hook(
|
||||
self,
|
||||
data: dict,
|
||||
):
|
||||
"""
|
||||
- Calls the Llama Guard Endpoint
|
||||
- Rejects request if it fails safety check
|
||||
|
||||
The llama guard prompt template is applied automatically in factory.py
|
||||
"""
|
||||
safety_check_messages = data["messages"][
|
||||
-1
|
||||
] # get the last response - llama guard has a 4k token limit
|
||||
self.set_custom_prompt_template(messages=[safety_check_messages])
|
||||
# print(f"self.model: {self.model}")
|
||||
response = await litellm.acompletion(
|
||||
model=self.model,
|
||||
messages=[safety_check_messages],
|
||||
hf_model_name="meta-llama/LlamaGuard-7b",
|
||||
)
|
||||
verbose_proxy_logger.info(f"LlamaGuard Response: {response}")
|
||||
if "unsafe" in response.choices[0].message.content:
|
||||
raise HTTPException(
|
||||
status_code=400, detail={"error": "Violated content safety policy"}
|
||||
)
|
||||
|
||||
return data
|
|
@ -1,71 +0,0 @@
|
|||
# +-------------------------------------------------------------+
|
||||
#
|
||||
# Llama Guard
|
||||
# https://huggingface.co/meta-llama/LlamaGuard-7b/tree/main
|
||||
#
|
||||
# LLM for Content Moderation
|
||||
# +-------------------------------------------------------------+
|
||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import sys, os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from typing import Optional, Literal, Union
|
||||
import litellm, traceback, sys, uuid
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from fastapi import HTTPException
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.utils import (
|
||||
ModelResponse,
|
||||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
StreamingChoices,
|
||||
)
|
||||
from datetime import datetime
|
||||
import aiohttp, asyncio
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
|
||||
class _ENTERPRISE_LlamaGuard(CustomLogger):
|
||||
# Class variables or attributes
|
||||
def __init__(self, model_name: Optional[str] = None):
|
||||
self.model = model_name or litellm.llamaguard_model_name
|
||||
|
||||
def print_verbose(self, print_statement):
|
||||
try:
|
||||
verbose_proxy_logger.debug(print_statement)
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
except:
|
||||
pass
|
||||
|
||||
async def async_moderation_hook(
|
||||
self,
|
||||
data: dict,
|
||||
):
|
||||
"""
|
||||
- Calls the Llama Guard Endpoint
|
||||
- Rejects request if it fails safety check
|
||||
|
||||
The llama guard prompt template is applied automatically in factory.py
|
||||
"""
|
||||
safety_check_messages = data["messages"][
|
||||
-1
|
||||
] # get the last response - llama guard has a 4k token limit
|
||||
response = await litellm.acompletion(
|
||||
model=self.model,
|
||||
messages=[safety_check_messages],
|
||||
hf_model_name="meta-llama/LlamaGuard-7b",
|
||||
)
|
||||
|
||||
if "unsafe" in response.choices[0].message.content:
|
||||
raise HTTPException(
|
||||
status_code=400, detail={"error": "Violated content safety policy"}
|
||||
)
|
||||
|
||||
return data
|
|
@ -56,6 +56,7 @@ aleph_alpha_key: Optional[str] = None
|
|||
nlp_cloud_key: Optional[str] = None
|
||||
use_client: bool = False
|
||||
llamaguard_model_name: Optional[str] = None
|
||||
llamaguard_unsafe_content_categories: Optional[str] = None
|
||||
logging: bool = True
|
||||
caching: bool = (
|
||||
False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
|
|
|
@ -212,6 +212,15 @@ def completion(
|
|||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||
messages=messages,
|
||||
)
|
||||
elif hf_model_name in custom_prompt_dict:
|
||||
# check if the base huggingface model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[hf_model_name]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details.get("roles", None),
|
||||
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
if hf_model_name is None:
|
||||
if "llama-2" in model.lower(): # llama-2 model
|
||||
|
|
46
litellm/proxy/llamaguard_prompt.txt
Normal file
46
litellm/proxy/llamaguard_prompt.txt
Normal file
|
@ -0,0 +1,46 @@
|
|||
O1: Violence and Hate.
|
||||
Should not
|
||||
- Help people plan or engage in violence.
|
||||
- Encourage people to engage in violence.
|
||||
- Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability.
|
||||
- Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics.
|
||||
Can
|
||||
- Provide information on violence and discrimination.
|
||||
- Discuss the topics of hate and violence, and can discuss historical events involving violence.
|
||||
O2: Sexual Content.
|
||||
Should not
|
||||
- Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities.
|
||||
Can
|
||||
- Discuss topics related to sexuality and sexual education.
|
||||
- Provide general information regarding sexual activities.
|
||||
- Provide guidance on topics related to sexual health.
|
||||
O3: Criminal Planning.
|
||||
Should not
|
||||
- Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them.
|
||||
- Encourage people to engage in criminal activities or condone criminal behavior.
|
||||
Can
|
||||
- Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes.
|
||||
- Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems.
|
||||
O4: Guns and Illegal Weapons.
|
||||
Should not
|
||||
- Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them.
|
||||
- Encourage people to engage in firearm or illegal weapons crimes or condone such behavior.
|
||||
Can
|
||||
- Discuss firearms and the arguments for and against firearm ownership.
|
||||
O5: Regulated or Controlled Substances.
|
||||
Should not
|
||||
- Encourage or assist people in the trafficking, consumption, or use of regulated or controlled substances, including but not limited to illegal drugs, tobacco, alcohol, and cannabis.
|
||||
- Assist or encourage people to create such substances when it is illegal to do so.
|
||||
Can
|
||||
- Discuss the use of regulated or controlled substances, and can provide general information on those substances (ex: what their chemical composition is and what their effects are).
|
||||
- Provide information on how those substances are manufactured, provided that the information is not so detailed as to provide step-by-step instructions for their creation.
|
||||
O6: Self-Harm.
|
||||
Should not
|
||||
- Encourage people to harm themselves.
|
||||
- Romanticize or condone self-harm or suicide.
|
||||
- Provide information regarding the methods of suicide or self-harm.
|
||||
- Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves.
|
||||
Can
|
||||
- Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions).
|
||||
Should
|
||||
- Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves.
|
|
@ -1381,7 +1381,7 @@ class ProxyConfig:
|
|||
isinstance(callback, str)
|
||||
and callback == "llamaguard_moderations"
|
||||
):
|
||||
from litellm.proxy.enterprise.hooks.llama_guard import (
|
||||
from litellm.proxy.enterprise.enterprise_hooks.llama_guard import (
|
||||
_ENTERPRISE_LlamaGuard,
|
||||
)
|
||||
|
||||
|
|
|
@ -34,6 +34,7 @@ from dataclasses import (
|
|||
try:
|
||||
# this works in python 3.8
|
||||
import pkg_resources
|
||||
|
||||
filename = pkg_resources.resource_filename(__name__, "llms/tokenizers")
|
||||
# try:
|
||||
# filename = str(
|
||||
|
@ -42,6 +43,7 @@ try:
|
|||
except:
|
||||
# this works in python 3.9+
|
||||
from importlib import resources
|
||||
|
||||
filename = str(
|
||||
resources.files(litellm).joinpath("llms/tokenizers") # for python 3.10
|
||||
) # for python 3.10+
|
||||
|
@ -87,16 +89,11 @@ from .exceptions import (
|
|||
UnprocessableEntityError,
|
||||
)
|
||||
|
||||
# Import Enterprise features
|
||||
project_path = abspath(join(dirname(__file__), "..", ".."))
|
||||
# Add the "enterprise" directory to sys.path
|
||||
verbose_logger.debug(f"current project_path: {project_path}")
|
||||
enterprise_path = abspath(join(project_path, "enterprise"))
|
||||
sys.path.append(enterprise_path)
|
||||
|
||||
verbose_logger.debug(f"sys.path: {sys.path}")
|
||||
try:
|
||||
from enterprise.callbacks.generic_api_callback import GenericAPILogger
|
||||
from .proxy.enterprise.enterprise_callbacks.generic_api_callback import (
|
||||
GenericAPILogger,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Exception import enterprise features {str(e)}")
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue