mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
add deprecation_error pointing meta-reference -> inline::llama-guard
This commit is contained in:
parent
fdaec91747
commit
984ba074e1
9 changed files with 84 additions and 74 deletions
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
import importlib
|
||||
import inspect
|
||||
import sys
|
||||
|
||||
from typing import Any, Dict, List, Set
|
||||
|
||||
|
@ -102,10 +103,14 @@ async def resolve_impls(
|
|||
)
|
||||
|
||||
p = provider_registry[api][provider.provider_type]
|
||||
if p.deprecation_warning:
|
||||
if p.deprecation_error:
|
||||
cprint(p.deprecation_error, "red", attrs=["bold"])
|
||||
sys.exit(1)
|
||||
|
||||
elif p.deprecation_warning:
|
||||
cprint(
|
||||
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
||||
"red",
|
||||
"yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
p.deps__ = [a.value for a in p.api_dependencies]
|
||||
|
|
|
@ -90,6 +90,10 @@ class ProviderSpec(BaseModel):
|
|||
default=None,
|
||||
description="If this provider is deprecated, specify the warning message here",
|
||||
)
|
||||
deprecation_error: Optional[str] = Field(
|
||||
default=None,
|
||||
description="If this provider is deprecated and does NOT work, specify the error message here",
|
||||
)
|
||||
|
||||
# used internally by the resolver; this is a hack for now
|
||||
deps__: List[str] = Field(default_factory=list)
|
||||
|
|
|
@ -19,52 +19,6 @@ from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
|||
from .config import LlamaGuardConfig
|
||||
|
||||
|
||||
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: LlamaGuardConfig, deps) -> None:
|
||||
self.config = config
|
||||
self.inference_api = deps[Api.inference]
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.shield = LlamaGuardShield(
|
||||
model=self.config.model,
|
||||
inference_api=self.inference_api,
|
||||
excluded_categories=self.config.excluded_categories,
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
raise ValueError("Registering dynamic shields is not supported")
|
||||
|
||||
async def list_shields(self) -> List[ShieldDef]:
|
||||
return [
|
||||
ShieldDef(
|
||||
identifier=ShieldType.llama_guard.value,
|
||||
shield_type=ShieldType.llama_guard.value,
|
||||
params={},
|
||||
),
|
||||
]
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
identifier: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
shield_def = await self.shield_store.get_shield(identifier)
|
||||
if not shield_def:
|
||||
raise ValueError(f"Unknown shield {identifier}")
|
||||
|
||||
messages = messages.copy()
|
||||
# some shields like llama-guard require the first message to be a user message
|
||||
# since this might be a tool call, first role might not be user
|
||||
if len(messages) > 0 and messages[0].role != Role.user.value:
|
||||
messages[0] = UserMessage(content=messages[0].content)
|
||||
|
||||
return await self.shield.run(messages)
|
||||
|
||||
|
||||
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||
|
||||
SAFE_RESPONSE = "safe"
|
||||
|
@ -158,6 +112,52 @@ PROMPT_TEMPLATE = Template(
|
|||
)
|
||||
|
||||
|
||||
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: LlamaGuardConfig, deps) -> None:
|
||||
self.config = config
|
||||
self.inference_api = deps[Api.inference]
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.shield = LlamaGuardShield(
|
||||
model=self.config.model,
|
||||
inference_api=self.inference_api,
|
||||
excluded_categories=self.config.excluded_categories,
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
raise ValueError("Registering dynamic shields is not supported")
|
||||
|
||||
async def list_shields(self) -> List[ShieldDef]:
|
||||
return [
|
||||
ShieldDef(
|
||||
identifier=ShieldType.llama_guard.value,
|
||||
shield_type=ShieldType.llama_guard.value,
|
||||
params={},
|
||||
),
|
||||
]
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
identifier: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
shield_def = await self.shield_store.get_shield(identifier)
|
||||
if not shield_def:
|
||||
raise ValueError(f"Unknown shield {identifier}")
|
||||
|
||||
messages = messages.copy()
|
||||
# some shields like llama-guard require the first message to be a user message
|
||||
# since this might be a tool call, first role might not be user
|
||||
if len(messages) > 0 and messages[0].role != Role.user.value:
|
||||
messages[0] = UserMessage(content=messages[0].content)
|
||||
|
||||
return await self.shield.run(messages)
|
||||
|
||||
|
||||
class LlamaGuardShield:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -38,11 +38,11 @@ def available_providers() -> List[ProviderSpec]:
|
|||
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
||||
module="llama_stack.providers.inline.memory.faiss",
|
||||
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
||||
deprecation_warning="Please use the `faiss` provider instead.",
|
||||
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.memory,
|
||||
provider_type="faiss",
|
||||
provider_type="inline::faiss",
|
||||
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
||||
module="llama_stack.providers.inline.memory.faiss",
|
||||
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
||||
|
|
|
@ -29,11 +29,18 @@ def available_providers() -> List[ProviderSpec]:
|
|||
api_dependencies=[
|
||||
Api.inference,
|
||||
],
|
||||
deprecation_warning="Please use the `llama-guard` / `prompt-guard` / `code-scanner` providers instead.",
|
||||
deprecation_error="""
|
||||
Provider `meta-reference` for API `safety` does not work with the latest Llama Stack.
|
||||
|
||||
- if you are using Llama Guard v3, please use the `inline::llama-guard` provider instead.
|
||||
- if you are using Prompt Guard, please use the `inline::prompt-guard` provider instead.
|
||||
- if you are using Code Scanner, please use the `inline::code-scanner` provider instead.
|
||||
|
||||
""",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.safety,
|
||||
provider_type="llama-guard",
|
||||
provider_type="inline::llama-guard",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.inline.safety.llama_guard",
|
||||
config_class="llama_stack.providers.inline.safety.llama_guard.LlamaGuardConfig",
|
||||
|
@ -43,7 +50,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.safety,
|
||||
provider_type="prompt-guard",
|
||||
provider_type="inline::prompt-guard",
|
||||
pip_packages=[
|
||||
"transformers",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu",
|
||||
|
@ -56,7 +63,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.safety,
|
||||
provider_type="code-scanner",
|
||||
provider_type="inline::code-scanner",
|
||||
pip_packages=[
|
||||
"codeshield",
|
||||
],
|
||||
|
|
|
@ -80,6 +80,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
continue
|
||||
|
||||
llama_model = ollama_to_llama[r["model"]]
|
||||
print(f"Found model {llama_model} in Ollama")
|
||||
ret.append(
|
||||
Model(
|
||||
identifier=llama_model,
|
||||
|
|
|
@ -18,7 +18,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
pytest.param(
|
||||
{
|
||||
"inference": "meta_reference",
|
||||
"safety": "meta_reference",
|
||||
"safety": "llama_guard",
|
||||
"memory": "meta_reference",
|
||||
"agents": "meta_reference",
|
||||
},
|
||||
|
@ -28,7 +28,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
pytest.param(
|
||||
{
|
||||
"inference": "ollama",
|
||||
"safety": "meta_reference",
|
||||
"safety": "llama_guard",
|
||||
"memory": "meta_reference",
|
||||
"agents": "meta_reference",
|
||||
},
|
||||
|
@ -38,7 +38,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
pytest.param(
|
||||
{
|
||||
"inference": "together",
|
||||
"safety": "meta_reference",
|
||||
"safety": "llama_guard",
|
||||
# make this work with Weaviate which is what the together distro supports
|
||||
"memory": "meta_reference",
|
||||
"agents": "meta_reference",
|
||||
|
|
|
@ -16,7 +16,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
pytest.param(
|
||||
{
|
||||
"inference": "meta_reference",
|
||||
"safety": "meta_reference",
|
||||
"safety": "llama_guard",
|
||||
},
|
||||
id="meta_reference",
|
||||
marks=pytest.mark.meta_reference,
|
||||
|
@ -24,7 +24,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
pytest.param(
|
||||
{
|
||||
"inference": "ollama",
|
||||
"safety": "meta_reference",
|
||||
"safety": "llama_guard",
|
||||
},
|
||||
id="ollama",
|
||||
marks=pytest.mark.ollama,
|
||||
|
@ -32,7 +32,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
pytest.param(
|
||||
{
|
||||
"inference": "together",
|
||||
"safety": "meta_reference",
|
||||
"safety": "llama_guard",
|
||||
},
|
||||
id="together",
|
||||
marks=pytest.mark.together,
|
||||
|
|
|
@ -10,12 +10,9 @@ import pytest_asyncio
|
|||
from llama_stack.apis.shields import ShieldType
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.providers.inline.safety.meta_reference import (
|
||||
LlamaGuardShieldConfig,
|
||||
SafetyConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
|
||||
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
|
||||
from llama_stack.providers.tests.env import get_env_or_fail
|
||||
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
|
@ -34,17 +31,13 @@ def safety_model(request):
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def safety_meta_reference(safety_model) -> ProviderFixture:
|
||||
def safety_llama_guard(safety_model) -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="meta-reference",
|
||||
config=SafetyConfig(
|
||||
llama_guard_shield=LlamaGuardShieldConfig(
|
||||
model=safety_model,
|
||||
),
|
||||
).model_dump(),
|
||||
provider_id="inline::llama-guard",
|
||||
provider_type="inline::llama-guard",
|
||||
config=LlamaGuardConfig(model=safety_model).model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
@ -63,7 +56,7 @@ def safety_bedrock() -> ProviderFixture:
|
|||
)
|
||||
|
||||
|
||||
SAFETY_FIXTURES = ["meta_reference", "bedrock", "remote"]
|
||||
SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue