setting up tests

This commit is contained in:
Chantal D Gama Rose 2025-02-19 15:46:49 -08:00
parent b90ff9ca16
commit 685deeab8f
4 changed files with 24 additions and 11 deletions

View file

@ -50,8 +50,6 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
return await self.shield.run(messages) return await self.shield.run(messages)
class NeMoGuardrails: class NeMoGuardrails:
def __init__( def __init__(
self, self,

View file

@ -51,11 +51,19 @@ DEFAULT_PROVIDER_COMBINATIONS = [
id="remote", id="remote",
marks=pytest.mark.remote, marks=pytest.mark.remote,
), ),
pytest.param(
{
"inference": "nvidia",
"safety": "nvidia",
},
id="nvidia",
marks=pytest.mark.meta_reference,
),
] ]
def pytest_configure(config): def pytest_configure(config):
for mark in ["meta_reference", "ollama", "together", "remote", "bedrock"]: for mark in ["meta_reference", "ollama", "together", "remote", "bedrock", "nvidia"]:
config.addinivalue_line( config.addinivalue_line(
"markers", "markers",
f"{mark}: marks tests as {mark} specific", f"{mark}: marks tests as {mark} specific",

View file

@ -13,6 +13,7 @@ from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture from ..conftest import ProviderFixture, remote_stack_fixture
@ -94,6 +95,18 @@ def safety_bedrock() -> ProviderFixture:
], ],
) )
@pytest.fixture(scope="session")
def safety_nvidia() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
config=NVIDIASafetyConfig().model_dump(),
)
],
)
SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"] SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"]

View file

@ -6,15 +6,9 @@
from pathlib import Path from pathlib import Path
<<<<<<< Updated upstream
from llama_stack.distribution.datatypes import ModelInput, Provider, ToolGroupInput
from llama_stack.models.llama.sku_list import all_registered_models
=======
from llama_models.sku_list import all_registered_models
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
>>>>>>> Stashed changes
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.nvidia.nvidia import _MODEL_ALIASES from llama_stack.providers.remote.inference.nvidia.nvidia import _MODEL_ALIASES
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings from llama_stack.templates.template import DistributionTemplate, RunConfigSettings