mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 12:06:04 +00:00
clean up
This commit is contained in:
parent
67de6af0f0
commit
bc5eeef6f3
2 changed files with 2 additions and 10 deletions
|
|
@ -13,7 +13,6 @@ from llama_stack.apis.safety import ModerationObject, RunShieldResponse, Safety,
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
|
|
||||||
|
|
||||||
from .config import NVIDIASafetyConfig
|
from .config import NVIDIASafetyConfig
|
||||||
|
|
||||||
|
|
@ -132,10 +131,9 @@ class NeMoGuardrails:
|
||||||
Raises:
|
Raises:
|
||||||
requests.HTTPError: If the POST request fails.
|
requests.HTTPError: If the POST request fails.
|
||||||
"""
|
"""
|
||||||
request_messages = [await convert_message_to_openai_dict_new(message) for message in messages]
|
|
||||||
request_data = {
|
request_data = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": request_messages,
|
"messages": messages,
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"top_p": 1,
|
"top_p": 1,
|
||||||
"frequency_penalty": 0,
|
"frequency_penalty": 0,
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
|
@ -21,7 +20,6 @@ from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
|
|
||||||
|
|
||||||
from .config import SambaNovaSafetyConfig
|
from .config import SambaNovaSafetyConfig
|
||||||
|
|
||||||
|
|
@ -80,12 +78,8 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
|
||||||
|
|
||||||
shield_params = shield.params
|
shield_params = shield.params
|
||||||
logger.debug(f"run_shield::{shield_params}::messages={messages}")
|
logger.debug(f"run_shield::{shield_params}::messages={messages}")
|
||||||
content_messages = [await convert_message_to_openai_dict_new(m) for m in messages]
|
|
||||||
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
|
|
||||||
|
|
||||||
response = litellm.completion(
|
response = litellm.completion(model=shield.provider_resource_id, messages=messages, api_key=self._get_api_key())
|
||||||
model=shield.provider_resource_id, messages=content_messages, api_key=self._get_api_key()
|
|
||||||
)
|
|
||||||
shield_message = response.choices[0].message.content
|
shield_message = response.choices[0].message.content
|
||||||
|
|
||||||
if "unsafe" in shield_message.lower():
|
if "unsafe" in shield_message.lower():
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue