From 23a6255795f33960c9b1d3b4746d80ab884d4365 Mon Sep 17 00:00:00 2001 From: Chantal D Gama Rose Date: Fri, 21 Feb 2025 07:26:18 +0000 Subject: [PATCH] fixed more pre-checks --- .../providers/remote/safety/nvidia/config.py | 4 ++-- .../providers/remote/safety/nvidia/nvidia.py | 22 ++++++++----------- llama_stack/templates/nvidia/nvidia.py | 2 +- 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/llama_stack/providers/remote/safety/nvidia/config.py b/llama_stack/providers/remote/safety/nvidia/config.py index e688db1b9..187a16dc5 100644 --- a/llama_stack/providers/remote/safety/nvidia/config.py +++ b/llama_stack/providers/remote/safety/nvidia/config.py @@ -3,11 +3,11 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum import os +from enum import Enum from typing import Any, Dict, Optional -from pydantic import BaseModel, Field, SecretStr, field_validator +from pydantic import BaseModel, Field, field_validator from llama_stack.schema_utils import json_schema_type diff --git a/llama_stack/providers/remote/safety/nvidia/nvidia.py b/llama_stack/providers/remote/safety/nvidia/nvidia.py index 11bdd14a2..ce9b2953b 100644 --- a/llama_stack/providers/remote/safety/nvidia/nvidia.py +++ b/llama_stack/providers/remote/safety/nvidia/nvidia.py @@ -4,23 +4,19 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import json import logging from typing import Any, Dict, List -from llama_stack.apis.inference import Message, UserMessage -from llama_stack.apis.safety import ( - RunShieldResponse, - Safety, - SafetyViolation, - ViolationLevel, -) -from llama_stack.apis.shields import Shield -from llama_stack.distribution.library_client import convert_pydantic_to_json_value -from llama_stack.models.llama.datatypes import CoreModelId -from llama_stack.providers.datatypes import ShieldsProtocolPrivate import requests +from llama_stack.apis.inference import Message +from llama_stack.apis.safety import (RunShieldResponse, Safety, + SafetyViolation, ViolationLevel) +from llama_stack.apis.shields import Shield +from llama_stack.distribution.library_client import \ + convert_pydantic_to_json_value +from llama_stack.providers.datatypes import ShieldsProtocolPrivate + from .config import NVIDIASafetyConfig logger = logging.getLogger(__name__) @@ -40,7 +36,7 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): async def register_shield(self, shield: Shield) -> None: if not shield.provider_resource_id: - raise ValueError(f"Shield model not provided. ") + raise ValueError("Shield model not provided.") async def run_shield( self, shield_id: str, messages: List[Message], params: Dict[str, Any] = None diff --git a/llama_stack/templates/nvidia/nvidia.py b/llama_stack/templates/nvidia/nvidia.py index 035784641..6636978db 100644 --- a/llama_stack/templates/nvidia/nvidia.py +++ b/llama_stack/templates/nvidia/nvidia.py @@ -8,9 +8,9 @@ from pathlib import Path from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput from llama_stack.models.llama.sku_list import all_registered_models -from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.nvidia.models import _MODEL_ENTRIES +from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig from llama_stack.templates.template import DistributionTemplate, RunConfigSettings