added nvidia as safety provider

This commit is contained in:
Chantal D Gama Rose 2025-02-25 08:16:49 +00:00
parent 07a992ef90
commit 0593408c19
14 changed files with 354 additions and 78 deletions

View file

@ -13,6 +13,7 @@ from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture
@ -95,7 +96,20 @@ def safety_bedrock() -> ProviderFixture:
)
SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"]
@pytest.fixture(scope="session")
def safety_nvidia() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
config=NVIDIASafetyConfig().model_dump(),
)
],
)
SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote", "nvidia"]
@pytest_asyncio.fixture(scope="session")