chore!: Safety api refactoring to use OpenAIMessageParam (#3796)
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.12) (push) Failing after 1s
Python Package Build Test / build (3.13) (push) Failing after 1s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
Test External API and Providers / test-external (venv) (push) Failing after 4s
Vector IO Integration Tests / test-matrix (push) Failing after 6s
Unit Tests / unit-tests (3.12) (push) Failing after 4s
Unit Tests / unit-tests (3.13) (push) Failing after 3s
API Conformance Tests / check-schema-compatibility (push) Successful in 13s
UI Tests / ui-tests (22) (push) Successful in 40s
Pre-commit / pre-commit (push) Successful in 1m28s

# What does this PR do?
Remove usage of deprecated `Message` from Safety apis


## Test Plan
CI
This commit is contained in:
slekkala1 2025-10-12 08:01:00 -07:00 committed by GitHub
parent 82cbcada39
commit 3bb6ef351b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
37 changed files with 2455 additions and 1050 deletions

View file

@ -9,7 +9,7 @@ from typing import Any
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from llama_stack.apis.inference import Message
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
@ -22,9 +22,7 @@ from llama_stack.apis.shields import Shield
from llama_stack.core.utils.model_utils import model_local_dir
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from .config import PromptGuardConfig, PromptGuardType
@ -56,7 +54,7 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
async def run_shield(
self,
shield_id: str,
messages: list[Message],
messages: list[OpenAIMessageParam],
params: dict[str, Any],
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
@ -93,7 +91,7 @@ class PromptGuardShield:
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir, device_map=self.device)
async def run(self, messages: list[Message]) -> RunShieldResponse:
async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
message = messages[-1]
text = interleaved_content_as_str(message.content)