From 984ba074e11edba328eaf7e6aca669f8658bca5a Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 7 Nov 2024 15:23:15 -0800 Subject: [PATCH] add deprecation_error pointing meta-reference -> inline::llama-guard --- llama_stack/distribution/resolver.py | 9 +- llama_stack/providers/datatypes.py | 4 + .../inline/safety/llama_guard/llama_guard.py | 92 +++++++++---------- llama_stack/providers/registry/memory.py | 4 +- llama_stack/providers/registry/safety.py | 15 ++- .../remote/inference/ollama/ollama.py | 1 + .../providers/tests/agents/conftest.py | 6 +- .../providers/tests/safety/conftest.py | 6 +- .../providers/tests/safety/fixtures.py | 21 ++--- 9 files changed, 84 insertions(+), 74 deletions(-) diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index aac7ae5b6..b689b00c9 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import importlib import inspect +import sys from typing import Any, Dict, List, Set @@ -102,10 +103,14 @@ async def resolve_impls( ) p = provider_registry[api][provider.provider_type] - if p.deprecation_warning: + if p.deprecation_error: + cprint(p.deprecation_error, "red", attrs=["bold"]) + sys.exit(1) + + elif p.deprecation_warning: cprint( f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}", - "red", + "yellow", attrs=["bold"], ) p.deps__ = [a.value for a in p.api_dependencies] diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index cacfa39d1..7aa2b976f 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -90,6 +90,10 @@ class ProviderSpec(BaseModel): default=None, description="If this provider is deprecated, specify the warning message here", ) + deprecation_error: Optional[str] = Field( + default=None, + description="If this provider is deprecated and does NOT work, specify the error message here", + ) # used internally by the resolver; this is a hack for now deps__: List[str] = Field(default_factory=list) diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index bac153a4d..17577b0c9 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -19,52 +19,6 @@ from llama_stack.providers.datatypes import ShieldsProtocolPrivate from .config import LlamaGuardConfig -class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): - def __init__(self, config: LlamaGuardConfig, deps) -> None: - self.config = config - self.inference_api = deps[Api.inference] - - async def initialize(self) -> None: - self.shield = LlamaGuardShield( - model=self.config.model, - inference_api=self.inference_api, - excluded_categories=self.config.excluded_categories, - ) - - async def shutdown(self) -> None: - pass - - async def register_shield(self, shield: ShieldDef) -> None: - raise ValueError("Registering dynamic shields is not supported") - - async def list_shields(self) -> List[ShieldDef]: - return [ - ShieldDef( - identifier=ShieldType.llama_guard.value, - shield_type=ShieldType.llama_guard.value, - params={}, - ), - ] - - async def run_shield( - self, - identifier: str, - messages: List[Message], - params: Dict[str, Any] = None, - ) -> RunShieldResponse: - shield_def = await self.shield_store.get_shield(identifier) - if not shield_def: - raise ValueError(f"Unknown shield {identifier}") - - messages = messages.copy() - # some shields like llama-guard require the first message to be a user message - # since this might be a tool call, first role might not be user - if len(messages) > 0 and messages[0].role != Role.user.value: - messages[0] = UserMessage(content=messages[0].content) - - return await self.shield.run(messages) - - CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" SAFE_RESPONSE = "safe" @@ -158,6 +112,52 @@ PROMPT_TEMPLATE = Template( ) +class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): + def __init__(self, config: LlamaGuardConfig, deps) -> None: + self.config = config + self.inference_api = deps[Api.inference] + + async def initialize(self) -> None: + self.shield = LlamaGuardShield( + model=self.config.model, + inference_api=self.inference_api, + excluded_categories=self.config.excluded_categories, + ) + + async def shutdown(self) -> None: + pass + + async def register_shield(self, shield: ShieldDef) -> None: + raise ValueError("Registering dynamic shields is not supported") + + async def list_shields(self) -> List[ShieldDef]: + return [ + ShieldDef( + identifier=ShieldType.llama_guard.value, + shield_type=ShieldType.llama_guard.value, + params={}, + ), + ] + + async def run_shield( + self, + identifier: str, + messages: List[Message], + params: Dict[str, Any] = None, + ) -> RunShieldResponse: + shield_def = await self.shield_store.get_shield(identifier) + if not shield_def: + raise ValueError(f"Unknown shield {identifier}") + + messages = messages.copy() + # some shields like llama-guard require the first message to be a user message + # since this might be a tool call, first role might not be user + if len(messages) > 0 and messages[0].role != Role.user.value: + messages[0] = UserMessage(content=messages[0].content) + + return await self.shield.run(messages) + + class LlamaGuardShield: def __init__( self, diff --git a/llama_stack/providers/registry/memory.py b/llama_stack/providers/registry/memory.py index 93ecb7c13..50fd64d7b 100644 --- a/llama_stack/providers/registry/memory.py +++ b/llama_stack/providers/registry/memory.py @@ -38,11 +38,11 @@ def available_providers() -> List[ProviderSpec]: pip_packages=EMBEDDING_DEPS + ["faiss-cpu"], module="llama_stack.providers.inline.memory.faiss", config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig", - deprecation_warning="Please use the `faiss` provider instead.", + deprecation_warning="Please use the `inline::faiss` provider instead.", ), InlineProviderSpec( api=Api.memory, - provider_type="faiss", + provider_type="inline::faiss", pip_packages=EMBEDDING_DEPS + ["faiss-cpu"], module="llama_stack.providers.inline.memory.faiss", config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig", diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index 668419338..3479671b2 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -29,11 +29,18 @@ def available_providers() -> List[ProviderSpec]: api_dependencies=[ Api.inference, ], - deprecation_warning="Please use the `llama-guard` / `prompt-guard` / `code-scanner` providers instead.", + deprecation_error=""" +Provider `meta-reference` for API `safety` does not work with the latest Llama Stack. + +- if you are using Llama Guard v3, please use the `inline::llama-guard` provider instead. +- if you are using Prompt Guard, please use the `inline::prompt-guard` provider instead. +- if you are using Code Scanner, please use the `inline::code-scanner` provider instead. + + """, ), InlineProviderSpec( api=Api.safety, - provider_type="llama-guard", + provider_type="inline::llama-guard", pip_packages=[], module="llama_stack.providers.inline.safety.llama_guard", config_class="llama_stack.providers.inline.safety.llama_guard.LlamaGuardConfig", @@ -43,7 +50,7 @@ def available_providers() -> List[ProviderSpec]: ), InlineProviderSpec( api=Api.safety, - provider_type="prompt-guard", + provider_type="inline::prompt-guard", pip_packages=[ "transformers", "torch --index-url https://download.pytorch.org/whl/cpu", @@ -56,7 +63,7 @@ def available_providers() -> List[ProviderSpec]: ), InlineProviderSpec( api=Api.safety, - provider_type="code-scanner", + provider_type="inline::code-scanner", pip_packages=[ "codeshield", ], diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 18cfef50d..938d05c08 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -80,6 +80,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): continue llama_model = ollama_to_llama[r["model"]] + print(f"Found model {llama_model} in Ollama") ret.append( Model( identifier=llama_model, diff --git a/llama_stack/providers/tests/agents/conftest.py b/llama_stack/providers/tests/agents/conftest.py index 7b16242cf..c2e1261f7 100644 --- a/llama_stack/providers/tests/agents/conftest.py +++ b/llama_stack/providers/tests/agents/conftest.py @@ -18,7 +18,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( { "inference": "meta_reference", - "safety": "meta_reference", + "safety": "llama_guard", "memory": "meta_reference", "agents": "meta_reference", }, @@ -28,7 +28,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( { "inference": "ollama", - "safety": "meta_reference", + "safety": "llama_guard", "memory": "meta_reference", "agents": "meta_reference", }, @@ -38,7 +38,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( { "inference": "together", - "safety": "meta_reference", + "safety": "llama_guard", # make this work with Weaviate which is what the together distro supports "memory": "meta_reference", "agents": "meta_reference", diff --git a/llama_stack/providers/tests/safety/conftest.py b/llama_stack/providers/tests/safety/conftest.py index daf16aefc..cb380ce57 100644 --- a/llama_stack/providers/tests/safety/conftest.py +++ b/llama_stack/providers/tests/safety/conftest.py @@ -16,7 +16,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( { "inference": "meta_reference", - "safety": "meta_reference", + "safety": "llama_guard", }, id="meta_reference", marks=pytest.mark.meta_reference, @@ -24,7 +24,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( { "inference": "ollama", - "safety": "meta_reference", + "safety": "llama_guard", }, id="ollama", marks=pytest.mark.ollama, @@ -32,7 +32,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( { "inference": "together", - "safety": "meta_reference", + "safety": "llama_guard", }, id="together", marks=pytest.mark.together, diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index 035288cf8..c2beff7e2 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -10,12 +10,9 @@ import pytest_asyncio from llama_stack.apis.shields import ShieldType from llama_stack.distribution.datatypes import Api, Provider -from llama_stack.providers.inline.safety.meta_reference import ( - LlamaGuardShieldConfig, - SafetyConfig, -) +from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig -from llama_stack.providers.tests.env import get_env_or_fail + from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from ..conftest import ProviderFixture, remote_stack_fixture @@ -34,17 +31,13 @@ def safety_model(request): @pytest.fixture(scope="session") -def safety_meta_reference(safety_model) -> ProviderFixture: +def safety_llama_guard(safety_model) -> ProviderFixture: return ProviderFixture( providers=[ Provider( - provider_id="meta-reference", - provider_type="meta-reference", - config=SafetyConfig( - llama_guard_shield=LlamaGuardShieldConfig( - model=safety_model, - ), - ).model_dump(), + provider_id="inline::llama-guard", + provider_type="inline::llama-guard", + config=LlamaGuardConfig(model=safety_model).model_dump(), ) ], ) @@ -63,7 +56,7 @@ def safety_bedrock() -> ProviderFixture: ) -SAFETY_FIXTURES = ["meta_reference", "bedrock", "remote"] +SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"] @pytest_asyncio.fixture(scope="session")