mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-03 19:32:16 +00:00
added nvidia as safety provider
This commit is contained in:
parent
07a992ef90
commit
0593408c19
14 changed files with 354 additions and 78 deletions
|
|
@ -13,6 +13,7 @@ from llama_stack.distribution.datatypes import Api, Provider
|
|||
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
|
||||
from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig
|
||||
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
|
||||
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
|
|
@ -95,7 +96,20 @@ def safety_bedrock() -> ProviderFixture:
|
|||
)
|
||||
|
||||
|
||||
SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"]
|
||||
@pytest.fixture(scope="session")
|
||||
def safety_nvidia() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="nvidia",
|
||||
provider_type="remote::nvidia",
|
||||
config=NVIDIASafetyConfig().model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote", "nvidia"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue