forked from phoenix-oss/llama-stack-mirror
		
	Add provider deprecation support; change directory structure (#397)
* Add provider deprecation support; change directory structure * fix a couple dangling imports * move the meta_reference safety dir also
This commit is contained in:
		
							parent
							
								
									36e2538eb0
								
							
						
					
					
						commit
						694c142b89
					
				
					 58 changed files with 61 additions and 120 deletions
				
			
		|  | @ -0,0 +1,17 @@ | |||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the terms described in the LICENSE file in | ||||
| # the root directory of this source tree. | ||||
| 
 | ||||
| from .config import LlamaGuardShieldConfig, SafetyConfig  # noqa: F401 | ||||
| 
 | ||||
| 
 | ||||
| async def get_provider_impl(config: SafetyConfig, deps): | ||||
|     from .safety import MetaReferenceSafetyImpl | ||||
| 
 | ||||
|     assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}" | ||||
| 
 | ||||
|     impl = MetaReferenceSafetyImpl(config, deps) | ||||
|     await impl.initialize() | ||||
|     return impl | ||||
							
								
								
									
										57
									
								
								llama_stack/providers/inline/safety/meta_reference/base.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								llama_stack/providers/inline/safety/meta_reference/base.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,57 @@ | |||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the terms described in the LICENSE file in | ||||
| # the root directory of this source tree. | ||||
| 
 | ||||
| from abc import ABC, abstractmethod | ||||
| from typing import List | ||||
| 
 | ||||
| from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message | ||||
| from pydantic import BaseModel | ||||
| from llama_stack.apis.safety import *  # noqa: F403 | ||||
| 
 | ||||
| CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" | ||||
| 
 | ||||
| 
 | ||||
| # TODO: clean this up; just remove this type completely | ||||
| class ShieldResponse(BaseModel): | ||||
|     is_violation: bool | ||||
|     violation_type: Optional[str] = None | ||||
|     violation_return_message: Optional[str] = None | ||||
| 
 | ||||
| 
 | ||||
| # TODO: this is a caller / agent concern | ||||
| class OnViolationAction(Enum): | ||||
|     IGNORE = 0 | ||||
|     WARN = 1 | ||||
|     RAISE = 2 | ||||
| 
 | ||||
| 
 | ||||
| class ShieldBase(ABC): | ||||
|     def __init__( | ||||
|         self, | ||||
|         on_violation_action: OnViolationAction = OnViolationAction.RAISE, | ||||
|     ): | ||||
|         self.on_violation_action = on_violation_action | ||||
| 
 | ||||
|     @abstractmethod | ||||
|     async def run(self, messages: List[Message]) -> ShieldResponse: | ||||
|         raise NotImplementedError() | ||||
| 
 | ||||
| 
 | ||||
| def message_content_as_str(message: Message) -> str: | ||||
|     return interleaved_text_media_as_str(message.content) | ||||
| 
 | ||||
| 
 | ||||
| class TextShield(ShieldBase): | ||||
|     def convert_messages_to_text(self, messages: List[Message]) -> str: | ||||
|         return "\n".join([message_content_as_str(m) for m in messages]) | ||||
| 
 | ||||
|     async def run(self, messages: List[Message]) -> ShieldResponse: | ||||
|         text = self.convert_messages_to_text(messages) | ||||
|         return await self.run_impl(text) | ||||
| 
 | ||||
|     @abstractmethod | ||||
|     async def run_impl(self, text: str) -> ShieldResponse: | ||||
|         raise NotImplementedError() | ||||
							
								
								
									
										48
									
								
								llama_stack/providers/inline/safety/meta_reference/config.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								llama_stack/providers/inline/safety/meta_reference/config.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,48 @@ | |||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the terms described in the LICENSE file in | ||||
| # the root directory of this source tree. | ||||
| 
 | ||||
| from enum import Enum | ||||
| from typing import List, Optional | ||||
| 
 | ||||
| from llama_models.sku_list import CoreModelId, safety_models | ||||
| 
 | ||||
| from pydantic import BaseModel, field_validator | ||||
| 
 | ||||
| 
 | ||||
| class PromptGuardType(Enum): | ||||
|     injection = "injection" | ||||
|     jailbreak = "jailbreak" | ||||
| 
 | ||||
| 
 | ||||
| class LlamaGuardShieldConfig(BaseModel): | ||||
|     model: str = "Llama-Guard-3-1B" | ||||
|     excluded_categories: List[str] = [] | ||||
| 
 | ||||
|     @field_validator("model") | ||||
|     @classmethod | ||||
|     def validate_model(cls, model: str) -> str: | ||||
|         permitted_models = [ | ||||
|             m.descriptor() | ||||
|             for m in safety_models() | ||||
|             if ( | ||||
|                 m.core_model_id | ||||
|                 in { | ||||
|                     CoreModelId.llama_guard_3_8b, | ||||
|                     CoreModelId.llama_guard_3_1b, | ||||
|                     CoreModelId.llama_guard_3_11b_vision, | ||||
|                 } | ||||
|             ) | ||||
|         ] | ||||
|         if model not in permitted_models: | ||||
|             raise ValueError( | ||||
|                 f"Invalid model: {model}. Must be one of {permitted_models}" | ||||
|             ) | ||||
|         return model | ||||
| 
 | ||||
| 
 | ||||
| class SafetyConfig(BaseModel): | ||||
|     llama_guard_shield: Optional[LlamaGuardShieldConfig] = None | ||||
|     enable_prompt_guard: Optional[bool] = False | ||||
|  | @ -0,0 +1,268 @@ | |||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the terms described in the LICENSE file in | ||||
| # the root directory of this source tree. | ||||
| 
 | ||||
| import re | ||||
| 
 | ||||
| from string import Template | ||||
| from typing import List, Optional | ||||
| 
 | ||||
| from llama_models.llama3.api.datatypes import *  # noqa: F403 | ||||
| from llama_stack.apis.inference import *  # noqa: F403 | ||||
| 
 | ||||
| from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse | ||||
| 
 | ||||
| 
 | ||||
| SAFE_RESPONSE = "safe" | ||||
| _INSTANCE = None | ||||
| 
 | ||||
| CAT_VIOLENT_CRIMES = "Violent Crimes" | ||||
| CAT_NON_VIOLENT_CRIMES = "Non-Violent Crimes" | ||||
| CAT_SEX_CRIMES = "Sex Crimes" | ||||
| CAT_CHILD_EXPLOITATION = "Child Exploitation" | ||||
| CAT_DEFAMATION = "Defamation" | ||||
| CAT_SPECIALIZED_ADVICE = "Specialized Advice" | ||||
| CAT_PRIVACY = "Privacy" | ||||
| CAT_INTELLECTUAL_PROPERTY = "Intellectual Property" | ||||
| CAT_INDISCRIMINATE_WEAPONS = "Indiscriminate Weapons" | ||||
| CAT_HATE = "Hate" | ||||
| CAT_SELF_HARM = "Self-Harm" | ||||
| CAT_SEXUAL_CONTENT = "Sexual Content" | ||||
| CAT_ELECTIONS = "Elections" | ||||
| CAT_CODE_INTERPRETER_ABUSE = "Code Interpreter Abuse" | ||||
| 
 | ||||
| 
 | ||||
| SAFETY_CATEGORIES_TO_CODE_MAP = { | ||||
|     CAT_VIOLENT_CRIMES: "S1", | ||||
|     CAT_NON_VIOLENT_CRIMES: "S2", | ||||
|     CAT_SEX_CRIMES: "S3", | ||||
|     CAT_CHILD_EXPLOITATION: "S4", | ||||
|     CAT_DEFAMATION: "S5", | ||||
|     CAT_SPECIALIZED_ADVICE: "S6", | ||||
|     CAT_PRIVACY: "S7", | ||||
|     CAT_INTELLECTUAL_PROPERTY: "S8", | ||||
|     CAT_INDISCRIMINATE_WEAPONS: "S9", | ||||
|     CAT_HATE: "S10", | ||||
|     CAT_SELF_HARM: "S11", | ||||
|     CAT_SEXUAL_CONTENT: "S12", | ||||
|     CAT_ELECTIONS: "S13", | ||||
|     CAT_CODE_INTERPRETER_ABUSE: "S14", | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| DEFAULT_LG_V3_SAFETY_CATEGORIES = [ | ||||
|     CAT_VIOLENT_CRIMES, | ||||
|     CAT_NON_VIOLENT_CRIMES, | ||||
|     CAT_SEX_CRIMES, | ||||
|     CAT_CHILD_EXPLOITATION, | ||||
|     CAT_DEFAMATION, | ||||
|     CAT_SPECIALIZED_ADVICE, | ||||
|     CAT_PRIVACY, | ||||
|     CAT_INTELLECTUAL_PROPERTY, | ||||
|     CAT_INDISCRIMINATE_WEAPONS, | ||||
|     CAT_HATE, | ||||
|     CAT_SELF_HARM, | ||||
|     CAT_SEXUAL_CONTENT, | ||||
|     CAT_ELECTIONS, | ||||
| ] | ||||
| 
 | ||||
| 
 | ||||
| MODEL_TO_SAFETY_CATEGORIES_MAP = { | ||||
|     CoreModelId.llama_guard_3_8b.value: ( | ||||
|         DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE] | ||||
|     ), | ||||
|     CoreModelId.llama_guard_3_1b.value: DEFAULT_LG_V3_SAFETY_CATEGORIES, | ||||
|     CoreModelId.llama_guard_3_11b_vision.value: DEFAULT_LG_V3_SAFETY_CATEGORIES, | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| PROMPT_TASK = "Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories." | ||||
| 
 | ||||
| SAFETY_CATEGORIES = """ | ||||
| 
 | ||||
| <BEGIN UNSAFE CONTENT CATEGORIES> | ||||
| $categories | ||||
| <END UNSAFE CONTENT CATEGORIES> | ||||
| """ | ||||
| 
 | ||||
| PROMPT_CONVERSATION = """ | ||||
| <BEGIN CONVERSATION> | ||||
| 
 | ||||
| $conversations | ||||
| 
 | ||||
| <END CONVERSATION> | ||||
| """ | ||||
| 
 | ||||
| 
 | ||||
| PROMPT_INSTRUCTIONS = """ | ||||
| Provide your safety assessment for ONLY THE LAST $agent_type message in the above conversation: | ||||
|  - First line must read 'safe' or 'unsafe'. | ||||
|  - If unsafe, a second line must include a comma-separated list of violated categories.""" | ||||
| 
 | ||||
| 
 | ||||
| PROMPT_TEMPLATE = Template( | ||||
|     f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}" | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| class LlamaGuardShield(ShieldBase): | ||||
|     def __init__( | ||||
|         self, | ||||
|         model: str, | ||||
|         inference_api: Inference, | ||||
|         excluded_categories: List[str] = None, | ||||
|         on_violation_action: OnViolationAction = OnViolationAction.RAISE, | ||||
|     ): | ||||
|         super().__init__(on_violation_action) | ||||
| 
 | ||||
|         if excluded_categories is None: | ||||
|             excluded_categories = [] | ||||
| 
 | ||||
|         assert len(excluded_categories) == 0 or all( | ||||
|             x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories | ||||
|         ), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]" | ||||
| 
 | ||||
|         if model not in MODEL_TO_SAFETY_CATEGORIES_MAP: | ||||
|             raise ValueError(f"Unsupported model: {model}") | ||||
| 
 | ||||
|         self.model = model | ||||
|         self.inference_api = inference_api | ||||
|         self.excluded_categories = excluded_categories | ||||
| 
 | ||||
|     def check_unsafe_response(self, response: str) -> Optional[str]: | ||||
|         match = re.match(r"^unsafe\n(.*)$", response) | ||||
|         if match: | ||||
|             # extracts the unsafe code | ||||
|             extracted = match.group(1) | ||||
|             return extracted | ||||
| 
 | ||||
|         return None | ||||
| 
 | ||||
|     def get_safety_categories(self) -> List[str]: | ||||
|         excluded_categories = self.excluded_categories | ||||
|         if set(excluded_categories) == set(SAFETY_CATEGORIES_TO_CODE_MAP.values()): | ||||
|             excluded_categories = [] | ||||
| 
 | ||||
|         final_categories = [] | ||||
| 
 | ||||
|         all_categories = MODEL_TO_SAFETY_CATEGORIES_MAP[self.model] | ||||
|         for cat in all_categories: | ||||
|             cat_code = SAFETY_CATEGORIES_TO_CODE_MAP[cat] | ||||
|             if cat_code in excluded_categories: | ||||
|                 continue | ||||
|             final_categories.append(f"{cat_code}: {cat}.") | ||||
| 
 | ||||
|         return final_categories | ||||
| 
 | ||||
|     def validate_messages(self, messages: List[Message]) -> None: | ||||
|         if len(messages) == 0: | ||||
|             raise ValueError("Messages must not be empty") | ||||
|         if messages[0].role != Role.user.value: | ||||
|             raise ValueError("Messages must start with user") | ||||
| 
 | ||||
|         if len(messages) >= 2 and ( | ||||
|             messages[0].role == Role.user.value and messages[1].role == Role.user.value | ||||
|         ): | ||||
|             messages = messages[1:] | ||||
| 
 | ||||
|         for i in range(1, len(messages)): | ||||
|             if messages[i].role == messages[i - 1].role: | ||||
|                 raise ValueError( | ||||
|                     f"Messages must alternate between user and assistant. Message {i} has the same role as message {i - 1}" | ||||
|                 ) | ||||
|         return messages | ||||
| 
 | ||||
|     async def run(self, messages: List[Message]) -> ShieldResponse: | ||||
|         messages = self.validate_messages(messages) | ||||
| 
 | ||||
|         if self.model == CoreModelId.llama_guard_3_11b_vision.value: | ||||
|             shield_input_message = self.build_vision_shield_input(messages) | ||||
|         else: | ||||
|             shield_input_message = self.build_text_shield_input(messages) | ||||
| 
 | ||||
|         # TODO: llama-stack inference protocol has issues with non-streaming inference code | ||||
|         content = "" | ||||
|         async for chunk in await self.inference_api.chat_completion( | ||||
|             model=self.model, | ||||
|             messages=[shield_input_message], | ||||
|             stream=True, | ||||
|         ): | ||||
|             event = chunk.event | ||||
|             if event.event_type == ChatCompletionResponseEventType.progress: | ||||
|                 assert isinstance(event.delta, str) | ||||
|                 content += event.delta | ||||
| 
 | ||||
|         content = content.strip() | ||||
|         shield_response = self.get_shield_response(content) | ||||
|         return shield_response | ||||
| 
 | ||||
|     def build_text_shield_input(self, messages: List[Message]) -> UserMessage: | ||||
|         return UserMessage(content=self.build_prompt(messages)) | ||||
| 
 | ||||
|     def build_vision_shield_input(self, messages: List[Message]) -> UserMessage: | ||||
|         conversation = [] | ||||
|         most_recent_img = None | ||||
| 
 | ||||
|         for m in messages[::-1]: | ||||
|             if isinstance(m.content, str): | ||||
|                 conversation.append(m) | ||||
|             elif isinstance(m.content, ImageMedia): | ||||
|                 if most_recent_img is None and m.role == Role.user.value: | ||||
|                     most_recent_img = m.content | ||||
|                     conversation.append(m) | ||||
|             elif isinstance(m.content, list): | ||||
|                 content = [] | ||||
|                 for c in m.content: | ||||
|                     if isinstance(c, str): | ||||
|                         content.append(c) | ||||
|                     elif isinstance(c, ImageMedia): | ||||
|                         if most_recent_img is None and m.role == Role.user.value: | ||||
|                             most_recent_img = c | ||||
|                             content.append(c) | ||||
|                     else: | ||||
|                         raise ValueError(f"Unknown content type: {c}") | ||||
| 
 | ||||
|                 conversation.append(UserMessage(content=content)) | ||||
|             else: | ||||
|                 raise ValueError(f"Unknown content type: {m.content}") | ||||
| 
 | ||||
|         prompt = [] | ||||
|         if most_recent_img is not None: | ||||
|             prompt.append(most_recent_img) | ||||
|         prompt.append(self.build_prompt(conversation[::-1])) | ||||
| 
 | ||||
|         return UserMessage(content=prompt) | ||||
| 
 | ||||
|     def build_prompt(self, messages: List[Message]) -> str: | ||||
|         categories = self.get_safety_categories() | ||||
|         categories_str = "\n".join(categories) | ||||
|         conversations_str = "\n\n".join( | ||||
|             [ | ||||
|                 f"{m.role.capitalize()}: {interleaved_text_media_as_str(m.content)}" | ||||
|                 for m in messages | ||||
|             ] | ||||
|         ) | ||||
|         return PROMPT_TEMPLATE.substitute( | ||||
|             agent_type=messages[-1].role.capitalize(), | ||||
|             categories=categories_str, | ||||
|             conversations=conversations_str, | ||||
|         ) | ||||
| 
 | ||||
|     def get_shield_response(self, response: str) -> ShieldResponse: | ||||
|         response = response.strip() | ||||
|         if response == SAFE_RESPONSE: | ||||
|             return ShieldResponse(is_violation=False) | ||||
|         unsafe_code = self.check_unsafe_response(response) | ||||
|         if unsafe_code: | ||||
|             unsafe_code_list = unsafe_code.split(",") | ||||
|             if set(unsafe_code_list).issubset(set(self.excluded_categories)): | ||||
|                 return ShieldResponse(is_violation=False) | ||||
|             return ShieldResponse( | ||||
|                 is_violation=True, | ||||
|                 violation_type=unsafe_code, | ||||
|                 violation_return_message=CANNED_RESPONSE_TEXT, | ||||
|             ) | ||||
| 
 | ||||
|         raise ValueError(f"Unexpected response: {response}") | ||||
|  | @ -0,0 +1,145 @@ | |||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the terms described in the LICENSE file in | ||||
| # the root directory of this source tree. | ||||
| 
 | ||||
| from enum import auto, Enum | ||||
| from typing import List | ||||
| 
 | ||||
| import torch | ||||
| 
 | ||||
| from llama_models.llama3.api.datatypes import Message | ||||
| from termcolor import cprint | ||||
| 
 | ||||
| from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield | ||||
| 
 | ||||
| 
 | ||||
| class PromptGuardShield(TextShield): | ||||
|     class Mode(Enum): | ||||
|         INJECTION = auto() | ||||
|         JAILBREAK = auto() | ||||
| 
 | ||||
|     _instances = {} | ||||
|     _model_cache = None | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def instance( | ||||
|         model_dir: str, | ||||
|         threshold: float = 0.9, | ||||
|         temperature: float = 1.0, | ||||
|         mode: "PromptGuardShield.Mode" = Mode.JAILBREAK, | ||||
|         on_violation_action=OnViolationAction.RAISE, | ||||
|     ) -> "PromptGuardShield": | ||||
|         action_value = on_violation_action.value | ||||
|         key = (model_dir, threshold, temperature, mode, action_value) | ||||
|         if key not in PromptGuardShield._instances: | ||||
|             PromptGuardShield._instances[key] = PromptGuardShield( | ||||
|                 model_dir=model_dir, | ||||
|                 threshold=threshold, | ||||
|                 temperature=temperature, | ||||
|                 mode=mode, | ||||
|                 on_violation_action=on_violation_action, | ||||
|             ) | ||||
|         return PromptGuardShield._instances[key] | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         model_dir: str, | ||||
|         threshold: float = 0.9, | ||||
|         temperature: float = 1.0, | ||||
|         mode: "PromptGuardShield.Mode" = Mode.JAILBREAK, | ||||
|         on_violation_action: OnViolationAction = OnViolationAction.RAISE, | ||||
|     ): | ||||
|         super().__init__(on_violation_action) | ||||
|         assert ( | ||||
|             model_dir is not None | ||||
|         ), "Must provide a model directory for prompt injection shield" | ||||
|         if temperature <= 0: | ||||
|             raise ValueError("Temperature must be greater than 0") | ||||
|         self.device = "cuda" | ||||
|         if PromptGuardShield._model_cache is None: | ||||
|             from transformers import AutoModelForSequenceClassification, AutoTokenizer | ||||
| 
 | ||||
|             # load model and tokenizer | ||||
|             tokenizer = AutoTokenizer.from_pretrained(model_dir) | ||||
|             model = AutoModelForSequenceClassification.from_pretrained( | ||||
|                 model_dir, device_map=self.device | ||||
|             ) | ||||
|             PromptGuardShield._model_cache = (tokenizer, model) | ||||
| 
 | ||||
|         self.tokenizer, self.model = PromptGuardShield._model_cache | ||||
|         self.temperature = temperature | ||||
|         self.threshold = threshold | ||||
|         self.mode = mode | ||||
| 
 | ||||
|     def convert_messages_to_text(self, messages: List[Message]) -> str: | ||||
|         return message_content_as_str(messages[-1]) | ||||
| 
 | ||||
|     async def run_impl(self, text: str) -> ShieldResponse: | ||||
|         # run model on messages and return response | ||||
|         inputs = self.tokenizer(text, return_tensors="pt") | ||||
|         inputs = {name: tensor.to(self.model.device) for name, tensor in inputs.items()} | ||||
|         with torch.no_grad(): | ||||
|             outputs = self.model(**inputs) | ||||
|         logits = outputs[0] | ||||
|         probabilities = torch.softmax(logits / self.temperature, dim=-1) | ||||
|         score_embedded = probabilities[0, 1].item() | ||||
|         score_malicious = probabilities[0, 2].item() | ||||
|         cprint( | ||||
|             f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}", | ||||
|             color="magenta", | ||||
|         ) | ||||
| 
 | ||||
|         if self.mode == self.Mode.INJECTION and ( | ||||
|             score_embedded + score_malicious > self.threshold | ||||
|         ): | ||||
|             return ShieldResponse( | ||||
|                 is_violation=True, | ||||
|                 violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}", | ||||
|                 violation_return_message="Sorry, I cannot do this.", | ||||
|             ) | ||||
|         elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold: | ||||
|             return ShieldResponse( | ||||
|                 is_violation=True, | ||||
|                 violation_type=f"prompt_injection:malicious={score_malicious}", | ||||
|                 violation_return_message="Sorry, I cannot do this.", | ||||
|             ) | ||||
| 
 | ||||
|         return ShieldResponse( | ||||
|             is_violation=False, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| class JailbreakShield(PromptGuardShield): | ||||
|     def __init__( | ||||
|         self, | ||||
|         model_dir: str, | ||||
|         threshold: float = 0.9, | ||||
|         temperature: float = 1.0, | ||||
|         on_violation_action: OnViolationAction = OnViolationAction.RAISE, | ||||
|     ): | ||||
|         super().__init__( | ||||
|             model_dir=model_dir, | ||||
|             threshold=threshold, | ||||
|             temperature=temperature, | ||||
|             mode=PromptGuardShield.Mode.JAILBREAK, | ||||
|             on_violation_action=on_violation_action, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| class InjectionShield(PromptGuardShield): | ||||
|     def __init__( | ||||
|         self, | ||||
|         model_dir: str, | ||||
|         threshold: float = 0.9, | ||||
|         temperature: float = 1.0, | ||||
|         on_violation_action: OnViolationAction = OnViolationAction.RAISE, | ||||
|     ): | ||||
|         super().__init__( | ||||
|             model_dir=model_dir, | ||||
|             threshold=threshold, | ||||
|             temperature=temperature, | ||||
|             mode=PromptGuardShield.Mode.INJECTION, | ||||
|             on_violation_action=on_violation_action, | ||||
|         ) | ||||
							
								
								
									
										112
									
								
								llama_stack/providers/inline/safety/meta_reference/safety.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										112
									
								
								llama_stack/providers/inline/safety/meta_reference/safety.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,112 @@ | |||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the terms described in the LICENSE file in | ||||
| # the root directory of this source tree. | ||||
| 
 | ||||
| from typing import Any, Dict, List | ||||
| 
 | ||||
| from llama_stack.distribution.utils.model_utils import model_local_dir | ||||
| from llama_stack.apis.inference import *  # noqa: F403 | ||||
| from llama_stack.apis.safety import *  # noqa: F403 | ||||
| from llama_models.llama3.api.datatypes import *  # noqa: F403 | ||||
| from llama_stack.distribution.datatypes import Api | ||||
| 
 | ||||
| from llama_stack.providers.datatypes import ShieldsProtocolPrivate | ||||
| 
 | ||||
| from .base import OnViolationAction, ShieldBase | ||||
| from .config import SafetyConfig | ||||
| from .llama_guard import LlamaGuardShield | ||||
| from .prompt_guard import InjectionShield, JailbreakShield, PromptGuardShield | ||||
| 
 | ||||
| 
 | ||||
| PROMPT_GUARD_MODEL = "Prompt-Guard-86M" | ||||
| 
 | ||||
| 
 | ||||
| class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate): | ||||
|     def __init__(self, config: SafetyConfig, deps) -> None: | ||||
|         self.config = config | ||||
|         self.inference_api = deps[Api.inference] | ||||
| 
 | ||||
|         self.available_shields = [] | ||||
|         if config.llama_guard_shield: | ||||
|             self.available_shields.append(ShieldType.llama_guard.value) | ||||
|         if config.enable_prompt_guard: | ||||
|             self.available_shields.append(ShieldType.prompt_guard.value) | ||||
| 
 | ||||
|     async def initialize(self) -> None: | ||||
|         if self.config.enable_prompt_guard: | ||||
|             model_dir = model_local_dir(PROMPT_GUARD_MODEL) | ||||
|             _ = PromptGuardShield.instance(model_dir) | ||||
| 
 | ||||
|     async def shutdown(self) -> None: | ||||
|         pass | ||||
| 
 | ||||
|     async def register_shield(self, shield: ShieldDef) -> None: | ||||
|         raise ValueError("Registering dynamic shields is not supported") | ||||
| 
 | ||||
|     async def list_shields(self) -> List[ShieldDef]: | ||||
|         return [ | ||||
|             ShieldDef( | ||||
|                 identifier=shield_type, | ||||
|                 shield_type=shield_type, | ||||
|                 params={}, | ||||
|             ) | ||||
|             for shield_type in self.available_shields | ||||
|         ] | ||||
| 
 | ||||
|     async def run_shield( | ||||
|         self, | ||||
|         identifier: str, | ||||
|         messages: List[Message], | ||||
|         params: Dict[str, Any] = None, | ||||
|     ) -> RunShieldResponse: | ||||
|         shield_def = await self.shield_store.get_shield(identifier) | ||||
|         if not shield_def: | ||||
|             raise ValueError(f"Unknown shield {identifier}") | ||||
| 
 | ||||
|         shield = self.get_shield_impl(shield_def) | ||||
| 
 | ||||
|         messages = messages.copy() | ||||
|         # some shields like llama-guard require the first message to be a user message | ||||
|         # since this might be a tool call, first role might not be user | ||||
|         if len(messages) > 0 and messages[0].role != Role.user.value: | ||||
|             messages[0] = UserMessage(content=messages[0].content) | ||||
| 
 | ||||
|         # TODO: we can refactor ShieldBase, etc. to be inline with the API types | ||||
|         res = await shield.run(messages) | ||||
|         violation = None | ||||
|         if res.is_violation and shield.on_violation_action != OnViolationAction.IGNORE: | ||||
|             violation = SafetyViolation( | ||||
|                 violation_level=( | ||||
|                     ViolationLevel.ERROR | ||||
|                     if shield.on_violation_action == OnViolationAction.RAISE | ||||
|                     else ViolationLevel.WARN | ||||
|                 ), | ||||
|                 user_message=res.violation_return_message, | ||||
|                 metadata={ | ||||
|                     "violation_type": res.violation_type, | ||||
|                 }, | ||||
|             ) | ||||
| 
 | ||||
|         return RunShieldResponse(violation=violation) | ||||
| 
 | ||||
|     def get_shield_impl(self, shield: ShieldDef) -> ShieldBase: | ||||
|         if shield.shield_type == ShieldType.llama_guard.value: | ||||
|             cfg = self.config.llama_guard_shield | ||||
|             return LlamaGuardShield( | ||||
|                 model=cfg.model, | ||||
|                 inference_api=self.inference_api, | ||||
|                 excluded_categories=cfg.excluded_categories, | ||||
|             ) | ||||
|         elif shield.shield_type == ShieldType.prompt_guard.value: | ||||
|             model_dir = model_local_dir(PROMPT_GUARD_MODEL) | ||||
|             subtype = shield.params.get("prompt_guard_type", "injection") | ||||
|             if subtype == "injection": | ||||
|                 return InjectionShield.instance(model_dir) | ||||
|             elif subtype == "jailbreak": | ||||
|                 return JailbreakShield.instance(model_dir) | ||||
|             else: | ||||
|                 raise ValueError(f"Unknown prompt guard type: {subtype}") | ||||
|         else: | ||||
|             raise ValueError(f"Unknown shield type: {shield.shield_type}") | ||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue