added nvidia as safety provider

This commit is contained in:
Chantal D Gama Rose 2025-02-25 08:16:49 +00:00
parent 07a992ef90
commit 0593408c19
14 changed files with 354 additions and 78 deletions

View file

@ -51,11 +51,19 @@ DEFAULT_PROVIDER_COMBINATIONS = [
id="remote",
marks=pytest.mark.remote,
),
pytest.param(
{
"inference": "nvidia",
"safety": "nvidia",
},
id="nvidia",
marks=pytest.mark.nvidia,
),
]
def pytest_configure(config):
for mark in ["meta_reference", "ollama", "together", "remote", "bedrock"]:
for mark in ["meta_reference", "ollama", "together", "remote", "bedrock", "nvidia"]:
config.addinivalue_line(
"markers",
f"{mark}: marks tests as {mark} specific",

View file

@ -13,6 +13,7 @@ from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture
@ -95,7 +96,20 @@ def safety_bedrock() -> ProviderFixture:
)
SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"]
@pytest.fixture(scope="session")
def safety_nvidia() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
config=NVIDIASafetyConfig().model_dump(),
)
],
)
SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote", "nvidia"]
@pytest_asyncio.fixture(scope="session")