mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-15 01:26:10 +00:00
Fix precommit check after moving to ruff (#927)
Lint check in main branch is failing. This fixes the lint check after we moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We need to move to a `ruff.toml` file as well as fixing and ignoring some additional checks. Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
parent
4773092dd1
commit
34ab7a3b6c
217 changed files with 981 additions and 2681 deletions
|
@ -67,10 +67,6 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
|||
violation = SafetyViolation(
|
||||
violation_level=(ViolationLevel.ERROR),
|
||||
user_message="Sorry, I found security concerns in the code.",
|
||||
metadata={
|
||||
"violation_type": ",".join(
|
||||
[issue.pattern_id for issue in result.issues_found]
|
||||
)
|
||||
},
|
||||
metadata={"violation_type": ",".join([issue.pattern_id for issue in result.issues_found])},
|
||||
)
|
||||
return RunShieldResponse(violation=violation)
|
||||
|
|
|
@ -10,9 +10,7 @@ from .config import LlamaGuardConfig
|
|||
async def get_provider_impl(config: LlamaGuardConfig, deps):
|
||||
from .llama_guard import LlamaGuardSafetyImpl
|
||||
|
||||
assert isinstance(
|
||||
config, LlamaGuardConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = LlamaGuardSafetyImpl(config, deps)
|
||||
await impl.initialize()
|
||||
|
|
|
@ -102,8 +102,7 @@ LLAMA_GUARD_MODEL_IDS = {
|
|||
}
|
||||
|
||||
MODEL_TO_SAFETY_CATEGORIES_MAP = {
|
||||
"meta-llama/Llama-Guard-3-8B": DEFAULT_LG_V3_SAFETY_CATEGORIES
|
||||
+ [CAT_CODE_INTERPRETER_ABUSE],
|
||||
"meta-llama/Llama-Guard-3-8B": DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE],
|
||||
"meta-llama/Llama-Guard-3-1B": DEFAULT_LG_V3_SAFETY_CATEGORIES,
|
||||
"meta-llama/Llama-Guard-3-11B-Vision": DEFAULT_LG_V3_SAFETY_CATEGORIES,
|
||||
}
|
||||
|
@ -133,9 +132,7 @@ Provide your safety assessment for ONLY THE LAST $agent_type message in the abov
|
|||
- If unsafe, a second line must include a comma-separated list of violated categories."""
|
||||
|
||||
|
||||
PROMPT_TEMPLATE = Template(
|
||||
f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}"
|
||||
)
|
||||
PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}")
|
||||
|
||||
|
||||
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||
|
@ -233,9 +230,7 @@ class LlamaGuardShield:
|
|||
if messages[0].role != Role.user.value:
|
||||
raise ValueError("Messages must start with user")
|
||||
|
||||
if len(messages) >= 2 and (
|
||||
messages[0].role == Role.user.value and messages[1].role == Role.user.value
|
||||
):
|
||||
if len(messages) >= 2 and (messages[0].role == Role.user.value and messages[1].role == Role.user.value):
|
||||
messages = messages[1:]
|
||||
|
||||
for i in range(1, len(messages)):
|
||||
|
@ -263,10 +258,7 @@ class LlamaGuardShield:
|
|||
stream=True,
|
||||
):
|
||||
event = chunk.event
|
||||
if (
|
||||
event.event_type == ChatCompletionResponseEventType.progress
|
||||
and event.delta.type == "text"
|
||||
):
|
||||
if event.event_type == ChatCompletionResponseEventType.progress and event.delta.type == "text":
|
||||
content += event.delta.text
|
||||
|
||||
content = content.strip()
|
||||
|
@ -313,10 +305,7 @@ class LlamaGuardShield:
|
|||
categories = self.get_safety_categories()
|
||||
categories_str = "\n".join(categories)
|
||||
conversations_str = "\n\n".join(
|
||||
[
|
||||
f"{m.role.capitalize()}: {interleaved_content_as_str(m.content)}"
|
||||
for m in messages
|
||||
]
|
||||
[f"{m.role.capitalize()}: {interleaved_content_as_str(m.content)}" for m in messages]
|
||||
)
|
||||
return PROMPT_TEMPLATE.substitute(
|
||||
agent_type=messages[-1].role.capitalize(),
|
||||
|
|
|
@ -46,9 +46,7 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
if shield.provider_resource_id != PROMPT_GUARD_MODEL:
|
||||
raise ValueError(
|
||||
f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. "
|
||||
)
|
||||
raise ValueError(f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. ")
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
|
@ -71,9 +69,7 @@ class PromptGuardShield:
|
|||
threshold: float = 0.9,
|
||||
temperature: float = 1.0,
|
||||
):
|
||||
assert (
|
||||
model_dir is not None
|
||||
), "Must provide a model directory for prompt injection shield"
|
||||
assert model_dir is not None, "Must provide a model directory for prompt injection shield"
|
||||
if temperature <= 0:
|
||||
raise ValueError("Temperature must be greater than 0")
|
||||
|
||||
|
@ -85,9 +81,7 @@ class PromptGuardShield:
|
|||
|
||||
# load model and tokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||
model_dir, device_map=self.device
|
||||
)
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir, device_map=self.device)
|
||||
|
||||
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
||||
message = messages[-1]
|
||||
|
@ -117,10 +111,7 @@ class PromptGuardShield:
|
|||
"violation_type": f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
|
||||
},
|
||||
)
|
||||
elif (
|
||||
self.config.guard_type == PromptGuardType.jailbreak.value
|
||||
and score_malicious > self.threshold
|
||||
):
|
||||
elif self.config.guard_type == PromptGuardType.jailbreak.value and score_malicious > self.threshold:
|
||||
violation = SafetyViolation(
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
violation_type=f"prompt_injection:malicious={score_malicious}",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue