mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-07 11:08:20 +00:00
Merge conflicts
This commit is contained in:
parent
dd1c0876f7
commit
1de5949e48
2 changed files with 0 additions and 241 deletions
|
@ -1,104 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from ..conftest import get_provider_fixture_overrides
|
||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||
from .fixtures import SAFETY_FIXTURES
|
||||
|
||||
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "meta_reference",
|
||||
"safety": "llama_guard",
|
||||
},
|
||||
id="meta_reference",
|
||||
marks=pytest.mark.meta_reference,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "ollama",
|
||||
"safety": "llama_guard",
|
||||
},
|
||||
id="ollama",
|
||||
marks=pytest.mark.ollama,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "together",
|
||||
"safety": "llama_guard",
|
||||
},
|
||||
id="together",
|
||||
marks=pytest.mark.together,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "bedrock",
|
||||
"safety": "bedrock",
|
||||
},
|
||||
id="bedrock",
|
||||
marks=pytest.mark.bedrock,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "remote",
|
||||
"safety": "remote",
|
||||
},
|
||||
id="remote",
|
||||
marks=pytest.mark.remote,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "nvidia",
|
||||
"safety": "nvidia",
|
||||
},
|
||||
id="nvidia",
|
||||
marks=pytest.mark.nvidia,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
for mark in ["meta_reference", "ollama", "together", "remote", "bedrock", "nvidia"]:
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
f"{mark}: marks tests as {mark} specific",
|
||||
)
|
||||
|
||||
|
||||
SAFETY_SHIELD_PARAMS = [
|
||||
pytest.param("meta-llama/Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"),
|
||||
]
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
# We use this method to make sure we have built-in simple combos for safety tests
|
||||
# But a user can also pass in a custom combination via the CLI by doing
|
||||
# `--providers inference=together,safety=meta_reference`
|
||||
|
||||
if "safety_shield" in metafunc.fixturenames:
|
||||
shield_id = metafunc.config.getoption("--safety-shield")
|
||||
if shield_id:
|
||||
params = [pytest.param(shield_id, id="")]
|
||||
else:
|
||||
params = SAFETY_SHIELD_PARAMS
|
||||
for fixture in ["inference_model", "safety_shield"]:
|
||||
metafunc.parametrize(
|
||||
fixture,
|
||||
params,
|
||||
indirect=True,
|
||||
)
|
||||
|
||||
if "safety_stack" in metafunc.fixturenames:
|
||||
available_fixtures = {
|
||||
"inference": INFERENCE_FIXTURES,
|
||||
"safety": SAFETY_FIXTURES,
|
||||
}
|
||||
combinations = (
|
||||
get_provider_fixture_overrides(metafunc.config, available_fixtures) or DEFAULT_PROVIDER_COMBINATIONS
|
||||
)
|
||||
metafunc.parametrize("safety_stack", combinations, indirect=True)
|
|
@ -1,137 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.models import ModelInput
|
||||
from llama_stack.apis.shields import ShieldInput
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
|
||||
from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig
|
||||
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
|
||||
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
from ..env import get_env_or_fail
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def safety_remote() -> ProviderFixture:
|
||||
return remote_stack_fixture()
|
||||
|
||||
|
||||
def safety_model_from_shield(shield_id):
|
||||
if shield_id in ("Bedrock", "CodeScanner", "CodeShield"):
|
||||
return None
|
||||
|
||||
return shield_id
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def safety_shield(request):
|
||||
if hasattr(request, "param"):
|
||||
shield_id = request.param
|
||||
else:
|
||||
shield_id = request.config.getoption("--safety-shield", None)
|
||||
|
||||
if shield_id == "bedrock":
|
||||
shield_id = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER")
|
||||
params = {"guardrailVersion": get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")}
|
||||
else:
|
||||
params = {}
|
||||
|
||||
if not shield_id:
|
||||
return None
|
||||
|
||||
return ShieldInput(
|
||||
shield_id=shield_id,
|
||||
params=params,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def safety_llama_guard() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="llama-guard",
|
||||
provider_type="inline::llama-guard",
|
||||
config=LlamaGuardConfig().model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# TODO: this is not tested yet; we would need to configure the run_shield() test
|
||||
# and parametrize it with the "prompt" for testing depending on the safety fixture
|
||||
# we are using.
|
||||
@pytest.fixture(scope="session")
|
||||
def safety_prompt_guard() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="prompt-guard",
|
||||
provider_type="inline::prompt-guard",
|
||||
config=PromptGuardConfig().model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def safety_bedrock() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="bedrock",
|
||||
provider_type="remote::bedrock",
|
||||
config=BedrockSafetyConfig().model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def safety_nvidia() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="nvidia",
|
||||
provider_type="remote::nvidia",
|
||||
config=NVIDIASafetyConfig().model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote", "nvidia"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def safety_stack(inference_model, safety_shield, request):
|
||||
# We need an inference + safety fixture to test safety
|
||||
fixture_dict = request.param
|
||||
|
||||
providers = {}
|
||||
provider_data = {}
|
||||
for key in ["inference", "safety"]:
|
||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||
providers[key] = fixture.providers
|
||||
if fixture.provider_data:
|
||||
provider_data.update(fixture.provider_data)
|
||||
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.safety, Api.shields, Api.inference],
|
||||
providers,
|
||||
provider_data,
|
||||
models=[ModelInput(model_id=inference_model)],
|
||||
shields=[safety_shield],
|
||||
)
|
||||
|
||||
shield = await test_stack.impls[Api.shields].get_shield(safety_shield.shield_id)
|
||||
return test_stack.impls[Api.safety], test_stack.impls[Api.shields], shield
|
Loading…
Add table
Add a link
Reference in a new issue