chore: add mypy llama guard

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-07-09 01:49:55 +02:00
parent 1d8c00635c
commit 584d618e27
3 changed files with 46 additions and 30 deletions

View file

@ -10,6 +10,7 @@ from typing import Any
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionResponse,
Inference, Inference,
Message, Message,
UserMessage, UserMessage,
@ -18,6 +19,7 @@ from llama_stack.apis.safety import (
RunShieldResponse, RunShieldResponse,
Safety, Safety,
SafetyViolation, SafetyViolation,
ShieldStore,
ViolationLevel, ViolationLevel,
) )
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
@ -135,9 +137,10 @@ PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATIO
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
def __init__(self, config: LlamaGuardConfig, deps) -> None: def __init__(self, config: LlamaGuardConfig, deps: dict[str, Any]) -> None:
self.config = config self.config = config
self.inference_api = deps[Api.inference] self.inference_api: Inference = deps[Api.inference.value]
self.shield_store: ShieldStore = deps["shield_store"]
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
@ -154,7 +157,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
self, self,
shield_id: str, shield_id: str,
messages: list[Message], messages: list[Message],
params: dict[str, Any] = None, params: dict[str, Any] | None = None,
) -> RunShieldResponse: ) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id) shield = await self.shield_store.get_shield(shield_id)
if not shield: if not shield:
@ -166,6 +169,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
if len(messages) > 0 and messages[0].role != Role.user.value: if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content) messages[0] = UserMessage(content=messages[0].content)
<<<<<<< HEAD
# Use the inference API's model resolution instead of hardcoded mappings # Use the inference API's model resolution instead of hardcoded mappings
# This allows the shield to work with any registered model # This allows the shield to work with any registered model
model_id = shield.provider_resource_id model_id = shield.provider_resource_id
@ -180,6 +184,13 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
# For unknown models, use default Llama Guard 3 8B categories # For unknown models, use default Llama Guard 3 8B categories
safety_categories = DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE] safety_categories = DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
=======
provider_resource_id = shield.provider_resource_id
if provider_resource_id is None:
raise ValueError("Shield provider_resource_id is required")
model = LLAMA_GUARD_MODEL_IDS[provider_resource_id]
>>>>>>> 26efb98b (chore: add mypy llama guard)
impl = LlamaGuardShield( impl = LlamaGuardShield(
model=model_id, model=model_id,
inference_api=self.inference_api, inference_api=self.inference_api,
@ -237,7 +248,7 @@ class LlamaGuardShield:
return final_categories return final_categories
def validate_messages(self, messages: list[Message]) -> None: def validate_messages(self, messages: list[Message]) -> list[Message]:
if len(messages) == 0: if len(messages) == 0:
raise ValueError("Messages must not be empty") raise ValueError("Messages must not be empty")
if messages[0].role != Role.user.value: if messages[0].role != Role.user.value:
@ -262,9 +273,23 @@ class LlamaGuardShield:
messages=[shield_input_message], messages=[shield_input_message],
stream=False, stream=False,
) )
# Handle streaming response
if isinstance(response, ChatCompletionResponse):
content = response.completion_message.content content = response.completion_message.content
content = content.strip() else:
return self.get_shield_response(content) # Handle case where response is a stream
raise ValueError("Streaming response not supported for shield validation")
# Handle different content types
if isinstance(content, str):
content_str = content.strip()
elif hasattr(content, "strip"):
content_str = content.strip()
else:
content_str = str(content).strip()
return self.get_shield_response(content_str)
def build_text_shield_input(self, messages: list[Message]) -> UserMessage: def build_text_shield_input(self, messages: list[Message]) -> UserMessage:
return UserMessage(content=self.build_prompt(messages)) return UserMessage(content=self.build_prompt(messages))
@ -281,10 +306,10 @@ class LlamaGuardShield:
most_recent_img = m.content most_recent_img = m.content
conversation.append(m) conversation.append(m)
elif isinstance(m.content, list): elif isinstance(m.content, list):
content = [] content: list[TextContentItem | ImageContentItem] = []
for c in m.content: for c in m.content:
if isinstance(c, str) or isinstance(c, TextContentItem): if isinstance(c, str) or isinstance(c, TextContentItem):
content.append(c) content.append(TextContentItem(text=c if isinstance(c, str) else c.text))
elif isinstance(c, ImageContentItem): elif isinstance(c, ImageContentItem):
if most_recent_img is None and m.role == Role.user.value: if most_recent_img is None and m.role == Role.user.value:
most_recent_img = c most_recent_img = c
@ -296,10 +321,10 @@ class LlamaGuardShield:
else: else:
raise ValueError(f"Unknown content type: {m.content}") raise ValueError(f"Unknown content type: {m.content}")
prompt = [] prompt: list[ImageContentItem | TextContentItem] = []
if most_recent_img is not None: if most_recent_img is not None:
prompt.append(most_recent_img) prompt.append(most_recent_img)
prompt.append(self.build_prompt(conversation[::-1])) prompt.append(TextContentItem(text=self.build_prompt(conversation[::-1])))
return UserMessage(content=prompt) return UserMessage(content=prompt)

View file

@ -257,7 +257,6 @@ exclude = [
"^llama_stack/providers/inline/inference/vllm/", "^llama_stack/providers/inline/inference/vllm/",
"^llama_stack/providers/inline/post_training/common/validator\\.py$", "^llama_stack/providers/inline/post_training/common/validator\\.py$",
"^llama_stack/providers/inline/safety/code_scanner/", "^llama_stack/providers/inline/safety/code_scanner/",
"^llama_stack/providers/inline/safety/llama_guard/",
"^llama_stack/providers/inline/safety/prompt_guard/", "^llama_stack/providers/inline/safety/prompt_guard/",
"^llama_stack/providers/inline/scoring/basic/", "^llama_stack/providers/inline/scoring/basic/",
"^llama_stack/providers/inline/scoring/braintrust/", "^llama_stack/providers/inline/scoring/braintrust/",

View file

@ -10,12 +10,6 @@
set -euo pipefail set -euo pipefail
# Use mapfile to get a faster way to iterate over directories
if (( BASH_VERSINFO[0] < 4 )); then
echo "This script requires Bash 4.0 or higher for mapfile support."
exit 1
fi
PACKAGE_DIR="${1:-llama_stack}" PACKAGE_DIR="${1:-llama_stack}"
if [ ! -d "$PACKAGE_DIR" ]; then if [ ! -d "$PACKAGE_DIR" ]; then
@ -23,24 +17,22 @@ if [ ! -d "$PACKAGE_DIR" ]; then
exit 1 exit 1
fi fi
# Get all directories with Python files (excluding __init__.py)
mapfile -t py_dirs < <(
find "$PACKAGE_DIR" \
-type f \
-name "*.py" ! -name "__init__.py" \
! -path "*/.venv/*" \
! -path "*/node_modules/*" \
-exec dirname {} \; | sort -u
)
missing_init_files=0 missing_init_files=0
for dir in "${py_dirs[@]}"; do # Get all directories with Python files (excluding __init__.py) and check each one
while IFS= read -r -d '' dir; do
if [ ! -f "$dir/__init__.py" ]; then if [ ! -f "$dir/__init__.py" ]; then
echo "ERROR: Missing __init__.py in directory: $dir" echo "ERROR: Missing __init__.py in directory: $dir"
echo "This directory contains Python files but no __init__.py, which may cause packaging issues." echo "This directory contains Python files but no __init__.py, which may cause packaging issues."
missing_init_files=1 missing_init_files=1
fi fi
done done < <(
find "$PACKAGE_DIR" \
-type f \
-name "*.py" ! -name "__init__.py" \
! -path "*/.venv/*" \
! -path "*/node_modules/*" \
-exec dirname {} \; | sort -u | tr '\n' '\0'
)
exit $missing_init_files exit $missing_init_files