diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index 9d359e053..1f8985a94 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -10,6 +10,7 @@ from typing import Any from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem from llama_stack.apis.inference import ( + ChatCompletionResponse, Inference, Message, UserMessage, @@ -18,6 +19,7 @@ from llama_stack.apis.safety import ( RunShieldResponse, Safety, SafetyViolation, + ShieldStore, ViolationLevel, ) from llama_stack.apis.shields import Shield @@ -135,9 +137,10 @@ PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATIO class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): - def __init__(self, config: LlamaGuardConfig, deps) -> None: + def __init__(self, config: LlamaGuardConfig, deps: dict[str, Any]) -> None: self.config = config - self.inference_api = deps[Api.inference] + self.inference_api: Inference = deps[Api.inference.value] + self.shield_store: ShieldStore = deps["shield_store"] async def initialize(self) -> None: pass @@ -154,7 +157,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): self, shield_id: str, messages: list[Message], - params: dict[str, Any] = None, + params: dict[str, Any] | None = None, ) -> RunShieldResponse: shield = await self.shield_store.get_shield(shield_id) if not shield: @@ -166,6 +169,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): if len(messages) > 0 and messages[0].role != Role.user.value: messages[0] = UserMessage(content=messages[0].content) +<<<<<<< HEAD # Use the inference API's model resolution instead of hardcoded mappings # This allows the shield to work with any registered model model_id = shield.provider_resource_id @@ -180,6 +184,13 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): # For unknown models, use default Llama Guard 3 8B categories safety_categories = DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE] +======= + provider_resource_id = shield.provider_resource_id + if provider_resource_id is None: + raise ValueError("Shield provider_resource_id is required") + + model = LLAMA_GUARD_MODEL_IDS[provider_resource_id] +>>>>>>> 26efb98b (chore: add mypy llama guard) impl = LlamaGuardShield( model=model_id, inference_api=self.inference_api, @@ -237,7 +248,7 @@ class LlamaGuardShield: return final_categories - def validate_messages(self, messages: list[Message]) -> None: + def validate_messages(self, messages: list[Message]) -> list[Message]: if len(messages) == 0: raise ValueError("Messages must not be empty") if messages[0].role != Role.user.value: @@ -262,9 +273,23 @@ class LlamaGuardShield: messages=[shield_input_message], stream=False, ) - content = response.completion_message.content - content = content.strip() - return self.get_shield_response(content) + + # Handle streaming response + if isinstance(response, ChatCompletionResponse): + content = response.completion_message.content + else: + # Handle case where response is a stream + raise ValueError("Streaming response not supported for shield validation") + + # Handle different content types + if isinstance(content, str): + content_str = content.strip() + elif hasattr(content, "strip"): + content_str = content.strip() + else: + content_str = str(content).strip() + + return self.get_shield_response(content_str) def build_text_shield_input(self, messages: list[Message]) -> UserMessage: return UserMessage(content=self.build_prompt(messages)) @@ -281,10 +306,10 @@ class LlamaGuardShield: most_recent_img = m.content conversation.append(m) elif isinstance(m.content, list): - content = [] + content: list[TextContentItem | ImageContentItem] = [] for c in m.content: if isinstance(c, str) or isinstance(c, TextContentItem): - content.append(c) + content.append(TextContentItem(text=c if isinstance(c, str) else c.text)) elif isinstance(c, ImageContentItem): if most_recent_img is None and m.role == Role.user.value: most_recent_img = c @@ -296,10 +321,10 @@ class LlamaGuardShield: else: raise ValueError(f"Unknown content type: {m.content}") - prompt = [] + prompt: list[ImageContentItem | TextContentItem] = [] if most_recent_img is not None: prompt.append(most_recent_img) - prompt.append(self.build_prompt(conversation[::-1])) + prompt.append(TextContentItem(text=self.build_prompt(conversation[::-1]))) return UserMessage(content=prompt) diff --git a/pyproject.toml b/pyproject.toml index 04a6a685e..03d3c2c1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -257,7 +257,6 @@ exclude = [ "^llama_stack/providers/inline/inference/vllm/", "^llama_stack/providers/inline/post_training/common/validator\\.py$", "^llama_stack/providers/inline/safety/code_scanner/", - "^llama_stack/providers/inline/safety/llama_guard/", "^llama_stack/providers/inline/safety/prompt_guard/", "^llama_stack/providers/inline/scoring/basic/", "^llama_stack/providers/inline/scoring/braintrust/", diff --git a/scripts/check-init-py.sh b/scripts/check-init-py.sh index c6e8fd417..88ebba286 100755 --- a/scripts/check-init-py.sh +++ b/scripts/check-init-py.sh @@ -10,12 +10,6 @@ set -euo pipefail -# Use mapfile to get a faster way to iterate over directories -if (( BASH_VERSINFO[0] < 4 )); then - echo "This script requires Bash 4.0 or higher for mapfile support." - exit 1 -fi - PACKAGE_DIR="${1:-llama_stack}" if [ ! -d "$PACKAGE_DIR" ]; then @@ -23,24 +17,22 @@ if [ ! -d "$PACKAGE_DIR" ]; then exit 1 fi -# Get all directories with Python files (excluding __init__.py) -mapfile -t py_dirs < <( - find "$PACKAGE_DIR" \ - -type f \ - -name "*.py" ! -name "__init__.py" \ - ! -path "*/.venv/*" \ - ! -path "*/node_modules/*" \ - -exec dirname {} \; | sort -u -) - missing_init_files=0 -for dir in "${py_dirs[@]}"; do +# Get all directories with Python files (excluding __init__.py) and check each one +while IFS= read -r -d '' dir; do if [ ! -f "$dir/__init__.py" ]; then echo "ERROR: Missing __init__.py in directory: $dir" echo "This directory contains Python files but no __init__.py, which may cause packaging issues." missing_init_files=1 fi -done +done < <( + find "$PACKAGE_DIR" \ + -type f \ + -name "*.py" ! -name "__init__.py" \ + ! -path "*/.venv/*" \ + ! -path "*/node_modules/*" \ + -exec dirname {} \; | sort -u | tr '\n' '\0' +) exit $missing_init_files