add deprecation_error pointing meta-reference -> inline::llama-guard

This commit is contained in:
Ashwin Bharambe 2024-11-07 15:23:15 -08:00
parent fdaec91747
commit 984ba074e1
9 changed files with 84 additions and 74 deletions

View file

@ -16,7 +16,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"inference": "meta_reference",
"safety": "meta_reference",
"safety": "llama_guard",
},
id="meta_reference",
marks=pytest.mark.meta_reference,
@ -24,7 +24,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"inference": "ollama",
"safety": "meta_reference",
"safety": "llama_guard",
},
id="ollama",
marks=pytest.mark.ollama,
@ -32,7 +32,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"inference": "together",
"safety": "meta_reference",
"safety": "llama_guard",
},
id="together",
marks=pytest.mark.together,

View file

@ -10,12 +10,9 @@ import pytest_asyncio
from llama_stack.apis.shields import ShieldType
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.safety.meta_reference import (
LlamaGuardShieldConfig,
SafetyConfig,
)
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
from llama_stack.providers.tests.env import get_env_or_fail
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from ..conftest import ProviderFixture, remote_stack_fixture
@ -34,17 +31,13 @@ def safety_model(request):
@pytest.fixture(scope="session")
def safety_meta_reference(safety_model) -> ProviderFixture:
def safety_llama_guard(safety_model) -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="meta-reference",
provider_type="meta-reference",
config=SafetyConfig(
llama_guard_shield=LlamaGuardShieldConfig(
model=safety_model,
),
).model_dump(),
provider_id="inline::llama-guard",
provider_type="inline::llama-guard",
config=LlamaGuardConfig(model=safety_model).model_dump(),
)
],
)
@ -63,7 +56,7 @@ def safety_bedrock() -> ProviderFixture:
)
SAFETY_FIXTURES = ["meta_reference", "bedrock", "remote"]
SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"]
@pytest_asyncio.fixture(scope="session")