This commit is contained in:
Swapna Lekkala 2025-10-12 06:43:43 -07:00
parent 67de6af0f0
commit bc5eeef6f3
2 changed files with 2 additions and 10 deletions

View file

@ -13,7 +13,6 @@ from llama_stack.apis.safety import ModerationObject, RunShieldResponse, Safety,
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
from .config import NVIDIASafetyConfig from .config import NVIDIASafetyConfig
@ -132,10 +131,9 @@ class NeMoGuardrails:
Raises: Raises:
requests.HTTPError: If the POST request fails. requests.HTTPError: If the POST request fails.
""" """
request_messages = [await convert_message_to_openai_dict_new(message) for message in messages]
request_data = { request_data = {
"model": self.model, "model": self.model,
"messages": request_messages, "messages": messages,
"temperature": self.temperature, "temperature": self.temperature,
"top_p": 1, "top_p": 1,
"frequency_penalty": 0, "frequency_penalty": 0,

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json
from typing import Any from typing import Any
import litellm import litellm
@ -21,7 +20,6 @@ from llama_stack.apis.shields import Shield
from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
from .config import SambaNovaSafetyConfig from .config import SambaNovaSafetyConfig
@ -80,12 +78,8 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
shield_params = shield.params shield_params = shield.params
logger.debug(f"run_shield::{shield_params}::messages={messages}") logger.debug(f"run_shield::{shield_params}::messages={messages}")
content_messages = [await convert_message_to_openai_dict_new(m) for m in messages]
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
response = litellm.completion( response = litellm.completion(model=shield.provider_resource_id, messages=messages, api_key=self._get_api_key())
model=shield.provider_resource_id, messages=content_messages, api_key=self._get_api_key()
)
shield_message = response.choices[0].message.content shield_message = response.choices[0].message.content
if "unsafe" in shield_message.lower(): if "unsafe" in shield_message.lower():