pytest fixes

This commit is contained in:
Chantal D Gama Rose 2025-02-21 00:38:53 +00:00
parent 3b0a4d0f5e
commit 98d0a3c4ee
3 changed files with 4 additions and 3 deletions

View file

@ -36,7 +36,7 @@ class NVIDIASafetyConfig(BaseModel):
variable to set the api_key. Please do not put your API key in code.
"""
guardrails_service_url: str = Field(
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://0.0.0.0:7331"),
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "http://0.0.0.0:7331"),
description="The url for accessing the guardrails service",
)
config_id: Optional[str] = Field(

View file

@ -59,7 +59,8 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Shield {shield_id} not found")
self.shield = NeMoGuardrails(self.config, SHIELD_IDS_TO_MODEL_MAPPING[shield.shield_id])
# print(shield.provider_shield_id)
self.shield = NeMoGuardrails(self.config, shield.shield_id)
return await self.shield.run(messages)

View file

@ -57,7 +57,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"safety": "nvidia",
},
id="nvidia",
marks=pytest.mark.meta_reference,
marks=pytest.mark.nvidia,
),
]