chore!: Safety api refactoring to use OpenAIMessageParam (#3796)
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.12) (push) Failing after 1s
Python Package Build Test / build (3.13) (push) Failing after 1s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
Test External API and Providers / test-external (venv) (push) Failing after 4s
Vector IO Integration Tests / test-matrix (push) Failing after 6s
Unit Tests / unit-tests (3.12) (push) Failing after 4s
Unit Tests / unit-tests (3.13) (push) Failing after 3s
API Conformance Tests / check-schema-compatibility (push) Successful in 13s
UI Tests / ui-tests (22) (push) Successful in 40s
Pre-commit / pre-commit (push) Successful in 1m28s

# What does this PR do?
Remove usage of deprecated `Message` from Safety apis


## Test Plan
CI
This commit is contained in:
slekkala1 2025-10-12 08:01:00 -07:00 committed by GitHub
parent 82cbcada39
commit 3bb6ef351b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
37 changed files with 2455 additions and 1050 deletions

View file

@ -4,13 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Any
import litellm
import requests
from llama_stack.apis.inference import Message
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
@ -21,7 +20,6 @@ from llama_stack.apis.shields import Shield
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
from .config import SambaNovaSafetyConfig
@ -72,7 +70,7 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
pass
async def run_shield(
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
self, shield_id: str, messages: list[OpenAIMessageParam], params: dict[str, Any] | None = None
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
if not shield:
@ -80,12 +78,8 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
shield_params = shield.params
logger.debug(f"run_shield::{shield_params}::messages={messages}")
content_messages = [await convert_message_to_openai_dict_new(m) for m in messages]
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
response = litellm.completion(
model=shield.provider_resource_id, messages=content_messages, api_key=self._get_api_key()
)
response = litellm.completion(model=shield.provider_resource_id, messages=messages, api_key=self._get_api_key())
shield_message = response.choices[0].message.content
if "unsafe" in shield_message.lower():