add deprecation_error pointing meta-reference -> inline::llama-guard

This commit is contained in:
Ashwin Bharambe 2024-11-07 15:23:15 -08:00
parent fdaec91747
commit 984ba074e1
9 changed files with 84 additions and 74 deletions

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import importlib
import inspect
import sys
from typing import Any, Dict, List, Set
@ -102,10 +103,14 @@ async def resolve_impls(
)
p = provider_registry[api][provider.provider_type]
if p.deprecation_warning:
if p.deprecation_error:
cprint(p.deprecation_error, "red", attrs=["bold"])
sys.exit(1)
elif p.deprecation_warning:
cprint(
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
"red",
"yellow",
attrs=["bold"],
)
p.deps__ = [a.value for a in p.api_dependencies]

View file

@ -90,6 +90,10 @@ class ProviderSpec(BaseModel):
default=None,
description="If this provider is deprecated, specify the warning message here",
)
deprecation_error: Optional[str] = Field(
default=None,
description="If this provider is deprecated and does NOT work, specify the error message here",
)
# used internally by the resolver; this is a hack for now
deps__: List[str] = Field(default_factory=list)

View file

@ -19,52 +19,6 @@ from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from .config import LlamaGuardConfig
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
def __init__(self, config: LlamaGuardConfig, deps) -> None:
self.config = config
self.inference_api = deps[Api.inference]
async def initialize(self) -> None:
self.shield = LlamaGuardShield(
model=self.config.model,
inference_api=self.inference_api,
excluded_categories=self.config.excluded_categories,
)
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: ShieldDef) -> None:
raise ValueError("Registering dynamic shields is not supported")
async def list_shields(self) -> List[ShieldDef]:
return [
ShieldDef(
identifier=ShieldType.llama_guard.value,
shield_type=ShieldType.llama_guard.value,
params={},
),
]
async def run_shield(
self,
identifier: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
shield_def = await self.shield_store.get_shield(identifier)
if not shield_def:
raise ValueError(f"Unknown shield {identifier}")
messages = messages.copy()
# some shields like llama-guard require the first message to be a user message
# since this might be a tool call, first role might not be user
if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content)
return await self.shield.run(messages)
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
SAFE_RESPONSE = "safe"
@ -158,6 +112,52 @@ PROMPT_TEMPLATE = Template(
)
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
def __init__(self, config: LlamaGuardConfig, deps) -> None:
self.config = config
self.inference_api = deps[Api.inference]
async def initialize(self) -> None:
self.shield = LlamaGuardShield(
model=self.config.model,
inference_api=self.inference_api,
excluded_categories=self.config.excluded_categories,
)
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: ShieldDef) -> None:
raise ValueError("Registering dynamic shields is not supported")
async def list_shields(self) -> List[ShieldDef]:
return [
ShieldDef(
identifier=ShieldType.llama_guard.value,
shield_type=ShieldType.llama_guard.value,
params={},
),
]
async def run_shield(
self,
identifier: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
shield_def = await self.shield_store.get_shield(identifier)
if not shield_def:
raise ValueError(f"Unknown shield {identifier}")
messages = messages.copy()
# some shields like llama-guard require the first message to be a user message
# since this might be a tool call, first role might not be user
if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content)
return await self.shield.run(messages)
class LlamaGuardShield:
def __init__(
self,

View file

@ -38,11 +38,11 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
module="llama_stack.providers.inline.memory.faiss",
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
deprecation_warning="Please use the `faiss` provider instead.",
deprecation_warning="Please use the `inline::faiss` provider instead.",
),
InlineProviderSpec(
api=Api.memory,
provider_type="faiss",
provider_type="inline::faiss",
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
module="llama_stack.providers.inline.memory.faiss",
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",

View file

@ -29,11 +29,18 @@ def available_providers() -> List[ProviderSpec]:
api_dependencies=[
Api.inference,
],
deprecation_warning="Please use the `llama-guard` / `prompt-guard` / `code-scanner` providers instead.",
deprecation_error="""
Provider `meta-reference` for API `safety` does not work with the latest Llama Stack.
- if you are using Llama Guard v3, please use the `inline::llama-guard` provider instead.
- if you are using Prompt Guard, please use the `inline::prompt-guard` provider instead.
- if you are using Code Scanner, please use the `inline::code-scanner` provider instead.
""",
),
InlineProviderSpec(
api=Api.safety,
provider_type="llama-guard",
provider_type="inline::llama-guard",
pip_packages=[],
module="llama_stack.providers.inline.safety.llama_guard",
config_class="llama_stack.providers.inline.safety.llama_guard.LlamaGuardConfig",
@ -43,7 +50,7 @@ def available_providers() -> List[ProviderSpec]:
),
InlineProviderSpec(
api=Api.safety,
provider_type="prompt-guard",
provider_type="inline::prompt-guard",
pip_packages=[
"transformers",
"torch --index-url https://download.pytorch.org/whl/cpu",
@ -56,7 +63,7 @@ def available_providers() -> List[ProviderSpec]:
),
InlineProviderSpec(
api=Api.safety,
provider_type="code-scanner",
provider_type="inline::code-scanner",
pip_packages=[
"codeshield",
],

View file

@ -80,6 +80,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
continue
llama_model = ollama_to_llama[r["model"]]
print(f"Found model {llama_model} in Ollama")
ret.append(
Model(
identifier=llama_model,

View file

@ -18,7 +18,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"inference": "meta_reference",
"safety": "meta_reference",
"safety": "llama_guard",
"memory": "meta_reference",
"agents": "meta_reference",
},
@ -28,7 +28,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"inference": "ollama",
"safety": "meta_reference",
"safety": "llama_guard",
"memory": "meta_reference",
"agents": "meta_reference",
},
@ -38,7 +38,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"inference": "together",
"safety": "meta_reference",
"safety": "llama_guard",
# make this work with Weaviate which is what the together distro supports
"memory": "meta_reference",
"agents": "meta_reference",

View file

@ -16,7 +16,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"inference": "meta_reference",
"safety": "meta_reference",
"safety": "llama_guard",
},
id="meta_reference",
marks=pytest.mark.meta_reference,
@ -24,7 +24,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"inference": "ollama",
"safety": "meta_reference",
"safety": "llama_guard",
},
id="ollama",
marks=pytest.mark.ollama,
@ -32,7 +32,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"inference": "together",
"safety": "meta_reference",
"safety": "llama_guard",
},
id="together",
marks=pytest.mark.together,

View file

@ -10,12 +10,9 @@ import pytest_asyncio
from llama_stack.apis.shields import ShieldType
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.safety.meta_reference import (
LlamaGuardShieldConfig,
SafetyConfig,
)
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
from llama_stack.providers.tests.env import get_env_or_fail
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from ..conftest import ProviderFixture, remote_stack_fixture
@ -34,17 +31,13 @@ def safety_model(request):
@pytest.fixture(scope="session")
def safety_meta_reference(safety_model) -> ProviderFixture:
def safety_llama_guard(safety_model) -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="meta-reference",
provider_type="meta-reference",
config=SafetyConfig(
llama_guard_shield=LlamaGuardShieldConfig(
model=safety_model,
),
).model_dump(),
provider_id="inline::llama-guard",
provider_type="inline::llama-guard",
config=LlamaGuardConfig(model=safety_model).model_dump(),
)
],
)
@ -63,7 +56,7 @@ def safety_bedrock() -> ProviderFixture:
)
SAFETY_FIXTURES = ["meta_reference", "bedrock", "remote"]
SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"]
@pytest_asyncio.fixture(scope="session")