mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
add deprecation_error pointing meta-reference -> inline::llama-guard
This commit is contained in:
parent
fdaec91747
commit
984ba074e1
9 changed files with 84 additions and 74 deletions
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
|
import sys
|
||||||
|
|
||||||
from typing import Any, Dict, List, Set
|
from typing import Any, Dict, List, Set
|
||||||
|
|
||||||
|
@ -102,10 +103,14 @@ async def resolve_impls(
|
||||||
)
|
)
|
||||||
|
|
||||||
p = provider_registry[api][provider.provider_type]
|
p = provider_registry[api][provider.provider_type]
|
||||||
if p.deprecation_warning:
|
if p.deprecation_error:
|
||||||
|
cprint(p.deprecation_error, "red", attrs=["bold"])
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
elif p.deprecation_warning:
|
||||||
cprint(
|
cprint(
|
||||||
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
||||||
"red",
|
"yellow",
|
||||||
attrs=["bold"],
|
attrs=["bold"],
|
||||||
)
|
)
|
||||||
p.deps__ = [a.value for a in p.api_dependencies]
|
p.deps__ = [a.value for a in p.api_dependencies]
|
||||||
|
|
|
@ -90,6 +90,10 @@ class ProviderSpec(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="If this provider is deprecated, specify the warning message here",
|
description="If this provider is deprecated, specify the warning message here",
|
||||||
)
|
)
|
||||||
|
deprecation_error: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="If this provider is deprecated and does NOT work, specify the error message here",
|
||||||
|
)
|
||||||
|
|
||||||
# used internally by the resolver; this is a hack for now
|
# used internally by the resolver; this is a hack for now
|
||||||
deps__: List[str] = Field(default_factory=list)
|
deps__: List[str] = Field(default_factory=list)
|
||||||
|
|
|
@ -19,52 +19,6 @@ from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||||
from .config import LlamaGuardConfig
|
from .config import LlamaGuardConfig
|
||||||
|
|
||||||
|
|
||||||
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|
||||||
def __init__(self, config: LlamaGuardConfig, deps) -> None:
|
|
||||||
self.config = config
|
|
||||||
self.inference_api = deps[Api.inference]
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
self.shield = LlamaGuardShield(
|
|
||||||
model=self.config.model,
|
|
||||||
inference_api=self.inference_api,
|
|
||||||
excluded_categories=self.config.excluded_categories,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def register_shield(self, shield: ShieldDef) -> None:
|
|
||||||
raise ValueError("Registering dynamic shields is not supported")
|
|
||||||
|
|
||||||
async def list_shields(self) -> List[ShieldDef]:
|
|
||||||
return [
|
|
||||||
ShieldDef(
|
|
||||||
identifier=ShieldType.llama_guard.value,
|
|
||||||
shield_type=ShieldType.llama_guard.value,
|
|
||||||
params={},
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
async def run_shield(
|
|
||||||
self,
|
|
||||||
identifier: str,
|
|
||||||
messages: List[Message],
|
|
||||||
params: Dict[str, Any] = None,
|
|
||||||
) -> RunShieldResponse:
|
|
||||||
shield_def = await self.shield_store.get_shield(identifier)
|
|
||||||
if not shield_def:
|
|
||||||
raise ValueError(f"Unknown shield {identifier}")
|
|
||||||
|
|
||||||
messages = messages.copy()
|
|
||||||
# some shields like llama-guard require the first message to be a user message
|
|
||||||
# since this might be a tool call, first role might not be user
|
|
||||||
if len(messages) > 0 and messages[0].role != Role.user.value:
|
|
||||||
messages[0] = UserMessage(content=messages[0].content)
|
|
||||||
|
|
||||||
return await self.shield.run(messages)
|
|
||||||
|
|
||||||
|
|
||||||
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||||
|
|
||||||
SAFE_RESPONSE = "safe"
|
SAFE_RESPONSE = "safe"
|
||||||
|
@ -158,6 +112,52 @@ PROMPT_TEMPLATE = Template(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
|
def __init__(self, config: LlamaGuardConfig, deps) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.inference_api = deps[Api.inference]
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
self.shield = LlamaGuardShield(
|
||||||
|
model=self.config.model,
|
||||||
|
inference_api=self.inference_api,
|
||||||
|
excluded_categories=self.config.excluded_categories,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_shield(self, shield: ShieldDef) -> None:
|
||||||
|
raise ValueError("Registering dynamic shields is not supported")
|
||||||
|
|
||||||
|
async def list_shields(self) -> List[ShieldDef]:
|
||||||
|
return [
|
||||||
|
ShieldDef(
|
||||||
|
identifier=ShieldType.llama_guard.value,
|
||||||
|
shield_type=ShieldType.llama_guard.value,
|
||||||
|
params={},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
async def run_shield(
|
||||||
|
self,
|
||||||
|
identifier: str,
|
||||||
|
messages: List[Message],
|
||||||
|
params: Dict[str, Any] = None,
|
||||||
|
) -> RunShieldResponse:
|
||||||
|
shield_def = await self.shield_store.get_shield(identifier)
|
||||||
|
if not shield_def:
|
||||||
|
raise ValueError(f"Unknown shield {identifier}")
|
||||||
|
|
||||||
|
messages = messages.copy()
|
||||||
|
# some shields like llama-guard require the first message to be a user message
|
||||||
|
# since this might be a tool call, first role might not be user
|
||||||
|
if len(messages) > 0 and messages[0].role != Role.user.value:
|
||||||
|
messages[0] = UserMessage(content=messages[0].content)
|
||||||
|
|
||||||
|
return await self.shield.run(messages)
|
||||||
|
|
||||||
|
|
||||||
class LlamaGuardShield:
|
class LlamaGuardShield:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -38,11 +38,11 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
||||||
module="llama_stack.providers.inline.memory.faiss",
|
module="llama_stack.providers.inline.memory.faiss",
|
||||||
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
||||||
deprecation_warning="Please use the `faiss` provider instead.",
|
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.memory,
|
api=Api.memory,
|
||||||
provider_type="faiss",
|
provider_type="inline::faiss",
|
||||||
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
||||||
module="llama_stack.providers.inline.memory.faiss",
|
module="llama_stack.providers.inline.memory.faiss",
|
||||||
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
||||||
|
|
|
@ -29,11 +29,18 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
api_dependencies=[
|
api_dependencies=[
|
||||||
Api.inference,
|
Api.inference,
|
||||||
],
|
],
|
||||||
deprecation_warning="Please use the `llama-guard` / `prompt-guard` / `code-scanner` providers instead.",
|
deprecation_error="""
|
||||||
|
Provider `meta-reference` for API `safety` does not work with the latest Llama Stack.
|
||||||
|
|
||||||
|
- if you are using Llama Guard v3, please use the `inline::llama-guard` provider instead.
|
||||||
|
- if you are using Prompt Guard, please use the `inline::prompt-guard` provider instead.
|
||||||
|
- if you are using Code Scanner, please use the `inline::code-scanner` provider instead.
|
||||||
|
|
||||||
|
""",
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.safety,
|
api=Api.safety,
|
||||||
provider_type="llama-guard",
|
provider_type="inline::llama-guard",
|
||||||
pip_packages=[],
|
pip_packages=[],
|
||||||
module="llama_stack.providers.inline.safety.llama_guard",
|
module="llama_stack.providers.inline.safety.llama_guard",
|
||||||
config_class="llama_stack.providers.inline.safety.llama_guard.LlamaGuardConfig",
|
config_class="llama_stack.providers.inline.safety.llama_guard.LlamaGuardConfig",
|
||||||
|
@ -43,7 +50,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.safety,
|
api=Api.safety,
|
||||||
provider_type="prompt-guard",
|
provider_type="inline::prompt-guard",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"transformers",
|
"transformers",
|
||||||
"torch --index-url https://download.pytorch.org/whl/cpu",
|
"torch --index-url https://download.pytorch.org/whl/cpu",
|
||||||
|
@ -56,7 +63,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.safety,
|
api=Api.safety,
|
||||||
provider_type="code-scanner",
|
provider_type="inline::code-scanner",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"codeshield",
|
"codeshield",
|
||||||
],
|
],
|
||||||
|
|
|
@ -80,6 +80,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
llama_model = ollama_to_llama[r["model"]]
|
llama_model = ollama_to_llama[r["model"]]
|
||||||
|
print(f"Found model {llama_model} in Ollama")
|
||||||
ret.append(
|
ret.append(
|
||||||
Model(
|
Model(
|
||||||
identifier=llama_model,
|
identifier=llama_model,
|
||||||
|
|
|
@ -18,7 +18,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
pytest.param(
|
pytest.param(
|
||||||
{
|
{
|
||||||
"inference": "meta_reference",
|
"inference": "meta_reference",
|
||||||
"safety": "meta_reference",
|
"safety": "llama_guard",
|
||||||
"memory": "meta_reference",
|
"memory": "meta_reference",
|
||||||
"agents": "meta_reference",
|
"agents": "meta_reference",
|
||||||
},
|
},
|
||||||
|
@ -28,7 +28,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
pytest.param(
|
pytest.param(
|
||||||
{
|
{
|
||||||
"inference": "ollama",
|
"inference": "ollama",
|
||||||
"safety": "meta_reference",
|
"safety": "llama_guard",
|
||||||
"memory": "meta_reference",
|
"memory": "meta_reference",
|
||||||
"agents": "meta_reference",
|
"agents": "meta_reference",
|
||||||
},
|
},
|
||||||
|
@ -38,7 +38,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
pytest.param(
|
pytest.param(
|
||||||
{
|
{
|
||||||
"inference": "together",
|
"inference": "together",
|
||||||
"safety": "meta_reference",
|
"safety": "llama_guard",
|
||||||
# make this work with Weaviate which is what the together distro supports
|
# make this work with Weaviate which is what the together distro supports
|
||||||
"memory": "meta_reference",
|
"memory": "meta_reference",
|
||||||
"agents": "meta_reference",
|
"agents": "meta_reference",
|
||||||
|
|
|
@ -16,7 +16,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
pytest.param(
|
pytest.param(
|
||||||
{
|
{
|
||||||
"inference": "meta_reference",
|
"inference": "meta_reference",
|
||||||
"safety": "meta_reference",
|
"safety": "llama_guard",
|
||||||
},
|
},
|
||||||
id="meta_reference",
|
id="meta_reference",
|
||||||
marks=pytest.mark.meta_reference,
|
marks=pytest.mark.meta_reference,
|
||||||
|
@ -24,7 +24,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
pytest.param(
|
pytest.param(
|
||||||
{
|
{
|
||||||
"inference": "ollama",
|
"inference": "ollama",
|
||||||
"safety": "meta_reference",
|
"safety": "llama_guard",
|
||||||
},
|
},
|
||||||
id="ollama",
|
id="ollama",
|
||||||
marks=pytest.mark.ollama,
|
marks=pytest.mark.ollama,
|
||||||
|
@ -32,7 +32,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
pytest.param(
|
pytest.param(
|
||||||
{
|
{
|
||||||
"inference": "together",
|
"inference": "together",
|
||||||
"safety": "meta_reference",
|
"safety": "llama_guard",
|
||||||
},
|
},
|
||||||
id="together",
|
id="together",
|
||||||
marks=pytest.mark.together,
|
marks=pytest.mark.together,
|
||||||
|
|
|
@ -10,12 +10,9 @@ import pytest_asyncio
|
||||||
from llama_stack.apis.shields import ShieldType
|
from llama_stack.apis.shields import ShieldType
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
from llama_stack.providers.inline.safety.meta_reference import (
|
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
|
||||||
LlamaGuardShieldConfig,
|
|
||||||
SafetyConfig,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
|
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
|
||||||
from llama_stack.providers.tests.env import get_env_or_fail
|
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||||
|
|
||||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
|
@ -34,17 +31,13 @@ def safety_model(request):
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def safety_meta_reference(safety_model) -> ProviderFixture:
|
def safety_llama_guard(safety_model) -> ProviderFixture:
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
providers=[
|
providers=[
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="meta-reference",
|
provider_id="inline::llama-guard",
|
||||||
provider_type="meta-reference",
|
provider_type="inline::llama-guard",
|
||||||
config=SafetyConfig(
|
config=LlamaGuardConfig(model=safety_model).model_dump(),
|
||||||
llama_guard_shield=LlamaGuardShieldConfig(
|
|
||||||
model=safety_model,
|
|
||||||
),
|
|
||||||
).model_dump(),
|
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -63,7 +56,7 @@ def safety_bedrock() -> ProviderFixture:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
SAFETY_FIXTURES = ["meta_reference", "bedrock", "remote"]
|
SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"]
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue