diff --git a/llama_stack/providers/remote/safety/nvidia/nvidia.py b/llama_stack/providers/remote/safety/nvidia/nvidia.py index 6da5bc54a..c1918f0bc 100644 --- a/llama_stack/providers/remote/safety/nvidia/nvidia.py +++ b/llama_stack/providers/remote/safety/nvidia/nvidia.py @@ -49,8 +49,6 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): self.shield = NeMoGuardrails(self.config, shield.provider_resource_id) return await self.shield.run(messages) - - class NeMoGuardrails: def __init__( @@ -107,4 +105,4 @@ class NeMoGuardrails: metadata=metadata, ) ) - return RunShieldResponse(violation=None) \ No newline at end of file + return RunShieldResponse(violation=None) diff --git a/llama_stack/providers/tests/safety/conftest.py b/llama_stack/providers/tests/safety/conftest.py index 3e46f0d50..6452c90eb 100644 --- a/llama_stack/providers/tests/safety/conftest.py +++ b/llama_stack/providers/tests/safety/conftest.py @@ -51,11 +51,19 @@ DEFAULT_PROVIDER_COMBINATIONS = [ id="remote", marks=pytest.mark.remote, ), + pytest.param( + { + "inference": "nvidia", + "safety": "nvidia", + }, + id="nvidia", + marks=pytest.mark.meta_reference, + ), ] def pytest_configure(config): - for mark in ["meta_reference", "ollama", "together", "remote", "bedrock"]: + for mark in ["meta_reference", "ollama", "together", "remote", "bedrock", "nvidia"]: config.addinivalue_line( "markers", f"{mark}: marks tests as {mark} specific", diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index a0c00ee7c..67d1ecac4 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -13,6 +13,7 @@ from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig +from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig from llama_stack.providers.tests.resolver import construct_stack_for_test from ..conftest import ProviderFixture, remote_stack_fixture @@ -94,6 +95,18 @@ def safety_bedrock() -> ProviderFixture: ], ) +@pytest.fixture(scope="session") +def safety_nvidia() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="nvidia", + provider_type="remote::nvidia", + config=NVIDIASafetyConfig().model_dump(), + ) + ], + ) + SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"] diff --git a/llama_stack/templates/nvidia/nvidia.py b/llama_stack/templates/nvidia/nvidia.py index f7100ac72..b6f85949b 100644 --- a/llama_stack/templates/nvidia/nvidia.py +++ b/llama_stack/templates/nvidia/nvidia.py @@ -6,15 +6,9 @@ from pathlib import Path -<<<<<<< Updated upstream -from llama_stack.distribution.datatypes import ModelInput, Provider, ToolGroupInput -from llama_stack.models.llama.sku_list import all_registered_models -======= -from llama_models.sku_list import all_registered_models - from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput +from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig ->>>>>>> Stashed changes from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.nvidia.nvidia import _MODEL_ALIASES from llama_stack.templates.template import DistributionTemplate, RunConfigSettings