chore!: Safety api refactoring to use OpenAIMessageParam (#3796)
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.12) (push) Failing after 1s
Python Package Build Test / build (3.13) (push) Failing after 1s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
Test External API and Providers / test-external (venv) (push) Failing after 4s
Vector IO Integration Tests / test-matrix (push) Failing after 6s
Unit Tests / unit-tests (3.12) (push) Failing after 4s
Unit Tests / unit-tests (3.13) (push) Failing after 3s
API Conformance Tests / check-schema-compatibility (push) Successful in 13s
UI Tests / ui-tests (22) (push) Successful in 40s
Pre-commit / pre-commit (push) Successful in 1m28s

# What does this PR do?
Remove usage of deprecated `Message` from Safety apis


## Test Plan
CI
This commit is contained in:
slekkala1 2025-10-12 08:01:00 -07:00 committed by GitHub
parent 82cbcada39
commit 3bb6ef351b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
37 changed files with 2455 additions and 1050 deletions

View file

@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from codeshield.cs import CodeShieldScanResult
from llama_stack.apis.inference import Message
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
@ -53,7 +53,7 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
async def run_shield(
self,
shield_id: str,
messages: list[Message],
messages: list[OpenAIMessageParam],
params: dict[str, Any] = None,
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)

View file

@ -12,10 +12,9 @@ from typing import Any
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.apis.inference import (
Inference,
Message,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIMessageParam,
OpenAIUserMessageParam,
UserMessage,
)
from llama_stack.apis.safety import (
RunShieldResponse,
@ -165,7 +164,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
async def run_shield(
self,
shield_id: str,
messages: list[Message],
messages: list[OpenAIMessageParam],
params: dict[str, Any] = None,
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
@ -175,8 +174,8 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
messages = messages.copy()
# some shields like llama-guard require the first message to be a user message
# since this might be a tool call, first role might not be user
if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content)
if len(messages) > 0 and messages[0].role != "user":
messages[0] = OpenAIUserMessageParam(content=messages[0].content)
# Use the inference API's model resolution instead of hardcoded mappings
# This allows the shield to work with any registered model
@ -208,7 +207,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
messages = [input]
# convert to user messages format with role
messages = [UserMessage(content=m) for m in messages]
messages = [OpenAIUserMessageParam(content=m) for m in messages]
# Determine safety categories based on the model type
# For known Llama Guard models, use specific categories
@ -277,7 +276,7 @@ class LlamaGuardShield:
return final_categories
def validate_messages(self, messages: list[Message]) -> None:
def validate_messages(self, messages: list[OpenAIMessageParam]) -> list[OpenAIMessageParam]:
if len(messages) == 0:
raise ValueError("Messages must not be empty")
if messages[0].role != Role.user.value:
@ -288,7 +287,7 @@ class LlamaGuardShield:
return messages
async def run(self, messages: list[Message]) -> RunShieldResponse:
async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
messages = self.validate_messages(messages)
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
@ -307,10 +306,10 @@ class LlamaGuardShield:
content = content.strip()
return self.get_shield_response(content)
def build_text_shield_input(self, messages: list[Message]) -> OpenAIUserMessageParam:
return OpenAIUserMessageParam(role="user", content=self.build_prompt(messages))
def build_text_shield_input(self, messages: list[OpenAIMessageParam]) -> OpenAIUserMessageParam:
return OpenAIUserMessageParam(content=self.build_prompt(messages))
def build_vision_shield_input(self, messages: list[Message]) -> OpenAIUserMessageParam:
def build_vision_shield_input(self, messages: list[OpenAIMessageParam]) -> OpenAIUserMessageParam:
conversation = []
most_recent_img = None
@ -333,7 +332,7 @@ class LlamaGuardShield:
else:
raise ValueError(f"Unknown content type: {c}")
conversation.append(UserMessage(content=content))
conversation.append(OpenAIUserMessageParam(content=content))
else:
raise ValueError(f"Unknown content type: {m.content}")
@ -342,9 +341,9 @@ class LlamaGuardShield:
prompt.append(most_recent_img)
prompt.append(self.build_prompt(conversation[::-1]))
return OpenAIUserMessageParam(role="user", content=prompt)
return OpenAIUserMessageParam(content=prompt)
def build_prompt(self, messages: list[Message]) -> str:
def build_prompt(self, messages: list[OpenAIMessageParam]) -> str:
categories = self.get_safety_categories()
categories_str = "\n".join(categories)
conversations_str = "\n\n".join(
@ -377,7 +376,7 @@ class LlamaGuardShield:
raise ValueError(f"Unexpected response: {response}")
async def run_moderation(self, messages: list[Message]) -> ModerationObject:
async def run_moderation(self, messages: list[OpenAIMessageParam]) -> ModerationObject:
if not messages:
return self.create_moderation_object(self.model)
@ -388,6 +387,7 @@ class LlamaGuardShield:
model=self.model,
messages=[shield_input_message],
stream=False,
temperature=0.0, # default is 1, which is too high for safety
)
response = await self.inference_api.openai_chat_completion(params)
content = response.choices[0].message.content

View file

@ -9,7 +9,7 @@ from typing import Any
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from llama_stack.apis.inference import Message
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
@ -22,9 +22,7 @@ from llama_stack.apis.shields import Shield
from llama_stack.core.utils.model_utils import model_local_dir
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from .config import PromptGuardConfig, PromptGuardType
@ -56,7 +54,7 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
async def run_shield(
self,
shield_id: str,
messages: list[Message],
messages: list[OpenAIMessageParam],
params: dict[str, Any],
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
@ -93,7 +91,7 @@ class PromptGuardShield:
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir, device_map=self.device)
async def run(self, messages: list[Message]) -> RunShieldResponse:
async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
message = messages[-1]
text = interleaved_content_as_str(message.content)

View file

@ -7,7 +7,7 @@
import json
from typing import Any
from llama_stack.apis.inference import Message
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
@ -56,7 +56,7 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
pass
async def run_shield(
self, shield_id: str, messages: list[Message], params: dict[str, Any] = None
self, shield_id: str, messages: list[OpenAIMessageParam], params: dict[str, Any] = None
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
if not shield:

View file

@ -8,12 +8,11 @@ from typing import Any
import requests
from llama_stack.apis.inference import Message
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import ModerationObject, RunShieldResponse, Safety, SafetyViolation, ViolationLevel
from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
from .config import NVIDIASafetyConfig
@ -44,7 +43,7 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
pass
async def run_shield(
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
self, shield_id: str, messages: list[OpenAIMessageParam], params: dict[str, Any] | None = None
) -> RunShieldResponse:
"""
Run a safety shield check against the provided messages.
@ -118,7 +117,7 @@ class NeMoGuardrails:
response.raise_for_status()
return response.json()
async def run(self, messages: list[Message]) -> RunShieldResponse:
async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
"""
Queries the /v1/guardrails/checks endpoint of the NeMo guardrails deployed API.
@ -132,10 +131,9 @@ class NeMoGuardrails:
Raises:
requests.HTTPError: If the POST request fails.
"""
request_messages = [await convert_message_to_openai_dict_new(message) for message in messages]
request_data = {
"model": self.model,
"messages": request_messages,
"messages": [{"role": message.role, "content": message.content} for message in messages],
"temperature": self.temperature,
"top_p": 1,
"frequency_penalty": 0,

View file

@ -4,13 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Any
import litellm
import requests
from llama_stack.apis.inference import Message
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
@ -21,7 +20,6 @@ from llama_stack.apis.shields import Shield
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
from .config import SambaNovaSafetyConfig
@ -72,7 +70,7 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
pass
async def run_shield(
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
self, shield_id: str, messages: list[OpenAIMessageParam], params: dict[str, Any] | None = None
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
if not shield:
@ -80,12 +78,8 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
shield_params = shield.params
logger.debug(f"run_shield::{shield_params}::messages={messages}")
content_messages = [await convert_message_to_openai_dict_new(m) for m in messages]
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
response = litellm.completion(
model=shield.provider_resource_id, messages=content_messages, api_key=self._get_api_key()
)
response = litellm.completion(model=shield.provider_resource_id, messages=messages, api_key=self._get_api_key())
shield_message = response.choices[0].message.content
if "unsafe" in shield_message.lower():

View file

@ -9,6 +9,7 @@ import base64
import io
import json
import re
from typing import Any
import httpx
from PIL import Image as PIL_Image
@ -23,6 +24,9 @@ from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
Message,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIFile,
ResponseFormat,
ResponseFormatType,
SystemMessage,
@ -74,14 +78,22 @@ def decode_assistant_message(content: str, stop_reason: StopReason) -> RawMessag
return formatter.decode_assistant_message_from_content(content, stop_reason)
def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str:
def interleaved_content_as_str(
content: Any,
sep: str = " ",
) -> str:
if content is None:
return ""
def _process(c) -> str:
if isinstance(c, str):
return c
elif isinstance(c, ImageContentItem):
return "<image>"
elif isinstance(c, TextContentItem):
elif isinstance(c, TextContentItem) or isinstance(c, OpenAIChatCompletionContentPartTextParam):
return c.text
elif isinstance(c, ImageContentItem) or isinstance(c, OpenAIChatCompletionContentPartImageParam):
return "<image>"
elif isinstance(c, OpenAIFile):
return "<file>"
else:
raise ValueError(f"Unsupported content type: {type(c)}")