fixed more pre-checks

This commit is contained in:
Chantal D Gama Rose 2025-02-21 07:26:18 +00:00
parent 66726241aa
commit 23a6255795
3 changed files with 12 additions and 16 deletions

View file

@ -3,11 +3,11 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum
import os import os
from enum import Enum
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, SecretStr, field_validator from pydantic import BaseModel, Field, field_validator
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type

View file

@ -4,23 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json
import logging import logging
from typing import Any, Dict, List from typing import Any, Dict, List
from llama_stack.apis.inference import Message, UserMessage
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.shields import Shield
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
import requests import requests
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import (RunShieldResponse, Safety,
SafetyViolation, ViolationLevel)
from llama_stack.apis.shields import Shield
from llama_stack.distribution.library_client import \
convert_pydantic_to_json_value
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from .config import NVIDIASafetyConfig from .config import NVIDIASafetyConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -40,7 +36,7 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
async def register_shield(self, shield: Shield) -> None: async def register_shield(self, shield: Shield) -> None:
if not shield.provider_resource_id: if not shield.provider_resource_id:
raise ValueError(f"Shield model not provided. ") raise ValueError("Shield model not provided.")
async def run_shield( async def run_shield(
self, shield_id: str, messages: List[Message], params: Dict[str, Any] = None self, shield_id: str, messages: List[Message], params: Dict[str, Any] = None

View file

@ -8,9 +8,9 @@ from pathlib import Path
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput
from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.nvidia.models import _MODEL_ENTRIES from llama_stack.providers.remote.inference.nvidia.models import _MODEL_ENTRIES
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings from llama_stack.templates.template import DistributionTemplate, RunConfigSettings