diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index 5c7f30aa7..e562d0a1d 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -17,7 +17,7 @@ from llama_stack.apis.safety import ( SafetyViolation, ViolationLevel, ) -from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults +from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults, ShieldStore from llama_stack.apis.shields import Shield from llama_stack.core.datatypes import Api from llama_stack.log import get_logger @@ -136,6 +136,8 @@ logger = get_logger(name=__name__, category="safety") class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): + shield_store: ShieldStore + def __init__(self, config: LlamaGuardConfig, deps) -> None: self.config = config self.inference_api = deps[Api.inference] @@ -160,7 +162,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): self, shield_id: str, messages: list[Message], - params: dict[str, Any] = None, + params: dict[str, Any] | None = None, ) -> RunShieldResponse: shield = await self.shield_store.get_shield(shield_id) if not shield: @@ -175,6 +177,8 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): # Use the inference API's model resolution instead of hardcoded mappings # This allows the shield to work with any registered model model_id = shield.provider_resource_id + if not model_id: + raise ValueError("Shield must have a provider_resource_id (model_id)") # Determine safety categories based on the model type # For known Llama Guard models, use specific categories @@ -202,7 +206,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): messages = [input] # convert to user messages format with role - messages = [UserMessage(content=m) for m in messages] + user_messages: list[Message] = [UserMessage(content=m) for m in messages] # Determine safety categories based on the model type # For known Llama Guard models, use specific categories @@ -221,7 +225,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): safety_categories=safety_categories, ) - return await impl.run_moderation(messages) + return await impl.run_moderation(user_messages) class LlamaGuardShield: @@ -271,7 +275,7 @@ class LlamaGuardShield: return final_categories - def validate_messages(self, messages: list[Message]) -> None: + def validate_messages(self, messages: list[Message]) -> list[Message]: if len(messages) == 0: raise ValueError("Messages must not be empty") if messages[0].role != Role.user.value: @@ -283,7 +287,9 @@ class LlamaGuardShield: return messages async def run(self, messages: list[Message]) -> RunShieldResponse: - messages = self.validate_messages(messages) + validated_messages = self.validate_messages(messages) + if validated_messages is not None: + messages = validated_messages if self.model == CoreModelId.llama_guard_3_11b_vision.value: shield_input_message = self.build_vision_shield_input(messages) @@ -296,8 +302,14 @@ class LlamaGuardShield: messages=[shield_input_message], stream=False, ) - content = response.completion_message.content - content = content.strip() + if hasattr(response, "completion_message"): + content = response.completion_message.content + if isinstance(content, str): + content = content.strip() + else: + raise ValueError(f"Expected string content, got {type(content)}") + else: + raise ValueError("Response does not have completion_message attribute") return self.get_shield_response(content) def build_text_shield_input(self, messages: list[Message]) -> UserMessage: @@ -315,27 +327,51 @@ class LlamaGuardShield: most_recent_img = m.content conversation.append(m) elif isinstance(m.content, list): - content = [] + text_content: list[TextContentItem] = [] for c in m.content: - if isinstance(c, str) or isinstance(c, TextContentItem): - content.append(c) + if isinstance(c, str): + text_content.append(TextContentItem(text=c)) + elif isinstance(c, TextContentItem): + text_content.append(c) elif isinstance(c, ImageContentItem): if most_recent_img is None and m.role == Role.user.value: most_recent_img = c - content.append(c) + # Note: we handle images separately for vision models else: raise ValueError(f"Unknown content type: {c}") - conversation.append(UserMessage(content=content)) + if len(text_content) == 1: + conversation.append(UserMessage(content=text_content[0])) + elif len(text_content) > 1: + # Cast to the expected type + from typing import cast + + content_list = cast(list[ImageContentItem | TextContentItem], text_content) + conversation.append(UserMessage(content=content_list)) + else: + conversation.append(UserMessage(content="")) else: raise ValueError(f"Unknown content type: {m.content}") - prompt = [] + prompt: list[ImageContentItem | str] = [] if most_recent_img is not None: prompt.append(most_recent_img) prompt.append(self.build_prompt(conversation[::-1])) - return UserMessage(content=prompt) + # Convert the prompt list to the expected content type + if len(prompt) == 1: + # Single item case + single_content = prompt[0] + return UserMessage(content=single_content) + else: + # Multiple items - convert strings to TextContentItem + mixed_content: list[ImageContentItem | TextContentItem] = [] + for item in prompt: + if isinstance(item, str): + mixed_content.append(TextContentItem(text=item)) + else: + mixed_content.append(item) # ImageContentItem + return UserMessage(content=mixed_content) def build_prompt(self, messages: list[Message]) -> str: categories = self.get_safety_categories() @@ -377,13 +413,42 @@ class LlamaGuardShield: # TODO: Add Image based support for OpenAI Moderations shield_input_message = self.build_text_shield_input(messages) + # Convert to OpenAI format - we need to import the conversion function + from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new + + openai_message = await convert_message_to_openai_dict_new(shield_input_message) + # Cast to expected type to satisfy mypy + from typing import cast + + from llama_stack.apis.inference import ( + OpenAIAssistantMessageParam, + OpenAIDeveloperMessageParam, + OpenAISystemMessageParam, + OpenAIToolMessageParam, + OpenAIUserMessageParam, + ) + + openai_message_param = ( + OpenAIUserMessageParam + | OpenAISystemMessageParam + | OpenAIAssistantMessageParam + | OpenAIToolMessageParam + | OpenAIDeveloperMessageParam + ) + openai_messages = [cast(openai_message_param, openai_message)] response = await self.inference_api.openai_chat_completion( model=self.model, - messages=[shield_input_message], + messages=openai_messages, stream=False, ) - content = response.choices[0].message.content - content = content.strip() + if hasattr(response, "choices") and len(response.choices) > 0: + message_content = response.choices[0].message.content + if isinstance(message_content, str): + content = message_content.strip() + else: + raise ValueError(f"Expected string content, got {type(message_content)}") + else: + raise ValueError("Response does not have choices or choices is empty") return self.get_moderation_object(content) def create_moderation_object(self, model: str, unsafe_code: str | None = None) -> ModerationObject: @@ -399,10 +464,10 @@ class LlamaGuardShield: # Set default values for safe case categories = dict.fromkeys(SAFETY_CATEGORIES_TO_CODE_MAP.keys(), False) category_scores = dict.fromkeys(SAFETY_CATEGORIES_TO_CODE_MAP.keys(), 1.0) - category_applied_input_types = {key: [] for key in SAFETY_CATEGORIES_TO_CODE_MAP.keys()} + category_applied_input_types: dict[str, list[str]] = {key: [] for key in SAFETY_CATEGORIES_TO_CODE_MAP.keys()} flagged = False user_message = None - metadata = {} + metadata: dict[str, Any] = {} # Handle unsafe case if unsafe_code: diff --git a/pyproject.toml b/pyproject.toml index 98bae47c5..630d20adb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -268,7 +268,6 @@ exclude = [ "^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$", "^llama_stack/providers/inline/post_training/common/validator\\.py$", "^llama_stack/providers/inline/safety/code_scanner/", - "^llama_stack/providers/inline/safety/llama_guard/", "^llama_stack/providers/inline/scoring/basic/", "^llama_stack/providers/inline/scoring/braintrust/", "^llama_stack/providers/inline/scoring/llm_as_judge/",