mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-28 15:02:37 +00:00
Merge pull request #2 from meta-llama/revert-1-fix_llama_guard
Revert "Update llama guard file to latest version"
This commit is contained in:
commit
8cd2e4164c
1 changed files with 5 additions and 4 deletions
|
@ -1,4 +1,3 @@
|
|||
|
||||
import re
|
||||
|
||||
from string import Template
|
||||
|
@ -6,6 +5,7 @@ from typing import List, Optional
|
|||
|
||||
import torch
|
||||
from llama_models.llama3_1.api.datatypes import Message
|
||||
from termcolor import cprint
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
||||
|
@ -177,10 +177,10 @@ class LlamaGuardShield(ShieldBase):
|
|||
categories = self.get_safety_categories()
|
||||
categories_str = "\n".join(categories)
|
||||
conversations_str = "\n\n".join(
|
||||
[f"{m.role.name.capitalize()}: {m.content}" for m in messages]
|
||||
[f"{m.role.capitalize()}: {m.content}" for m in messages]
|
||||
)
|
||||
return PROMPT_TEMPLATE.substitute(
|
||||
agent_type=messages[-1].role.name.capitalize(),
|
||||
agent_type=messages[-1].role.capitalize(),
|
||||
categories=categories_str,
|
||||
conversations=conversations_str,
|
||||
)
|
||||
|
@ -230,7 +230,7 @@ class LlamaGuardShield(ShieldBase):
|
|||
prompt_len = input_ids.shape[1]
|
||||
output = self.model.generate(
|
||||
input_ids=input_ids,
|
||||
max_new_tokens=50,
|
||||
max_new_tokens=20,
|
||||
output_scores=True,
|
||||
return_dict_in_generate=True,
|
||||
pad_token_id=0,
|
||||
|
@ -244,4 +244,5 @@ class LlamaGuardShield(ShieldBase):
|
|||
response = response.strip()
|
||||
shield_response = self.get_shield_response(response)
|
||||
|
||||
cprint(f"Final Llama Guard response {shield_response}", color="magenta")
|
||||
return shield_response
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue