feat(llama_guard.py): allow user to define custom unsafe content categories

This commit is contained in:
Krrish Dholakia 2024-02-17 17:42:47 -08:00
parent e3fab50853
commit 074d93cc97
10 changed files with 187 additions and 80 deletions

View file

@ -212,6 +212,15 @@ def completion(
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages,
)
elif hf_model_name in custom_prompt_dict:
# check if the base huggingface model has a registered custom prompt
model_prompt_details = custom_prompt_dict[hf_model_name]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages,
)
else:
if hf_model_name is None:
if "llama-2" in model.lower(): # llama-2 model