add deprecation_error pointing meta-reference -> inline::llama-guard

This commit is contained in:
Ashwin Bharambe 2024-11-07 15:23:15 -08:00
parent fdaec91747
commit 984ba074e1
9 changed files with 84 additions and 74 deletions

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import importlib import importlib
import inspect import inspect
import sys
from typing import Any, Dict, List, Set from typing import Any, Dict, List, Set
@ -102,10 +103,14 @@ async def resolve_impls(
) )
p = provider_registry[api][provider.provider_type] p = provider_registry[api][provider.provider_type]
if p.deprecation_warning: if p.deprecation_error:
cprint(p.deprecation_error, "red", attrs=["bold"])
sys.exit(1)
elif p.deprecation_warning:
cprint( cprint(
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}", f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
"red", "yellow",
attrs=["bold"], attrs=["bold"],
) )
p.deps__ = [a.value for a in p.api_dependencies] p.deps__ = [a.value for a in p.api_dependencies]

View file

@ -90,6 +90,10 @@ class ProviderSpec(BaseModel):
default=None, default=None,
description="If this provider is deprecated, specify the warning message here", description="If this provider is deprecated, specify the warning message here",
) )
deprecation_error: Optional[str] = Field(
default=None,
description="If this provider is deprecated and does NOT work, specify the error message here",
)
# used internally by the resolver; this is a hack for now # used internally by the resolver; this is a hack for now
deps__: List[str] = Field(default_factory=list) deps__: List[str] = Field(default_factory=list)

View file

@ -19,52 +19,6 @@ from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from .config import LlamaGuardConfig from .config import LlamaGuardConfig
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
def __init__(self, config: LlamaGuardConfig, deps) -> None:
self.config = config
self.inference_api = deps[Api.inference]
async def initialize(self) -> None:
self.shield = LlamaGuardShield(
model=self.config.model,
inference_api=self.inference_api,
excluded_categories=self.config.excluded_categories,
)
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: ShieldDef) -> None:
raise ValueError("Registering dynamic shields is not supported")
async def list_shields(self) -> List[ShieldDef]:
return [
ShieldDef(
identifier=ShieldType.llama_guard.value,
shield_type=ShieldType.llama_guard.value,
params={},
),
]
async def run_shield(
self,
identifier: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
shield_def = await self.shield_store.get_shield(identifier)
if not shield_def:
raise ValueError(f"Unknown shield {identifier}")
messages = messages.copy()
# some shields like llama-guard require the first message to be a user message
# since this might be a tool call, first role might not be user
if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content)
return await self.shield.run(messages)
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
SAFE_RESPONSE = "safe" SAFE_RESPONSE = "safe"
@ -158,6 +112,52 @@ PROMPT_TEMPLATE = Template(
) )
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
def __init__(self, config: LlamaGuardConfig, deps) -> None:
self.config = config
self.inference_api = deps[Api.inference]
async def initialize(self) -> None:
self.shield = LlamaGuardShield(
model=self.config.model,
inference_api=self.inference_api,
excluded_categories=self.config.excluded_categories,
)
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: ShieldDef) -> None:
raise ValueError("Registering dynamic shields is not supported")
async def list_shields(self) -> List[ShieldDef]:
return [
ShieldDef(
identifier=ShieldType.llama_guard.value,
shield_type=ShieldType.llama_guard.value,
params={},
),
]
async def run_shield(
self,
identifier: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
shield_def = await self.shield_store.get_shield(identifier)
if not shield_def:
raise ValueError(f"Unknown shield {identifier}")
messages = messages.copy()
# some shields like llama-guard require the first message to be a user message
# since this might be a tool call, first role might not be user
if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content)
return await self.shield.run(messages)
class LlamaGuardShield: class LlamaGuardShield:
def __init__( def __init__(
self, self,

View file

@ -38,11 +38,11 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"], pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
module="llama_stack.providers.inline.memory.faiss", module="llama_stack.providers.inline.memory.faiss",
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig", config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
deprecation_warning="Please use the `faiss` provider instead.", deprecation_warning="Please use the `inline::faiss` provider instead.",
), ),
InlineProviderSpec( InlineProviderSpec(
api=Api.memory, api=Api.memory,
provider_type="faiss", provider_type="inline::faiss",
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"], pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
module="llama_stack.providers.inline.memory.faiss", module="llama_stack.providers.inline.memory.faiss",
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig", config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",

View file

@ -29,11 +29,18 @@ def available_providers() -> List[ProviderSpec]:
api_dependencies=[ api_dependencies=[
Api.inference, Api.inference,
], ],
deprecation_warning="Please use the `llama-guard` / `prompt-guard` / `code-scanner` providers instead.", deprecation_error="""
Provider `meta-reference` for API `safety` does not work with the latest Llama Stack.
- if you are using Llama Guard v3, please use the `inline::llama-guard` provider instead.
- if you are using Prompt Guard, please use the `inline::prompt-guard` provider instead.
- if you are using Code Scanner, please use the `inline::code-scanner` provider instead.
""",
), ),
InlineProviderSpec( InlineProviderSpec(
api=Api.safety, api=Api.safety,
provider_type="llama-guard", provider_type="inline::llama-guard",
pip_packages=[], pip_packages=[],
module="llama_stack.providers.inline.safety.llama_guard", module="llama_stack.providers.inline.safety.llama_guard",
config_class="llama_stack.providers.inline.safety.llama_guard.LlamaGuardConfig", config_class="llama_stack.providers.inline.safety.llama_guard.LlamaGuardConfig",
@ -43,7 +50,7 @@ def available_providers() -> List[ProviderSpec]:
), ),
InlineProviderSpec( InlineProviderSpec(
api=Api.safety, api=Api.safety,
provider_type="prompt-guard", provider_type="inline::prompt-guard",
pip_packages=[ pip_packages=[
"transformers", "transformers",
"torch --index-url https://download.pytorch.org/whl/cpu", "torch --index-url https://download.pytorch.org/whl/cpu",
@ -56,7 +63,7 @@ def available_providers() -> List[ProviderSpec]:
), ),
InlineProviderSpec( InlineProviderSpec(
api=Api.safety, api=Api.safety,
provider_type="code-scanner", provider_type="inline::code-scanner",
pip_packages=[ pip_packages=[
"codeshield", "codeshield",
], ],

View file

@ -80,6 +80,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
continue continue
llama_model = ollama_to_llama[r["model"]] llama_model = ollama_to_llama[r["model"]]
print(f"Found model {llama_model} in Ollama")
ret.append( ret.append(
Model( Model(
identifier=llama_model, identifier=llama_model,

View file

@ -18,7 +18,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param( pytest.param(
{ {
"inference": "meta_reference", "inference": "meta_reference",
"safety": "meta_reference", "safety": "llama_guard",
"memory": "meta_reference", "memory": "meta_reference",
"agents": "meta_reference", "agents": "meta_reference",
}, },
@ -28,7 +28,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param( pytest.param(
{ {
"inference": "ollama", "inference": "ollama",
"safety": "meta_reference", "safety": "llama_guard",
"memory": "meta_reference", "memory": "meta_reference",
"agents": "meta_reference", "agents": "meta_reference",
}, },
@ -38,7 +38,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param( pytest.param(
{ {
"inference": "together", "inference": "together",
"safety": "meta_reference", "safety": "llama_guard",
# make this work with Weaviate which is what the together distro supports # make this work with Weaviate which is what the together distro supports
"memory": "meta_reference", "memory": "meta_reference",
"agents": "meta_reference", "agents": "meta_reference",

View file

@ -16,7 +16,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param( pytest.param(
{ {
"inference": "meta_reference", "inference": "meta_reference",
"safety": "meta_reference", "safety": "llama_guard",
}, },
id="meta_reference", id="meta_reference",
marks=pytest.mark.meta_reference, marks=pytest.mark.meta_reference,
@ -24,7 +24,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param( pytest.param(
{ {
"inference": "ollama", "inference": "ollama",
"safety": "meta_reference", "safety": "llama_guard",
}, },
id="ollama", id="ollama",
marks=pytest.mark.ollama, marks=pytest.mark.ollama,
@ -32,7 +32,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param( pytest.param(
{ {
"inference": "together", "inference": "together",
"safety": "meta_reference", "safety": "llama_guard",
}, },
id="together", id="together",
marks=pytest.mark.together, marks=pytest.mark.together,

View file

@ -10,12 +10,9 @@ import pytest_asyncio
from llama_stack.apis.shields import ShieldType from llama_stack.apis.shields import ShieldType
from llama_stack.distribution.datatypes import Api, Provider from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.safety.meta_reference import ( from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
LlamaGuardShieldConfig,
SafetyConfig,
)
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
from llama_stack.providers.tests.env import get_env_or_fail
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from ..conftest import ProviderFixture, remote_stack_fixture from ..conftest import ProviderFixture, remote_stack_fixture
@ -34,17 +31,13 @@ def safety_model(request):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def safety_meta_reference(safety_model) -> ProviderFixture: def safety_llama_guard(safety_model) -> ProviderFixture:
return ProviderFixture( return ProviderFixture(
providers=[ providers=[
Provider( Provider(
provider_id="meta-reference", provider_id="inline::llama-guard",
provider_type="meta-reference", provider_type="inline::llama-guard",
config=SafetyConfig( config=LlamaGuardConfig(model=safety_model).model_dump(),
llama_guard_shield=LlamaGuardShieldConfig(
model=safety_model,
),
).model_dump(),
) )
], ],
) )
@ -63,7 +56,7 @@ def safety_bedrock() -> ProviderFixture:
) )
SAFETY_FIXTURES = ["meta_reference", "bedrock", "remote"] SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"]
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")