From 98d0a3c4eee03f3aa4a07cd073be4787e704b264 Mon Sep 17 00:00:00 2001 From: Chantal D Gama Rose Date: Fri, 21 Feb 2025 00:38:53 +0000 Subject: [PATCH] pytest fixes --- llama_stack/providers/remote/safety/nvidia/config.py | 2 +- llama_stack/providers/remote/safety/nvidia/nvidia.py | 3 ++- llama_stack/providers/tests/safety/conftest.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/llama_stack/providers/remote/safety/nvidia/config.py b/llama_stack/providers/remote/safety/nvidia/config.py index d98278c94..c3d7b609d 100644 --- a/llama_stack/providers/remote/safety/nvidia/config.py +++ b/llama_stack/providers/remote/safety/nvidia/config.py @@ -36,7 +36,7 @@ class NVIDIASafetyConfig(BaseModel): variable to set the api_key. Please do not put your API key in code. """ guardrails_service_url: str = Field( - default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://0.0.0.0:7331"), + default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "http://0.0.0.0:7331"), description="The url for accessing the guardrails service", ) config_id: Optional[str] = Field( diff --git a/llama_stack/providers/remote/safety/nvidia/nvidia.py b/llama_stack/providers/remote/safety/nvidia/nvidia.py index 9b5d051dd..39379c2eb 100644 --- a/llama_stack/providers/remote/safety/nvidia/nvidia.py +++ b/llama_stack/providers/remote/safety/nvidia/nvidia.py @@ -59,7 +59,8 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): shield = await self.shield_store.get_shield(shield_id) if not shield: raise ValueError(f"Shield {shield_id} not found") - self.shield = NeMoGuardrails(self.config, SHIELD_IDS_TO_MODEL_MAPPING[shield.shield_id]) + # print(shield.provider_shield_id) + self.shield = NeMoGuardrails(self.config, shield.shield_id) return await self.shield.run(messages) diff --git a/llama_stack/providers/tests/safety/conftest.py b/llama_stack/providers/tests/safety/conftest.py index 6452c90eb..67ca23cb6 100644 --- a/llama_stack/providers/tests/safety/conftest.py +++ b/llama_stack/providers/tests/safety/conftest.py @@ -57,7 +57,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ "safety": "nvidia", }, id="nvidia", - marks=pytest.mark.meta_reference, + marks=pytest.mark.nvidia, ), ]