chore: add mypy llama guard

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-07-09 01:49:55 +02:00
parent 1d8c00635c
commit 584d618e27
3 changed files with 46 additions and 30 deletions

View file

@ -10,6 +10,7 @@ from typing import Any
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.apis.inference import (
ChatCompletionResponse,
Inference,
Message,
UserMessage,
@ -18,6 +19,7 @@ from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
SafetyViolation,
ShieldStore,
ViolationLevel,
)
from llama_stack.apis.shields import Shield
@ -135,9 +137,10 @@ PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATIO
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
def __init__(self, config: LlamaGuardConfig, deps) -> None:
def __init__(self, config: LlamaGuardConfig, deps: dict[str, Any]) -> None:
self.config = config
self.inference_api = deps[Api.inference]
self.inference_api: Inference = deps[Api.inference.value]
self.shield_store: ShieldStore = deps["shield_store"]
async def initialize(self) -> None:
pass
@ -154,7 +157,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
self,
shield_id: str,
messages: list[Message],
params: dict[str, Any] = None,
params: dict[str, Any] | None = None,
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
if not shield:
@ -166,6 +169,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content)
<<<<<<< HEAD
# Use the inference API's model resolution instead of hardcoded mappings
# This allows the shield to work with any registered model
model_id = shield.provider_resource_id
@ -180,6 +184,13 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
# For unknown models, use default Llama Guard 3 8B categories
safety_categories = DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
=======
provider_resource_id = shield.provider_resource_id
if provider_resource_id is None:
raise ValueError("Shield provider_resource_id is required")
model = LLAMA_GUARD_MODEL_IDS[provider_resource_id]
>>>>>>> 26efb98b (chore: add mypy llama guard)
impl = LlamaGuardShield(
model=model_id,
inference_api=self.inference_api,
@ -237,7 +248,7 @@ class LlamaGuardShield:
return final_categories
def validate_messages(self, messages: list[Message]) -> None:
def validate_messages(self, messages: list[Message]) -> list[Message]:
if len(messages) == 0:
raise ValueError("Messages must not be empty")
if messages[0].role != Role.user.value:
@ -262,9 +273,23 @@ class LlamaGuardShield:
messages=[shield_input_message],
stream=False,
)
content = response.completion_message.content
content = content.strip()
return self.get_shield_response(content)
# Handle streaming response
if isinstance(response, ChatCompletionResponse):
content = response.completion_message.content
else:
# Handle case where response is a stream
raise ValueError("Streaming response not supported for shield validation")
# Handle different content types
if isinstance(content, str):
content_str = content.strip()
elif hasattr(content, "strip"):
content_str = content.strip()
else:
content_str = str(content).strip()
return self.get_shield_response(content_str)
def build_text_shield_input(self, messages: list[Message]) -> UserMessage:
return UserMessage(content=self.build_prompt(messages))
@ -281,10 +306,10 @@ class LlamaGuardShield:
most_recent_img = m.content
conversation.append(m)
elif isinstance(m.content, list):
content = []
content: list[TextContentItem | ImageContentItem] = []
for c in m.content:
if isinstance(c, str) or isinstance(c, TextContentItem):
content.append(c)
content.append(TextContentItem(text=c if isinstance(c, str) else c.text))
elif isinstance(c, ImageContentItem):
if most_recent_img is None and m.role == Role.user.value:
most_recent_img = c
@ -296,10 +321,10 @@ class LlamaGuardShield:
else:
raise ValueError(f"Unknown content type: {m.content}")
prompt = []
prompt: list[ImageContentItem | TextContentItem] = []
if most_recent_img is not None:
prompt.append(most_recent_img)
prompt.append(self.build_prompt(conversation[::-1]))
prompt.append(TextContentItem(text=self.build_prompt(conversation[::-1])))
return UserMessage(content=prompt)

View file

@ -257,7 +257,6 @@ exclude = [
"^llama_stack/providers/inline/inference/vllm/",
"^llama_stack/providers/inline/post_training/common/validator\\.py$",
"^llama_stack/providers/inline/safety/code_scanner/",
"^llama_stack/providers/inline/safety/llama_guard/",
"^llama_stack/providers/inline/safety/prompt_guard/",
"^llama_stack/providers/inline/scoring/basic/",
"^llama_stack/providers/inline/scoring/braintrust/",

View file

@ -10,12 +10,6 @@
set -euo pipefail
# Use mapfile to get a faster way to iterate over directories
if (( BASH_VERSINFO[0] < 4 )); then
echo "This script requires Bash 4.0 or higher for mapfile support."
exit 1
fi
PACKAGE_DIR="${1:-llama_stack}"
if [ ! -d "$PACKAGE_DIR" ]; then
@ -23,24 +17,22 @@ if [ ! -d "$PACKAGE_DIR" ]; then
exit 1
fi
# Get all directories with Python files (excluding __init__.py)
mapfile -t py_dirs < <(
find "$PACKAGE_DIR" \
-type f \
-name "*.py" ! -name "__init__.py" \
! -path "*/.venv/*" \
! -path "*/node_modules/*" \
-exec dirname {} \; | sort -u
)
missing_init_files=0
for dir in "${py_dirs[@]}"; do
# Get all directories with Python files (excluding __init__.py) and check each one
while IFS= read -r -d '' dir; do
if [ ! -f "$dir/__init__.py" ]; then
echo "ERROR: Missing __init__.py in directory: $dir"
echo "This directory contains Python files but no __init__.py, which may cause packaging issues."
missing_init_files=1
fi
done
done < <(
find "$PACKAGE_DIR" \
-type f \
-name "*.py" ! -name "__init__.py" \
! -path "*/.venv/*" \
! -path "*/node_modules/*" \
-exec dirname {} \; | sort -u | tr '\n' '\0'
)
exit $missing_init_files