mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
pytest fixes
This commit is contained in:
parent
3b0a4d0f5e
commit
98d0a3c4ee
3 changed files with 4 additions and 3 deletions
|
@ -36,7 +36,7 @@ class NVIDIASafetyConfig(BaseModel):
|
|||
variable to set the api_key. Please do not put your API key in code.
|
||||
"""
|
||||
guardrails_service_url: str = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://0.0.0.0:7331"),
|
||||
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "http://0.0.0.0:7331"),
|
||||
description="The url for accessing the guardrails service",
|
||||
)
|
||||
config_id: Optional[str] = Field(
|
||||
|
|
|
@ -59,7 +59,8 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
raise ValueError(f"Shield {shield_id} not found")
|
||||
self.shield = NeMoGuardrails(self.config, SHIELD_IDS_TO_MODEL_MAPPING[shield.shield_id])
|
||||
# print(shield.provider_shield_id)
|
||||
self.shield = NeMoGuardrails(self.config, shield.shield_id)
|
||||
return await self.shield.run(messages)
|
||||
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
"safety": "nvidia",
|
||||
},
|
||||
id="nvidia",
|
||||
marks=pytest.mark.meta_reference,
|
||||
marks=pytest.mark.nvidia,
|
||||
),
|
||||
]
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue