mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
chore: add mypy llama guard
Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
parent
1d8c00635c
commit
584d618e27
3 changed files with 46 additions and 30 deletions
|
@ -10,6 +10,7 @@ from typing import Any
|
|||
|
||||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
Inference,
|
||||
Message,
|
||||
UserMessage,
|
||||
|
@ -18,6 +19,7 @@ from llama_stack.apis.safety import (
|
|||
RunShieldResponse,
|
||||
Safety,
|
||||
SafetyViolation,
|
||||
ShieldStore,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
|
@ -135,9 +137,10 @@ PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATIO
|
|||
|
||||
|
||||
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: LlamaGuardConfig, deps) -> None:
|
||||
def __init__(self, config: LlamaGuardConfig, deps: dict[str, Any]) -> None:
|
||||
self.config = config
|
||||
self.inference_api = deps[Api.inference]
|
||||
self.inference_api: Inference = deps[Api.inference.value]
|
||||
self.shield_store: ShieldStore = deps["shield_store"]
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
@ -154,7 +157,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
self,
|
||||
shield_id: str,
|
||||
messages: list[Message],
|
||||
params: dict[str, Any] = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
|
@ -166,6 +169,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
if len(messages) > 0 and messages[0].role != Role.user.value:
|
||||
messages[0] = UserMessage(content=messages[0].content)
|
||||
|
||||
<<<<<<< HEAD
|
||||
# Use the inference API's model resolution instead of hardcoded mappings
|
||||
# This allows the shield to work with any registered model
|
||||
model_id = shield.provider_resource_id
|
||||
|
@ -180,6 +184,13 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
# For unknown models, use default Llama Guard 3 8B categories
|
||||
safety_categories = DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
|
||||
|
||||
=======
|
||||
provider_resource_id = shield.provider_resource_id
|
||||
if provider_resource_id is None:
|
||||
raise ValueError("Shield provider_resource_id is required")
|
||||
|
||||
model = LLAMA_GUARD_MODEL_IDS[provider_resource_id]
|
||||
>>>>>>> 26efb98b (chore: add mypy llama guard)
|
||||
impl = LlamaGuardShield(
|
||||
model=model_id,
|
||||
inference_api=self.inference_api,
|
||||
|
@ -237,7 +248,7 @@ class LlamaGuardShield:
|
|||
|
||||
return final_categories
|
||||
|
||||
def validate_messages(self, messages: list[Message]) -> None:
|
||||
def validate_messages(self, messages: list[Message]) -> list[Message]:
|
||||
if len(messages) == 0:
|
||||
raise ValueError("Messages must not be empty")
|
||||
if messages[0].role != Role.user.value:
|
||||
|
@ -262,9 +273,23 @@ class LlamaGuardShield:
|
|||
messages=[shield_input_message],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Handle streaming response
|
||||
if isinstance(response, ChatCompletionResponse):
|
||||
content = response.completion_message.content
|
||||
content = content.strip()
|
||||
return self.get_shield_response(content)
|
||||
else:
|
||||
# Handle case where response is a stream
|
||||
raise ValueError("Streaming response not supported for shield validation")
|
||||
|
||||
# Handle different content types
|
||||
if isinstance(content, str):
|
||||
content_str = content.strip()
|
||||
elif hasattr(content, "strip"):
|
||||
content_str = content.strip()
|
||||
else:
|
||||
content_str = str(content).strip()
|
||||
|
||||
return self.get_shield_response(content_str)
|
||||
|
||||
def build_text_shield_input(self, messages: list[Message]) -> UserMessage:
|
||||
return UserMessage(content=self.build_prompt(messages))
|
||||
|
@ -281,10 +306,10 @@ class LlamaGuardShield:
|
|||
most_recent_img = m.content
|
||||
conversation.append(m)
|
||||
elif isinstance(m.content, list):
|
||||
content = []
|
||||
content: list[TextContentItem | ImageContentItem] = []
|
||||
for c in m.content:
|
||||
if isinstance(c, str) or isinstance(c, TextContentItem):
|
||||
content.append(c)
|
||||
content.append(TextContentItem(text=c if isinstance(c, str) else c.text))
|
||||
elif isinstance(c, ImageContentItem):
|
||||
if most_recent_img is None and m.role == Role.user.value:
|
||||
most_recent_img = c
|
||||
|
@ -296,10 +321,10 @@ class LlamaGuardShield:
|
|||
else:
|
||||
raise ValueError(f"Unknown content type: {m.content}")
|
||||
|
||||
prompt = []
|
||||
prompt: list[ImageContentItem | TextContentItem] = []
|
||||
if most_recent_img is not None:
|
||||
prompt.append(most_recent_img)
|
||||
prompt.append(self.build_prompt(conversation[::-1]))
|
||||
prompt.append(TextContentItem(text=self.build_prompt(conversation[::-1])))
|
||||
|
||||
return UserMessage(content=prompt)
|
||||
|
||||
|
|
|
@ -257,7 +257,6 @@ exclude = [
|
|||
"^llama_stack/providers/inline/inference/vllm/",
|
||||
"^llama_stack/providers/inline/post_training/common/validator\\.py$",
|
||||
"^llama_stack/providers/inline/safety/code_scanner/",
|
||||
"^llama_stack/providers/inline/safety/llama_guard/",
|
||||
"^llama_stack/providers/inline/safety/prompt_guard/",
|
||||
"^llama_stack/providers/inline/scoring/basic/",
|
||||
"^llama_stack/providers/inline/scoring/braintrust/",
|
||||
|
|
|
@ -10,12 +10,6 @@
|
|||
|
||||
set -euo pipefail
|
||||
|
||||
# Use mapfile to get a faster way to iterate over directories
|
||||
if (( BASH_VERSINFO[0] < 4 )); then
|
||||
echo "This script requires Bash 4.0 or higher for mapfile support."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PACKAGE_DIR="${1:-llama_stack}"
|
||||
|
||||
if [ ! -d "$PACKAGE_DIR" ]; then
|
||||
|
@ -23,24 +17,22 @@ if [ ! -d "$PACKAGE_DIR" ]; then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
# Get all directories with Python files (excluding __init__.py)
|
||||
mapfile -t py_dirs < <(
|
||||
find "$PACKAGE_DIR" \
|
||||
-type f \
|
||||
-name "*.py" ! -name "__init__.py" \
|
||||
! -path "*/.venv/*" \
|
||||
! -path "*/node_modules/*" \
|
||||
-exec dirname {} \; | sort -u
|
||||
)
|
||||
|
||||
missing_init_files=0
|
||||
|
||||
for dir in "${py_dirs[@]}"; do
|
||||
# Get all directories with Python files (excluding __init__.py) and check each one
|
||||
while IFS= read -r -d '' dir; do
|
||||
if [ ! -f "$dir/__init__.py" ]; then
|
||||
echo "ERROR: Missing __init__.py in directory: $dir"
|
||||
echo "This directory contains Python files but no __init__.py, which may cause packaging issues."
|
||||
missing_init_files=1
|
||||
fi
|
||||
done
|
||||
done < <(
|
||||
find "$PACKAGE_DIR" \
|
||||
-type f \
|
||||
-name "*.py" ! -name "__init__.py" \
|
||||
! -path "*/.venv/*" \
|
||||
! -path "*/node_modules/*" \
|
||||
-exec dirname {} \; | sort -u | tr '\n' '\0'
|
||||
)
|
||||
|
||||
exit $missing_init_files
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue