chore: enable mypy type checking for llama_guard safety provider

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-08-25 14:36:31 +02:00
parent ed418653ec
commit 0543f74c32
2 changed files with 85 additions and 21 deletions

View file

@ -17,7 +17,7 @@ from llama_stack.apis.safety import (
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults, ShieldStore
from llama_stack.apis.shields import Shield
from llama_stack.core.datatypes import Api
from llama_stack.log import get_logger
@ -136,6 +136,8 @@ logger = get_logger(name=__name__, category="safety")
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
shield_store: ShieldStore
def __init__(self, config: LlamaGuardConfig, deps) -> None:
self.config = config
self.inference_api = deps[Api.inference]
@ -160,7 +162,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
self,
shield_id: str,
messages: list[Message],
params: dict[str, Any] = None,
params: dict[str, Any] | None = None,
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
if not shield:
@ -175,6 +177,8 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
# Use the inference API's model resolution instead of hardcoded mappings
# This allows the shield to work with any registered model
model_id = shield.provider_resource_id
if not model_id:
raise ValueError("Shield must have a provider_resource_id (model_id)")
# Determine safety categories based on the model type
# For known Llama Guard models, use specific categories
@ -202,7 +206,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
messages = [input]
# convert to user messages format with role
messages = [UserMessage(content=m) for m in messages]
user_messages: list[Message] = [UserMessage(content=m) for m in messages]
# Determine safety categories based on the model type
# For known Llama Guard models, use specific categories
@ -221,7 +225,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
safety_categories=safety_categories,
)
return await impl.run_moderation(messages)
return await impl.run_moderation(user_messages)
class LlamaGuardShield:
@ -271,7 +275,7 @@ class LlamaGuardShield:
return final_categories
def validate_messages(self, messages: list[Message]) -> None:
def validate_messages(self, messages: list[Message]) -> list[Message]:
if len(messages) == 0:
raise ValueError("Messages must not be empty")
if messages[0].role != Role.user.value:
@ -283,7 +287,9 @@ class LlamaGuardShield:
return messages
async def run(self, messages: list[Message]) -> RunShieldResponse:
messages = self.validate_messages(messages)
validated_messages = self.validate_messages(messages)
if validated_messages is not None:
messages = validated_messages
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
shield_input_message = self.build_vision_shield_input(messages)
@ -296,8 +302,14 @@ class LlamaGuardShield:
messages=[shield_input_message],
stream=False,
)
content = response.completion_message.content
content = content.strip()
if hasattr(response, "completion_message"):
content = response.completion_message.content
if isinstance(content, str):
content = content.strip()
else:
raise ValueError(f"Expected string content, got {type(content)}")
else:
raise ValueError("Response does not have completion_message attribute")
return self.get_shield_response(content)
def build_text_shield_input(self, messages: list[Message]) -> UserMessage:
@ -315,27 +327,51 @@ class LlamaGuardShield:
most_recent_img = m.content
conversation.append(m)
elif isinstance(m.content, list):
content = []
text_content: list[TextContentItem] = []
for c in m.content:
if isinstance(c, str) or isinstance(c, TextContentItem):
content.append(c)
if isinstance(c, str):
text_content.append(TextContentItem(text=c))
elif isinstance(c, TextContentItem):
text_content.append(c)
elif isinstance(c, ImageContentItem):
if most_recent_img is None and m.role == Role.user.value:
most_recent_img = c
content.append(c)
# Note: we handle images separately for vision models
else:
raise ValueError(f"Unknown content type: {c}")
conversation.append(UserMessage(content=content))
if len(text_content) == 1:
conversation.append(UserMessage(content=text_content[0]))
elif len(text_content) > 1:
# Cast to the expected type
from typing import cast
content_list = cast(list[ImageContentItem | TextContentItem], text_content)
conversation.append(UserMessage(content=content_list))
else:
conversation.append(UserMessage(content=""))
else:
raise ValueError(f"Unknown content type: {m.content}")
prompt = []
prompt: list[ImageContentItem | str] = []
if most_recent_img is not None:
prompt.append(most_recent_img)
prompt.append(self.build_prompt(conversation[::-1]))
return UserMessage(content=prompt)
# Convert the prompt list to the expected content type
if len(prompt) == 1:
# Single item case
single_content = prompt[0]
return UserMessage(content=single_content)
else:
# Multiple items - convert strings to TextContentItem
mixed_content: list[ImageContentItem | TextContentItem] = []
for item in prompt:
if isinstance(item, str):
mixed_content.append(TextContentItem(text=item))
else:
mixed_content.append(item) # ImageContentItem
return UserMessage(content=mixed_content)
def build_prompt(self, messages: list[Message]) -> str:
categories = self.get_safety_categories()
@ -377,13 +413,42 @@ class LlamaGuardShield:
# TODO: Add Image based support for OpenAI Moderations
shield_input_message = self.build_text_shield_input(messages)
# Convert to OpenAI format - we need to import the conversion function
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
openai_message = await convert_message_to_openai_dict_new(shield_input_message)
# Cast to expected type to satisfy mypy
from typing import cast
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIDeveloperMessageParam,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
)
openai_message_param = (
OpenAIUserMessageParam
| OpenAISystemMessageParam
| OpenAIAssistantMessageParam
| OpenAIToolMessageParam
| OpenAIDeveloperMessageParam
)
openai_messages = [cast(openai_message_param, openai_message)]
response = await self.inference_api.openai_chat_completion(
model=self.model,
messages=[shield_input_message],
messages=openai_messages,
stream=False,
)
content = response.choices[0].message.content
content = content.strip()
if hasattr(response, "choices") and len(response.choices) > 0:
message_content = response.choices[0].message.content
if isinstance(message_content, str):
content = message_content.strip()
else:
raise ValueError(f"Expected string content, got {type(message_content)}")
else:
raise ValueError("Response does not have choices or choices is empty")
return self.get_moderation_object(content)
def create_moderation_object(self, model: str, unsafe_code: str | None = None) -> ModerationObject:
@ -399,10 +464,10 @@ class LlamaGuardShield:
# Set default values for safe case
categories = dict.fromkeys(SAFETY_CATEGORIES_TO_CODE_MAP.keys(), False)
category_scores = dict.fromkeys(SAFETY_CATEGORIES_TO_CODE_MAP.keys(), 1.0)
category_applied_input_types = {key: [] for key in SAFETY_CATEGORIES_TO_CODE_MAP.keys()}
category_applied_input_types: dict[str, list[str]] = {key: [] for key in SAFETY_CATEGORIES_TO_CODE_MAP.keys()}
flagged = False
user_message = None
metadata = {}
metadata: dict[str, Any] = {}
# Handle unsafe case
if unsafe_code:

View file

@ -271,7 +271,6 @@ exclude = [
"^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$",
"^llama_stack/providers/inline/post_training/common/validator\\.py$",
"^llama_stack/providers/inline/safety/code_scanner/",
"^llama_stack/providers/inline/safety/llama_guard/",
"^llama_stack/providers/inline/scoring/basic/",
"^llama_stack/providers/inline/scoring/braintrust/",
"^llama_stack/providers/inline/scoring/llm_as_judge/",