chore!: Safety api refactoring to use OpenAIMessageParam (#3796)
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.12) (push) Failing after 1s
Python Package Build Test / build (3.13) (push) Failing after 1s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
Test External API and Providers / test-external (venv) (push) Failing after 4s
Vector IO Integration Tests / test-matrix (push) Failing after 6s
Unit Tests / unit-tests (3.12) (push) Failing after 4s
Unit Tests / unit-tests (3.13) (push) Failing after 3s
API Conformance Tests / check-schema-compatibility (push) Successful in 13s
UI Tests / ui-tests (22) (push) Successful in 40s
Pre-commit / pre-commit (push) Successful in 1m28s

# What does this PR do?
Remove usage of deprecated `Message` from Safety apis


## Test Plan
CI
This commit is contained in:
slekkala1 2025-10-12 08:01:00 -07:00 committed by GitHub
parent 82cbcada39
commit 3bb6ef351b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
37 changed files with 2455 additions and 1050 deletions

View file

@ -8,12 +8,11 @@ from typing import Any
import requests
from llama_stack.apis.inference import Message
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import ModerationObject, RunShieldResponse, Safety, SafetyViolation, ViolationLevel
from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
from .config import NVIDIASafetyConfig
@ -44,7 +43,7 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
pass
async def run_shield(
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
self, shield_id: str, messages: list[OpenAIMessageParam], params: dict[str, Any] | None = None
) -> RunShieldResponse:
"""
Run a safety shield check against the provided messages.
@ -118,7 +117,7 @@ class NeMoGuardrails:
response.raise_for_status()
return response.json()
async def run(self, messages: list[Message]) -> RunShieldResponse:
async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
"""
Queries the /v1/guardrails/checks endpoint of the NeMo guardrails deployed API.
@ -132,10 +131,9 @@ class NeMoGuardrails:
Raises:
requests.HTTPError: If the POST request fails.
"""
request_messages = [await convert_message_to_openai_dict_new(message) for message in messages]
request_data = {
"model": self.model,
"messages": request_messages,
"messages": [{"role": message.role, "content": message.content} for message in messages],
"temperature": self.temperature,
"top_p": 1,
"frequency_penalty": 0,