mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-10 04:08:31 +00:00
Merge conflicts
This commit is contained in:
parent
dd1c0876f7
commit
1de5949e48
2 changed files with 0 additions and 241 deletions
|
@ -1,104 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from ..conftest import get_provider_fixture_overrides
|
|
||||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
|
||||||
from .fixtures import SAFETY_FIXTURES
|
|
||||||
|
|
||||||
DEFAULT_PROVIDER_COMBINATIONS = [
|
|
||||||
pytest.param(
|
|
||||||
{
|
|
||||||
"inference": "meta_reference",
|
|
||||||
"safety": "llama_guard",
|
|
||||||
},
|
|
||||||
id="meta_reference",
|
|
||||||
marks=pytest.mark.meta_reference,
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
{
|
|
||||||
"inference": "ollama",
|
|
||||||
"safety": "llama_guard",
|
|
||||||
},
|
|
||||||
id="ollama",
|
|
||||||
marks=pytest.mark.ollama,
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
{
|
|
||||||
"inference": "together",
|
|
||||||
"safety": "llama_guard",
|
|
||||||
},
|
|
||||||
id="together",
|
|
||||||
marks=pytest.mark.together,
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
{
|
|
||||||
"inference": "bedrock",
|
|
||||||
"safety": "bedrock",
|
|
||||||
},
|
|
||||||
id="bedrock",
|
|
||||||
marks=pytest.mark.bedrock,
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
{
|
|
||||||
"inference": "remote",
|
|
||||||
"safety": "remote",
|
|
||||||
},
|
|
||||||
id="remote",
|
|
||||||
marks=pytest.mark.remote,
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
{
|
|
||||||
"inference": "nvidia",
|
|
||||||
"safety": "nvidia",
|
|
||||||
},
|
|
||||||
id="nvidia",
|
|
||||||
marks=pytest.mark.nvidia,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_configure(config):
|
|
||||||
for mark in ["meta_reference", "ollama", "together", "remote", "bedrock", "nvidia"]:
|
|
||||||
config.addinivalue_line(
|
|
||||||
"markers",
|
|
||||||
f"{mark}: marks tests as {mark} specific",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
SAFETY_SHIELD_PARAMS = [
|
|
||||||
pytest.param("meta-llama/Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
|
||||||
# We use this method to make sure we have built-in simple combos for safety tests
|
|
||||||
# But a user can also pass in a custom combination via the CLI by doing
|
|
||||||
# `--providers inference=together,safety=meta_reference`
|
|
||||||
|
|
||||||
if "safety_shield" in metafunc.fixturenames:
|
|
||||||
shield_id = metafunc.config.getoption("--safety-shield")
|
|
||||||
if shield_id:
|
|
||||||
params = [pytest.param(shield_id, id="")]
|
|
||||||
else:
|
|
||||||
params = SAFETY_SHIELD_PARAMS
|
|
||||||
for fixture in ["inference_model", "safety_shield"]:
|
|
||||||
metafunc.parametrize(
|
|
||||||
fixture,
|
|
||||||
params,
|
|
||||||
indirect=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if "safety_stack" in metafunc.fixturenames:
|
|
||||||
available_fixtures = {
|
|
||||||
"inference": INFERENCE_FIXTURES,
|
|
||||||
"safety": SAFETY_FIXTURES,
|
|
||||||
}
|
|
||||||
combinations = (
|
|
||||||
get_provider_fixture_overrides(metafunc.config, available_fixtures) or DEFAULT_PROVIDER_COMBINATIONS
|
|
||||||
)
|
|
||||||
metafunc.parametrize("safety_stack", combinations, indirect=True)
|
|
|
@ -1,137 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import pytest_asyncio
|
|
||||||
|
|
||||||
from llama_stack.apis.models import ModelInput
|
|
||||||
from llama_stack.apis.shields import ShieldInput
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
|
||||||
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
|
|
||||||
from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig
|
|
||||||
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
|
|
||||||
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
|
|
||||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
|
||||||
|
|
||||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
|
||||||
from ..env import get_env_or_fail
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def safety_remote() -> ProviderFixture:
|
|
||||||
return remote_stack_fixture()
|
|
||||||
|
|
||||||
|
|
||||||
def safety_model_from_shield(shield_id):
|
|
||||||
if shield_id in ("Bedrock", "CodeScanner", "CodeShield"):
|
|
||||||
return None
|
|
||||||
|
|
||||||
return shield_id
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def safety_shield(request):
|
|
||||||
if hasattr(request, "param"):
|
|
||||||
shield_id = request.param
|
|
||||||
else:
|
|
||||||
shield_id = request.config.getoption("--safety-shield", None)
|
|
||||||
|
|
||||||
if shield_id == "bedrock":
|
|
||||||
shield_id = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER")
|
|
||||||
params = {"guardrailVersion": get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")}
|
|
||||||
else:
|
|
||||||
params = {}
|
|
||||||
|
|
||||||
if not shield_id:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return ShieldInput(
|
|
||||||
shield_id=shield_id,
|
|
||||||
params=params,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def safety_llama_guard() -> ProviderFixture:
|
|
||||||
return ProviderFixture(
|
|
||||||
providers=[
|
|
||||||
Provider(
|
|
||||||
provider_id="llama-guard",
|
|
||||||
provider_type="inline::llama-guard",
|
|
||||||
config=LlamaGuardConfig().model_dump(),
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: this is not tested yet; we would need to configure the run_shield() test
|
|
||||||
# and parametrize it with the "prompt" for testing depending on the safety fixture
|
|
||||||
# we are using.
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def safety_prompt_guard() -> ProviderFixture:
|
|
||||||
return ProviderFixture(
|
|
||||||
providers=[
|
|
||||||
Provider(
|
|
||||||
provider_id="prompt-guard",
|
|
||||||
provider_type="inline::prompt-guard",
|
|
||||||
config=PromptGuardConfig().model_dump(),
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def safety_bedrock() -> ProviderFixture:
|
|
||||||
return ProviderFixture(
|
|
||||||
providers=[
|
|
||||||
Provider(
|
|
||||||
provider_id="bedrock",
|
|
||||||
provider_type="remote::bedrock",
|
|
||||||
config=BedrockSafetyConfig().model_dump(),
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def safety_nvidia() -> ProviderFixture:
|
|
||||||
return ProviderFixture(
|
|
||||||
providers=[
|
|
||||||
Provider(
|
|
||||||
provider_id="nvidia",
|
|
||||||
provider_type="remote::nvidia",
|
|
||||||
config=NVIDIASafetyConfig().model_dump(),
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote", "nvidia"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
|
||||||
async def safety_stack(inference_model, safety_shield, request):
|
|
||||||
# We need an inference + safety fixture to test safety
|
|
||||||
fixture_dict = request.param
|
|
||||||
|
|
||||||
providers = {}
|
|
||||||
provider_data = {}
|
|
||||||
for key in ["inference", "safety"]:
|
|
||||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
|
||||||
providers[key] = fixture.providers
|
|
||||||
if fixture.provider_data:
|
|
||||||
provider_data.update(fixture.provider_data)
|
|
||||||
|
|
||||||
test_stack = await construct_stack_for_test(
|
|
||||||
[Api.safety, Api.shields, Api.inference],
|
|
||||||
providers,
|
|
||||||
provider_data,
|
|
||||||
models=[ModelInput(model_id=inference_model)],
|
|
||||||
shields=[safety_shield],
|
|
||||||
)
|
|
||||||
|
|
||||||
shield = await test_stack.impls[Api.shields].get_shield(safety_shield.shield_id)
|
|
||||||
return test_stack.impls[Api.safety], test_stack.impls[Api.shields], shield
|
|
Loading…
Add table
Add a link
Reference in a new issue