mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-10 19:43:16 +00:00
Significantly upgrade the interactive configuration experience
This commit is contained in:
parent
8d157a8197
commit
5a7b01d292
7 changed files with 217 additions and 156 deletions
|
|
@ -113,8 +113,6 @@ class LlamaGuardShield(ShieldBase):
|
|||
model: str,
|
||||
inference_api: Inference,
|
||||
excluded_categories: List[str] = None,
|
||||
disable_input_check: bool = False,
|
||||
disable_output_check: bool = False,
|
||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||
):
|
||||
super().__init__(on_violation_action)
|
||||
|
|
@ -132,8 +130,6 @@ class LlamaGuardShield(ShieldBase):
|
|||
self.model = model
|
||||
self.inference_api = inference_api
|
||||
self.excluded_categories = excluded_categories
|
||||
self.disable_input_check = disable_input_check
|
||||
self.disable_output_check = disable_output_check
|
||||
|
||||
def check_unsafe_response(self, response: str) -> Optional[str]:
|
||||
match = re.match(r"^unsafe\n(.*)$", response)
|
||||
|
|
@ -180,12 +176,6 @@ class LlamaGuardShield(ShieldBase):
|
|||
|
||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||
messages = self.validate_messages(messages)
|
||||
if self.disable_input_check and messages[-1].role == Role.user.value:
|
||||
return ShieldResponse(is_violation=False)
|
||||
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
|
||||
return ShieldResponse(
|
||||
is_violation=False,
|
||||
)
|
||||
|
||||
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
|
||||
shield_input_message = self.build_vision_shield_input(messages)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue