mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
Merge 0543f74c32
into fdb144f009
This commit is contained in:
commit
0bbc029597
2 changed files with 85 additions and 21 deletions
|
@ -17,7 +17,7 @@ from llama_stack.apis.safety import (
|
|||
SafetyViolation,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
|
||||
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults, ShieldStore
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.core.datatypes import Api
|
||||
from llama_stack.log import get_logger
|
||||
|
@ -136,6 +136,8 @@ logger = get_logger(name=__name__, category="safety")
|
|||
|
||||
|
||||
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||
shield_store: ShieldStore
|
||||
|
||||
def __init__(self, config: LlamaGuardConfig, deps) -> None:
|
||||
self.config = config
|
||||
self.inference_api = deps[Api.inference]
|
||||
|
@ -160,7 +162,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
self,
|
||||
shield_id: str,
|
||||
messages: list[Message],
|
||||
params: dict[str, Any] = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
|
@ -175,6 +177,8 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
# Use the inference API's model resolution instead of hardcoded mappings
|
||||
# This allows the shield to work with any registered model
|
||||
model_id = shield.provider_resource_id
|
||||
if not model_id:
|
||||
raise ValueError("Shield must have a provider_resource_id (model_id)")
|
||||
|
||||
# Determine safety categories based on the model type
|
||||
# For known Llama Guard models, use specific categories
|
||||
|
@ -202,7 +206,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
messages = [input]
|
||||
|
||||
# convert to user messages format with role
|
||||
messages = [UserMessage(content=m) for m in messages]
|
||||
user_messages: list[Message] = [UserMessage(content=m) for m in messages]
|
||||
|
||||
# Determine safety categories based on the model type
|
||||
# For known Llama Guard models, use specific categories
|
||||
|
@ -221,7 +225,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
safety_categories=safety_categories,
|
||||
)
|
||||
|
||||
return await impl.run_moderation(messages)
|
||||
return await impl.run_moderation(user_messages)
|
||||
|
||||
|
||||
class LlamaGuardShield:
|
||||
|
@ -271,7 +275,7 @@ class LlamaGuardShield:
|
|||
|
||||
return final_categories
|
||||
|
||||
def validate_messages(self, messages: list[Message]) -> None:
|
||||
def validate_messages(self, messages: list[Message]) -> list[Message]:
|
||||
if len(messages) == 0:
|
||||
raise ValueError("Messages must not be empty")
|
||||
if messages[0].role != Role.user.value:
|
||||
|
@ -283,7 +287,9 @@ class LlamaGuardShield:
|
|||
return messages
|
||||
|
||||
async def run(self, messages: list[Message]) -> RunShieldResponse:
|
||||
messages = self.validate_messages(messages)
|
||||
validated_messages = self.validate_messages(messages)
|
||||
if validated_messages is not None:
|
||||
messages = validated_messages
|
||||
|
||||
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
|
||||
shield_input_message = self.build_vision_shield_input(messages)
|
||||
|
@ -296,8 +302,14 @@ class LlamaGuardShield:
|
|||
messages=[shield_input_message],
|
||||
stream=False,
|
||||
)
|
||||
content = response.completion_message.content
|
||||
content = content.strip()
|
||||
if hasattr(response, "completion_message"):
|
||||
content = response.completion_message.content
|
||||
if isinstance(content, str):
|
||||
content = content.strip()
|
||||
else:
|
||||
raise ValueError(f"Expected string content, got {type(content)}")
|
||||
else:
|
||||
raise ValueError("Response does not have completion_message attribute")
|
||||
return self.get_shield_response(content)
|
||||
|
||||
def build_text_shield_input(self, messages: list[Message]) -> UserMessage:
|
||||
|
@ -315,27 +327,51 @@ class LlamaGuardShield:
|
|||
most_recent_img = m.content
|
||||
conversation.append(m)
|
||||
elif isinstance(m.content, list):
|
||||
content = []
|
||||
text_content: list[TextContentItem] = []
|
||||
for c in m.content:
|
||||
if isinstance(c, str) or isinstance(c, TextContentItem):
|
||||
content.append(c)
|
||||
if isinstance(c, str):
|
||||
text_content.append(TextContentItem(text=c))
|
||||
elif isinstance(c, TextContentItem):
|
||||
text_content.append(c)
|
||||
elif isinstance(c, ImageContentItem):
|
||||
if most_recent_img is None and m.role == Role.user.value:
|
||||
most_recent_img = c
|
||||
content.append(c)
|
||||
# Note: we handle images separately for vision models
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {c}")
|
||||
|
||||
conversation.append(UserMessage(content=content))
|
||||
if len(text_content) == 1:
|
||||
conversation.append(UserMessage(content=text_content[0]))
|
||||
elif len(text_content) > 1:
|
||||
# Cast to the expected type
|
||||
from typing import cast
|
||||
|
||||
content_list = cast(list[ImageContentItem | TextContentItem], text_content)
|
||||
conversation.append(UserMessage(content=content_list))
|
||||
else:
|
||||
conversation.append(UserMessage(content=""))
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {m.content}")
|
||||
|
||||
prompt = []
|
||||
prompt: list[ImageContentItem | str] = []
|
||||
if most_recent_img is not None:
|
||||
prompt.append(most_recent_img)
|
||||
prompt.append(self.build_prompt(conversation[::-1]))
|
||||
|
||||
return UserMessage(content=prompt)
|
||||
# Convert the prompt list to the expected content type
|
||||
if len(prompt) == 1:
|
||||
# Single item case
|
||||
single_content = prompt[0]
|
||||
return UserMessage(content=single_content)
|
||||
else:
|
||||
# Multiple items - convert strings to TextContentItem
|
||||
mixed_content: list[ImageContentItem | TextContentItem] = []
|
||||
for item in prompt:
|
||||
if isinstance(item, str):
|
||||
mixed_content.append(TextContentItem(text=item))
|
||||
else:
|
||||
mixed_content.append(item) # ImageContentItem
|
||||
return UserMessage(content=mixed_content)
|
||||
|
||||
def build_prompt(self, messages: list[Message]) -> str:
|
||||
categories = self.get_safety_categories()
|
||||
|
@ -377,13 +413,42 @@ class LlamaGuardShield:
|
|||
# TODO: Add Image based support for OpenAI Moderations
|
||||
shield_input_message = self.build_text_shield_input(messages)
|
||||
|
||||
# Convert to OpenAI format - we need to import the conversion function
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
|
||||
|
||||
openai_message = await convert_message_to_openai_dict_new(shield_input_message)
|
||||
# Cast to expected type to satisfy mypy
|
||||
from typing import cast
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
|
||||
openai_message_param = (
|
||||
OpenAIUserMessageParam
|
||||
| OpenAISystemMessageParam
|
||||
| OpenAIAssistantMessageParam
|
||||
| OpenAIToolMessageParam
|
||||
| OpenAIDeveloperMessageParam
|
||||
)
|
||||
openai_messages = [cast(openai_message_param, openai_message)]
|
||||
response = await self.inference_api.openai_chat_completion(
|
||||
model=self.model,
|
||||
messages=[shield_input_message],
|
||||
messages=openai_messages,
|
||||
stream=False,
|
||||
)
|
||||
content = response.choices[0].message.content
|
||||
content = content.strip()
|
||||
if hasattr(response, "choices") and len(response.choices) > 0:
|
||||
message_content = response.choices[0].message.content
|
||||
if isinstance(message_content, str):
|
||||
content = message_content.strip()
|
||||
else:
|
||||
raise ValueError(f"Expected string content, got {type(message_content)}")
|
||||
else:
|
||||
raise ValueError("Response does not have choices or choices is empty")
|
||||
return self.get_moderation_object(content)
|
||||
|
||||
def create_moderation_object(self, model: str, unsafe_code: str | None = None) -> ModerationObject:
|
||||
|
@ -399,10 +464,10 @@ class LlamaGuardShield:
|
|||
# Set default values for safe case
|
||||
categories = dict.fromkeys(SAFETY_CATEGORIES_TO_CODE_MAP.keys(), False)
|
||||
category_scores = dict.fromkeys(SAFETY_CATEGORIES_TO_CODE_MAP.keys(), 1.0)
|
||||
category_applied_input_types = {key: [] for key in SAFETY_CATEGORIES_TO_CODE_MAP.keys()}
|
||||
category_applied_input_types: dict[str, list[str]] = {key: [] for key in SAFETY_CATEGORIES_TO_CODE_MAP.keys()}
|
||||
flagged = False
|
||||
user_message = None
|
||||
metadata = {}
|
||||
metadata: dict[str, Any] = {}
|
||||
|
||||
# Handle unsafe case
|
||||
if unsafe_code:
|
||||
|
|
|
@ -268,7 +268,6 @@ exclude = [
|
|||
"^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$",
|
||||
"^llama_stack/providers/inline/post_training/common/validator\\.py$",
|
||||
"^llama_stack/providers/inline/safety/code_scanner/",
|
||||
"^llama_stack/providers/inline/safety/llama_guard/",
|
||||
"^llama_stack/providers/inline/scoring/basic/",
|
||||
"^llama_stack/providers/inline/scoring/braintrust/",
|
||||
"^llama_stack/providers/inline/scoring/llm_as_judge/",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue