fix: Fix messages format in NVIDIA safety check request body

This commit is contained in:
Jash Gulabrai 2025-04-30 10:20:15 -04:00
parent 653e8526ec
commit 5fcf20d934
2 changed files with 9 additions and 10 deletions

View file

@ -12,8 +12,8 @@ import requests
from llama_stack.apis.inference import Message from llama_stack.apis.inference import Message
from llama_stack.apis.safety import RunShieldResponse, Safety, SafetyViolation, ViolationLevel from llama_stack.apis.safety import RunShieldResponse, Safety, SafetyViolation, ViolationLevel
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
from .config import NVIDIASafetyConfig from .config import NVIDIASafetyConfig
@ -28,7 +28,6 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
Args: Args:
config (NVIDIASafetyConfig): The configuration containing the guardrails service URL and config ID. config (NVIDIASafetyConfig): The configuration containing the guardrails service URL and config ID.
""" """
print(f"Initializing NVIDIASafetyAdapter({config.guardrails_service_url})...")
self.config = config self.config = config
async def initialize(self) -> None: async def initialize(self) -> None:
@ -127,9 +126,10 @@ class NeMoGuardrails:
Raises: Raises:
requests.HTTPError: If the POST request fails. requests.HTTPError: If the POST request fails.
""" """
request_messages = [await convert_message_to_openai_dict_new(message) for message in messages]
request_data = { request_data = {
"model": self.model, "model": self.model,
"messages": convert_pydantic_to_json_value(messages), "messages": request_messages,
"temperature": self.temperature, "temperature": self.temperature,
"top_p": 1, "top_p": 1,
"frequency_penalty": 0, "frequency_penalty": 0,

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json
import os import os
import unittest import unittest
from typing import Any from typing import Any
@ -139,8 +138,8 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
data={ data={
"model": shield_id, "model": shield_id,
"messages": [ "messages": [
json.loads(messages[0].model_dump_json()), {"role": "user", "content": "Hello, how are you?"},
json.loads(messages[1].model_dump_json()), {"role": "assistant", "content": "I'm doing well, thank you for asking!"},
], ],
"temperature": 1.0, "temperature": 1.0,
"top_p": 1, "top_p": 1,
@ -193,8 +192,8 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
data={ data={
"model": shield_id, "model": shield_id,
"messages": [ "messages": [
json.loads(messages[0].model_dump_json()), {"role": "user", "content": "Hello, how are you?"},
json.loads(messages[1].model_dump_json()), {"role": "assistant", "content": "I'm doing well, thank you for asking!"},
], ],
"temperature": 1.0, "temperature": 1.0,
"top_p": 1, "top_p": 1,
@ -269,8 +268,8 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
data={ data={
"model": shield_id, "model": shield_id,
"messages": [ "messages": [
json.loads(messages[0].model_dump_json()), {"role": "user", "content": "Hello, how are you?"},
json.loads(messages[1].model_dump_json()), {"role": "assistant", "content": "I'm doing well, thank you for asking!"},
], ],
"temperature": 1.0, "temperature": 1.0,
"top_p": 1, "top_p": 1,