skip tests

This commit is contained in:
Swapna Lekkala 2025-08-12 16:21:15 -07:00
parent 8d0d6f3813
commit 356b4dff88
2 changed files with 8 additions and 21 deletions

View file

@ -8,6 +8,7 @@ import logging
from typing import Any
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import (
@ -24,7 +25,6 @@ from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from .config import PromptGuardConfig, PromptGuardType
@ -48,9 +48,7 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
async def register_shield(self, shield: Shield) -> None:
if shield.provider_resource_id != PROMPT_GUARD_MODEL:
raise ValueError(
f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. "
)
raise ValueError(f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. ")
async def unregister_shield(self, identifier: str) -> None:
pass
@ -67,9 +65,7 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
return await self.shield.run(messages)
async def run_moderation(
self, input: str | list[str], model: str
) -> ModerationObject:
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
raise NotImplementedError("run_moderation is not implemented for Prompt Guard")
@ -81,9 +77,7 @@ class PromptGuardShield:
threshold: float = 0.9,
temperature: float = 1.0,
):
assert (
model_dir is not None
), "Must provide a model directory for prompt injection shield"
assert model_dir is not None, "Must provide a model directory for prompt injection shield"
if temperature <= 0:
raise ValueError("Temperature must be greater than 0")
@ -97,9 +91,7 @@ class PromptGuardShield:
# load model and tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForSequenceClassification.from_pretrained(
model_dir, device_map=self.device
)
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir, device_map=self.device)
async def run(self, messages: list[Message]) -> RunShieldResponse:
message = messages[-1]
@ -129,10 +121,7 @@ class PromptGuardShield:
"violation_type": f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
},
)
elif (
self.config.guard_type == PromptGuardType.jailbreak.value
and score_malicious > self.threshold
):
elif self.config.guard_type == PromptGuardType.jailbreak.value and score_malicious > self.threshold:
violation = SafetyViolation(
violation_level=ViolationLevel.ERROR,
user_message="Sorry, I cannot do this.",