fixed more pre-checks

This commit is contained in:
Chantal D Gama Rose 2025-02-21 07:26:18 +00:00
parent 66726241aa
commit 23a6255795
3 changed files with 12 additions and 16 deletions

View file

@ -3,11 +3,11 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
import os
from enum import Enum
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, SecretStr, field_validator
from pydantic import BaseModel, Field, field_validator
from llama_stack.schema_utils import json_schema_type

View file

@ -4,23 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
from typing import Any, Dict, List
from llama_stack.apis.inference import Message, UserMessage
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.shields import Shield
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
import requests
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import (RunShieldResponse, Safety,
SafetyViolation, ViolationLevel)
from llama_stack.apis.shields import Shield
from llama_stack.distribution.library_client import \
convert_pydantic_to_json_value
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from .config import NVIDIASafetyConfig
logger = logging.getLogger(__name__)
@ -40,7 +36,7 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
async def register_shield(self, shield: Shield) -> None:
if not shield.provider_resource_id:
raise ValueError(f"Shield model not provided. ")
raise ValueError("Shield model not provided.")
async def run_shield(
self, shield_id: str, messages: List[Message], params: Dict[str, Any] = None

View file

@ -8,9 +8,9 @@ from pathlib import Path
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.nvidia.models import _MODEL_ENTRIES
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings