setting up tests

This commit is contained in:
Chantal D Gama Rose 2025-02-19 15:46:49 -08:00
parent b90ff9ca16
commit 685deeab8f
4 changed files with 24 additions and 11 deletions

View file

@ -49,8 +49,6 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
self.shield = NeMoGuardrails(self.config, shield.provider_resource_id)
return await self.shield.run(messages)
class NeMoGuardrails:
def __init__(
@ -107,4 +105,4 @@ class NeMoGuardrails:
metadata=metadata,
)
)
return RunShieldResponse(violation=None)
return RunShieldResponse(violation=None)

View file

@ -51,11 +51,19 @@ DEFAULT_PROVIDER_COMBINATIONS = [
id="remote",
marks=pytest.mark.remote,
),
pytest.param(
{
"inference": "nvidia",
"safety": "nvidia",
},
id="nvidia",
marks=pytest.mark.meta_reference,
),
]
def pytest_configure(config):
for mark in ["meta_reference", "ollama", "together", "remote", "bedrock"]:
for mark in ["meta_reference", "ollama", "together", "remote", "bedrock", "nvidia"]:
config.addinivalue_line(
"markers",
f"{mark}: marks tests as {mark} specific",

View file

@ -13,6 +13,7 @@ from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture
@ -94,6 +95,18 @@ def safety_bedrock() -> ProviderFixture:
],
)
@pytest.fixture(scope="session")
def safety_nvidia() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
config=NVIDIASafetyConfig().model_dump(),
)
],
)
SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"]

View file

@ -6,15 +6,9 @@
from pathlib import Path
<<<<<<< Updated upstream
from llama_stack.distribution.datatypes import ModelInput, Provider, ToolGroupInput
from llama_stack.models.llama.sku_list import all_registered_models
=======
from llama_models.sku_list import all_registered_models
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
>>>>>>> Stashed changes
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.nvidia.nvidia import _MODEL_ALIASES
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings