From c1f7ba3aed141e095ba83db6d3df934f8df77eb0 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 11 Nov 2024 09:29:18 -0800 Subject: [PATCH] Split safety into (llama-guard, prompt-guard, code-scanner) (#400) Splits the meta-reference safety implementation into three distinct providers: - inline::llama-guard - inline::prompt-guard - inline::code-scanner Note that this PR is a backward incompatible change to the llama stack server. I have added deprecation_error field to ProviderSpec -- the server reads it and immediately barfs. This is used to direct the user with a specific message on what action to perform. An automagical "config upgrade" is a bit too much work to implement right now :/ (Note that we will be gradually prefixing all inline providers with inline:: -- I am only doing this for this set of new providers because otherwise existing configuration files will break even more badly.) --- distributions/dell-tgi/run.yaml | 15 +- distributions/fireworks/run.yaml | 16 +- distributions/inline-vllm/run.yaml | 17 +- distributions/meta-reference-gpu/run.yaml | 15 +- .../meta-reference-quantized-gpu/run.yaml | 24 ++- distributions/ollama-gpu/run.yaml | 15 +- distributions/ollama/run.yaml | 15 +- distributions/remote-vllm/run.yaml | 15 +- distributions/tgi/run.yaml | 15 +- distributions/together/run.yaml | 15 +- .../distribution_dev/building_distro.md | 8 +- llama_stack/distribution/resolver.py | 14 +- llama_stack/distribution/server/server.py | 11 +- llama_stack/providers/datatypes.py | 4 + .../code_scanner}/__init__.py | 0 .../code_scanner}/code_scanner.py | 2 +- .../code_scanner}/config.py | 2 +- .../inline/safety/llama_guard/__init__.py | 19 +++ .../{meta_reference => llama_guard}/config.py | 15 +- .../llama_guard.py | 80 +++++++--- .../inline/safety/meta_reference/__init__.py | 17 -- .../inline/safety/meta_reference/base.py | 57 ------- .../safety/meta_reference/prompt_guard.py | 145 ------------------ .../inline/safety/meta_reference/safety.py | 107 ------------- .../inline/safety/prompt_guard/__init__.py | 15 ++ .../inline/safety/prompt_guard/config.py | 25 +++ .../safety/prompt_guard/prompt_guard.py | 120 +++++++++++++++ llama_stack/providers/registry/memory.py | 4 +- llama_stack/providers/registry/safety.py | 47 ++++-- .../remote/inference/bedrock/__init__.py | 3 +- .../remote/inference/ollama/ollama.py | 1 + .../providers/tests/agents/conftest.py | 6 +- .../providers/tests/inference/fixtures.py | 5 +- .../providers/tests/safety/conftest.py | 6 +- .../providers/tests/safety/fixtures.py | 57 +++++-- llama_stack/templates/bedrock/build.yaml | 4 +- llama_stack/templates/databricks/build.yaml | 4 +- llama_stack/templates/fireworks/build.yaml | 2 +- llama_stack/templates/hf-endpoint/build.yaml | 4 +- .../templates/hf-serverless/build.yaml | 4 +- llama_stack/templates/inline-vllm/build.yaml | 2 +- .../templates/meta-reference-gpu/build.yaml | 2 +- .../meta-reference-quantized-gpu/build.yaml | 2 +- llama_stack/templates/ollama/build.yaml | 2 +- llama_stack/templates/remote-vllm/build.yaml | 2 +- llama_stack/templates/tgi/build.yaml | 2 +- llama_stack/templates/together/build.yaml | 2 +- 47 files changed, 464 insertions(+), 500 deletions(-) rename llama_stack/providers/inline/{meta_reference/codeshield => safety/code_scanner}/__init__.py (100%) rename llama_stack/providers/inline/{meta_reference/codeshield => safety/code_scanner}/code_scanner.py (96%) rename llama_stack/providers/inline/{meta_reference/codeshield => safety/code_scanner}/config.py (87%) create mode 100644 llama_stack/providers/inline/safety/llama_guard/__init__.py rename llama_stack/providers/inline/safety/{meta_reference => llama_guard}/config.py (75%) rename llama_stack/providers/inline/safety/{meta_reference => llama_guard}/llama_guard.py (77%) delete mode 100644 llama_stack/providers/inline/safety/meta_reference/__init__.py delete mode 100644 llama_stack/providers/inline/safety/meta_reference/base.py delete mode 100644 llama_stack/providers/inline/safety/meta_reference/prompt_guard.py delete mode 100644 llama_stack/providers/inline/safety/meta_reference/safety.py create mode 100644 llama_stack/providers/inline/safety/prompt_guard/__init__.py create mode 100644 llama_stack/providers/inline/safety/prompt_guard/config.py create mode 100644 llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py diff --git a/distributions/dell-tgi/run.yaml b/distributions/dell-tgi/run.yaml index c5f6d0aaa..779750c58 100644 --- a/distributions/dell-tgi/run.yaml +++ b/distributions/dell-tgi/run.yaml @@ -19,15 +19,14 @@ providers: url: http://127.0.0.1:80 safety: - provider_id: meta0 - provider_type: meta-reference + provider_type: inline::llama-guard config: - llama_guard_shield: - model: Llama-Guard-3-1B - excluded_categories: [] - disable_input_check: false - disable_output_check: false - prompt_guard_shield: - model: Prompt-Guard-86M + model: Llama-Guard-3-1B + excluded_categories: [] + - provider_id: meta1 + provider_type: inline::prompt-guard + config: + model: Prompt-Guard-86M memory: - provider_id: meta0 provider_type: meta-reference diff --git a/distributions/fireworks/run.yaml b/distributions/fireworks/run.yaml index 4363d86f3..1259c9493 100644 --- a/distributions/fireworks/run.yaml +++ b/distributions/fireworks/run.yaml @@ -19,16 +19,16 @@ providers: url: https://api.fireworks.ai/inference # api_key: safety: + safety: - provider_id: meta0 - provider_type: meta-reference + provider_type: inline::llama-guard config: - llama_guard_shield: - model: Llama-Guard-3-1B - excluded_categories: [] - disable_input_check: false - disable_output_check: false - prompt_guard_shield: - model: Prompt-Guard-86M + model: Llama-Guard-3-1B + excluded_categories: [] + - provider_id: meta1 + provider_type: inline::prompt-guard + config: + model: Prompt-Guard-86M memory: - provider_id: meta0 provider_type: meta-reference diff --git a/distributions/inline-vllm/run.yaml b/distributions/inline-vllm/run.yaml index aadf5c0ce..02499b49a 100644 --- a/distributions/inline-vllm/run.yaml +++ b/distributions/inline-vllm/run.yaml @@ -21,7 +21,7 @@ providers: gpu_memory_utilization: 0.4 enforce_eager: true max_tokens: 4096 - - provider_id: vllm-safety + - provider_id: vllm-inference-safety provider_type: inline::vllm config: model: Llama-Guard-3-1B @@ -31,14 +31,15 @@ providers: max_tokens: 4096 safety: - provider_id: meta0 - provider_type: meta-reference + provider_type: inline::llama-guard config: - llama_guard_shield: - model: Llama-Guard-3-1B - excluded_categories: [] -# Uncomment to use prompt guard -# prompt_guard_shield: -# model: Prompt-Guard-86M + model: Llama-Guard-3-1B + excluded_categories: [] + # Uncomment to use prompt guard + # - provider_id: meta1 + # provider_type: inline::prompt-guard + # config: + # model: Prompt-Guard-86M memory: - provider_id: meta0 provider_type: meta-reference diff --git a/distributions/meta-reference-gpu/run.yaml b/distributions/meta-reference-gpu/run.yaml index ad3187aa1..98a52bed1 100644 --- a/distributions/meta-reference-gpu/run.yaml +++ b/distributions/meta-reference-gpu/run.yaml @@ -13,7 +13,7 @@ apis: - safety providers: inference: - - provider_id: meta-reference-inference + - provider_id: inference0 provider_type: meta-reference config: model: Llama3.2-3B-Instruct @@ -21,7 +21,7 @@ providers: torch_seed: null max_seq_len: 4096 max_batch_size: 1 - - provider_id: meta-reference-safety + - provider_id: inference1 provider_type: meta-reference config: model: Llama-Guard-3-1B @@ -31,11 +31,14 @@ providers: max_batch_size: 1 safety: - provider_id: meta0 - provider_type: meta-reference + provider_type: inline::llama-guard config: - llama_guard_shield: - model: Llama-Guard-3-1B - excluded_categories: [] + model: Llama-Guard-3-1B + excluded_categories: [] + - provider_id: meta1 + provider_type: inline::prompt-guard + config: + model: Prompt-Guard-86M # Uncomment to use prompt guard # prompt_guard_shield: # model: Prompt-Guard-86M diff --git a/distributions/meta-reference-quantized-gpu/run.yaml b/distributions/meta-reference-quantized-gpu/run.yaml index f162502c5..fa8be277d 100644 --- a/distributions/meta-reference-quantized-gpu/run.yaml +++ b/distributions/meta-reference-quantized-gpu/run.yaml @@ -22,17 +22,25 @@ providers: torch_seed: null max_seq_len: 2048 max_batch_size: 1 + - provider_id: meta1 + provider_type: meta-reference-quantized + config: + # not a quantized model ! + model: Llama-Guard-3-1B + quantization: null + torch_seed: null + max_seq_len: 2048 + max_batch_size: 1 safety: - provider_id: meta0 - provider_type: meta-reference + provider_type: inline::llama-guard config: - llama_guard_shield: - model: Llama-Guard-3-1B - excluded_categories: [] - disable_input_check: false - disable_output_check: false - prompt_guard_shield: - model: Prompt-Guard-86M + model: Llama-Guard-3-1B + excluded_categories: [] + - provider_id: meta1 + provider_type: inline::prompt-guard + config: + model: Prompt-Guard-86M memory: - provider_id: meta0 provider_type: meta-reference diff --git a/distributions/ollama-gpu/run.yaml b/distributions/ollama-gpu/run.yaml index 798dabc0b..46c67a1e5 100644 --- a/distributions/ollama-gpu/run.yaml +++ b/distributions/ollama-gpu/run.yaml @@ -19,15 +19,14 @@ providers: url: http://127.0.0.1:14343 safety: - provider_id: meta0 - provider_type: meta-reference + provider_type: inline::llama-guard config: - llama_guard_shield: - model: Llama-Guard-3-1B - excluded_categories: [] - disable_input_check: false - disable_output_check: false - prompt_guard_shield: - model: Prompt-Guard-86M + model: Llama-Guard-3-1B + excluded_categories: [] + - provider_id: meta1 + provider_type: inline::prompt-guard + config: + model: Prompt-Guard-86M memory: - provider_id: meta0 provider_type: meta-reference diff --git a/distributions/ollama/run.yaml b/distributions/ollama/run.yaml index 798dabc0b..46c67a1e5 100644 --- a/distributions/ollama/run.yaml +++ b/distributions/ollama/run.yaml @@ -19,15 +19,14 @@ providers: url: http://127.0.0.1:14343 safety: - provider_id: meta0 - provider_type: meta-reference + provider_type: inline::llama-guard config: - llama_guard_shield: - model: Llama-Guard-3-1B - excluded_categories: [] - disable_input_check: false - disable_output_check: false - prompt_guard_shield: - model: Prompt-Guard-86M + model: Llama-Guard-3-1B + excluded_categories: [] + - provider_id: meta1 + provider_type: inline::prompt-guard + config: + model: Prompt-Guard-86M memory: - provider_id: meta0 provider_type: meta-reference diff --git a/distributions/remote-vllm/run.yaml b/distributions/remote-vllm/run.yaml index 2d0d36370..27d60bd6c 100644 --- a/distributions/remote-vllm/run.yaml +++ b/distributions/remote-vllm/run.yaml @@ -19,15 +19,14 @@ providers: url: http://127.0.0.1:8000 safety: - provider_id: meta0 - provider_type: meta-reference + provider_type: inline::llama-guard config: - llama_guard_shield: - model: Llama-Guard-3-1B - excluded_categories: [] - disable_input_check: false - disable_output_check: false - prompt_guard_shield: - model: Prompt-Guard-86M + model: Llama-Guard-3-1B + excluded_categories: [] + - provider_id: meta1 + provider_type: inline::prompt-guard + config: + model: Prompt-Guard-86M memory: - provider_id: meta0 provider_type: meta-reference diff --git a/distributions/tgi/run.yaml b/distributions/tgi/run.yaml index dc8cb2d2d..dcbb69027 100644 --- a/distributions/tgi/run.yaml +++ b/distributions/tgi/run.yaml @@ -19,15 +19,14 @@ providers: url: http://127.0.0.1:5009 safety: - provider_id: meta0 - provider_type: meta-reference + provider_type: inline::llama-guard config: - llama_guard_shield: - model: Llama-Guard-3-1B - excluded_categories: [] - disable_input_check: false - disable_output_check: false - prompt_guard_shield: - model: Prompt-Guard-86M + model: Llama-Guard-3-1B + excluded_categories: [] + - provider_id: meta1 + provider_type: inline::prompt-guard + config: + model: Prompt-Guard-86M memory: - provider_id: meta0 provider_type: meta-reference diff --git a/distributions/together/run.yaml b/distributions/together/run.yaml index 87fd4dcd7..36ef86056 100644 --- a/distributions/together/run.yaml +++ b/distributions/together/run.yaml @@ -20,15 +20,14 @@ providers: # api_key: safety: - provider_id: meta0 - provider_type: meta-reference + provider_type: inline::llama-guard config: - llama_guard_shield: - model: Llama-Guard-3-1B - excluded_categories: [] - disable_input_check: false - disable_output_check: false - prompt_guard_shield: - model: Prompt-Guard-86M + model: Llama-Guard-3-1B + excluded_categories: [] + - provider_id: meta1 + provider_type: inline::prompt-guard + config: + model: Prompt-Guard-86M memory: - provider_id: meta0 provider_type: remote::weaviate diff --git a/docs/source/distribution_dev/building_distro.md b/docs/source/distribution_dev/building_distro.md index 314792e41..36c504b1b 100644 --- a/docs/source/distribution_dev/building_distro.md +++ b/docs/source/distribution_dev/building_distro.md @@ -36,9 +36,9 @@ the provider types (implementations) you want to use for these APIs. Tip: use to see options for the providers. > Enter provider for API inference: meta-reference -> Enter provider for API safety: meta-reference +> Enter provider for API safety: inline::llama-guard > Enter provider for API agents: meta-reference -> Enter provider for API memory: meta-reference +> Enter provider for API memory: inline::faiss > Enter provider for API datasetio: meta-reference > Enter provider for API scoring: meta-reference > Enter provider for API eval: meta-reference @@ -203,8 +203,8 @@ distribution_spec: description: Like local, but use ollama for running LLM inference providers: inference: remote::ollama - memory: meta-reference - safety: meta-reference + memory: inline::faiss + safety: inline::llama-guard agents: meta-reference telemetry: meta-reference image_type: conda diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index aac7ae5b6..4e7fa0102 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -33,6 +33,10 @@ from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.utils.dynamic import instantiate_class_type +class InvalidProviderError(Exception): + pass + + def api_protocol_map() -> Dict[Api, Any]: return { Api.agents: Agents, @@ -102,16 +106,20 @@ async def resolve_impls( ) p = provider_registry[api][provider.provider_type] - if p.deprecation_warning: + if p.deprecation_error: + cprint(p.deprecation_error, "red", attrs=["bold"]) + raise InvalidProviderError(p.deprecation_error) + + elif p.deprecation_warning: cprint( f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}", - "red", + "yellow", attrs=["bold"], ) p.deps__ = [a.value for a in p.api_dependencies] spec = ProviderWithSpec( spec=p, - **(provider.dict()), + **(provider.model_dump()), ) specs[provider.provider_id] = spec diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 143813780..9193583e1 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -9,6 +9,7 @@ import functools import inspect import json import signal +import sys import traceback from contextlib import asynccontextmanager @@ -41,7 +42,7 @@ from llama_stack.providers.utils.telemetry.tracing import ( ) from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.request_headers import set_request_provider_data -from llama_stack.distribution.resolver import resolve_impls +from llama_stack.distribution.resolver import InvalidProviderError, resolve_impls from .endpoints import get_all_api_endpoints @@ -282,7 +283,13 @@ def main( dist_registry, dist_kvstore = asyncio.run(create_dist_registry(config)) - impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry)) + try: + impls = asyncio.run( + resolve_impls(config, get_provider_registry(), dist_registry) + ) + except InvalidProviderError: + sys.exit(1) + if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index cacfa39d1..7aa2b976f 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -90,6 +90,10 @@ class ProviderSpec(BaseModel): default=None, description="If this provider is deprecated, specify the warning message here", ) + deprecation_error: Optional[str] = Field( + default=None, + description="If this provider is deprecated and does NOT work, specify the error message here", + ) # used internally by the resolver; this is a hack for now deps__: List[str] = Field(default_factory=list) diff --git a/llama_stack/providers/inline/meta_reference/codeshield/__init__.py b/llama_stack/providers/inline/safety/code_scanner/__init__.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/codeshield/__init__.py rename to llama_stack/providers/inline/safety/code_scanner/__init__.py diff --git a/llama_stack/providers/inline/meta_reference/codeshield/code_scanner.py b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py similarity index 96% rename from llama_stack/providers/inline/meta_reference/codeshield/code_scanner.py rename to llama_stack/providers/inline/safety/code_scanner/code_scanner.py index 36ad60b8e..1ca65c9bb 100644 --- a/llama_stack/providers/inline/meta_reference/codeshield/code_scanner.py +++ b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py @@ -25,7 +25,7 @@ class MetaReferenceCodeScannerSafetyImpl(Safety): pass async def register_shield(self, shield: Shield) -> None: - if shield.shield_type != ShieldType.code_scanner.value: + if shield.shield_type != ShieldType.code_scanner: raise ValueError(f"Unsupported safety shield type: {shield.shield_type}") async def run_shield( diff --git a/llama_stack/providers/inline/meta_reference/codeshield/config.py b/llama_stack/providers/inline/safety/code_scanner/config.py similarity index 87% rename from llama_stack/providers/inline/meta_reference/codeshield/config.py rename to llama_stack/providers/inline/safety/code_scanner/config.py index 583c2c95f..75c90d69a 100644 --- a/llama_stack/providers/inline/meta_reference/codeshield/config.py +++ b/llama_stack/providers/inline/safety/code_scanner/config.py @@ -7,5 +7,5 @@ from pydantic import BaseModel -class CodeShieldConfig(BaseModel): +class CodeScannerConfig(BaseModel): pass diff --git a/llama_stack/providers/inline/safety/llama_guard/__init__.py b/llama_stack/providers/inline/safety/llama_guard/__init__.py new file mode 100644 index 000000000..6024f840c --- /dev/null +++ b/llama_stack/providers/inline/safety/llama_guard/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import LlamaGuardConfig + + +async def get_provider_impl(config: LlamaGuardConfig, deps): + from .llama_guard import LlamaGuardSafetyImpl + + assert isinstance( + config, LlamaGuardConfig + ), f"Unexpected config type: {type(config)}" + + impl = LlamaGuardSafetyImpl(config, deps) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/safety/meta_reference/config.py b/llama_stack/providers/inline/safety/llama_guard/config.py similarity index 75% rename from llama_stack/providers/inline/safety/meta_reference/config.py rename to llama_stack/providers/inline/safety/llama_guard/config.py index 14233ad0c..aec856bce 100644 --- a/llama_stack/providers/inline/safety/meta_reference/config.py +++ b/llama_stack/providers/inline/safety/llama_guard/config.py @@ -4,20 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum -from typing import List, Optional +from typing import List from llama_models.sku_list import CoreModelId, safety_models from pydantic import BaseModel, field_validator -class PromptGuardType(Enum): - injection = "injection" - jailbreak = "jailbreak" - - -class LlamaGuardShieldConfig(BaseModel): +class LlamaGuardConfig(BaseModel): model: str = "Llama-Guard-3-1B" excluded_categories: List[str] = [] @@ -41,8 +35,3 @@ class LlamaGuardShieldConfig(BaseModel): f"Invalid model: {model}. Must be one of {permitted_models}" ) return model - - -class SafetyConfig(BaseModel): - llama_guard_shield: Optional[LlamaGuardShieldConfig] = None - enable_prompt_guard: Optional[bool] = False diff --git a/llama_stack/providers/inline/safety/meta_reference/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py similarity index 77% rename from llama_stack/providers/inline/safety/meta_reference/llama_guard.py rename to llama_stack/providers/inline/safety/llama_guard/llama_guard.py index 99b1c29be..9c3ec7750 100644 --- a/llama_stack/providers/inline/safety/meta_reference/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -7,16 +7,21 @@ import re from string import Template -from typing import List, Optional +from typing import Any, Dict, List, Optional from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.safety import * # noqa: F403 +from llama_stack.distribution.datatypes import Api -from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse +from llama_stack.providers.datatypes import ShieldsProtocolPrivate +from .config import LlamaGuardConfig + + +CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" SAFE_RESPONSE = "safe" -_INSTANCE = None CAT_VIOLENT_CRIMES = "Violent Crimes" CAT_NON_VIOLENT_CRIMES = "Non-Violent Crimes" @@ -107,16 +112,52 @@ PROMPT_TEMPLATE = Template( ) -class LlamaGuardShield(ShieldBase): +class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): + def __init__(self, config: LlamaGuardConfig, deps) -> None: + self.config = config + self.inference_api = deps[Api.inference] + + async def initialize(self) -> None: + self.shield = LlamaGuardShield( + model=self.config.model, + inference_api=self.inference_api, + excluded_categories=self.config.excluded_categories, + ) + + async def shutdown(self) -> None: + pass + + async def register_shield(self, shield: Shield) -> None: + print(f"Registering shield {shield}") + if shield.shield_type != ShieldType.llama_guard: + raise ValueError(f"Unsupported shield type: {shield.shield_type}") + + async def run_shield( + self, + shield_id: str, + messages: List[Message], + params: Dict[str, Any] = None, + ) -> RunShieldResponse: + shield = await self.shield_store.get_shield(shield_id) + if not shield: + raise ValueError(f"Unknown shield {shield_id}") + + messages = messages.copy() + # some shields like llama-guard require the first message to be a user message + # since this might be a tool call, first role might not be user + if len(messages) > 0 and messages[0].role != Role.user.value: + messages[0] = UserMessage(content=messages[0].content) + + return await self.shield.run(messages) + + +class LlamaGuardShield: def __init__( self, model: str, inference_api: Inference, - excluded_categories: List[str] = None, - on_violation_action: OnViolationAction = OnViolationAction.RAISE, + excluded_categories: Optional[List[str]] = None, ): - super().__init__(on_violation_action) - if excluded_categories is None: excluded_categories = [] @@ -174,7 +215,7 @@ class LlamaGuardShield(ShieldBase): ) return messages - async def run(self, messages: List[Message]) -> ShieldResponse: + async def run(self, messages: List[Message]) -> RunShieldResponse: messages = self.validate_messages(messages) if self.model == CoreModelId.llama_guard_3_11b_vision.value: @@ -195,8 +236,7 @@ class LlamaGuardShield(ShieldBase): content += event.delta content = content.strip() - shield_response = self.get_shield_response(content) - return shield_response + return self.get_shield_response(content) def build_text_shield_input(self, messages: List[Message]) -> UserMessage: return UserMessage(content=self.build_prompt(messages)) @@ -250,19 +290,23 @@ class LlamaGuardShield(ShieldBase): conversations=conversations_str, ) - def get_shield_response(self, response: str) -> ShieldResponse: + def get_shield_response(self, response: str) -> RunShieldResponse: response = response.strip() if response == SAFE_RESPONSE: - return ShieldResponse(is_violation=False) + return RunShieldResponse(violation=None) + unsafe_code = self.check_unsafe_response(response) if unsafe_code: unsafe_code_list = unsafe_code.split(",") if set(unsafe_code_list).issubset(set(self.excluded_categories)): - return ShieldResponse(is_violation=False) - return ShieldResponse( - is_violation=True, - violation_type=unsafe_code, - violation_return_message=CANNED_RESPONSE_TEXT, + return RunShieldResponse(violation=None) + + return RunShieldResponse( + violation=SafetyViolation( + violation_level=ViolationLevel.ERROR, + user_message=CANNED_RESPONSE_TEXT, + metadata={"violation_type": unsafe_code}, + ), ) raise ValueError(f"Unexpected response: {response}") diff --git a/llama_stack/providers/inline/safety/meta_reference/__init__.py b/llama_stack/providers/inline/safety/meta_reference/__init__.py deleted file mode 100644 index 5e0888de6..000000000 --- a/llama_stack/providers/inline/safety/meta_reference/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .config import LlamaGuardShieldConfig, SafetyConfig # noqa: F401 - - -async def get_provider_impl(config: SafetyConfig, deps): - from .safety import MetaReferenceSafetyImpl - - assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}" - - impl = MetaReferenceSafetyImpl(config, deps) - await impl.initialize() - return impl diff --git a/llama_stack/providers/inline/safety/meta_reference/base.py b/llama_stack/providers/inline/safety/meta_reference/base.py deleted file mode 100644 index 3861a7c4a..000000000 --- a/llama_stack/providers/inline/safety/meta_reference/base.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from abc import ABC, abstractmethod -from typing import List - -from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message -from pydantic import BaseModel -from llama_stack.apis.safety import * # noqa: F403 - -CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" - - -# TODO: clean this up; just remove this type completely -class ShieldResponse(BaseModel): - is_violation: bool - violation_type: Optional[str] = None - violation_return_message: Optional[str] = None - - -# TODO: this is a caller / agent concern -class OnViolationAction(Enum): - IGNORE = 0 - WARN = 1 - RAISE = 2 - - -class ShieldBase(ABC): - def __init__( - self, - on_violation_action: OnViolationAction = OnViolationAction.RAISE, - ): - self.on_violation_action = on_violation_action - - @abstractmethod - async def run(self, messages: List[Message]) -> ShieldResponse: - raise NotImplementedError() - - -def message_content_as_str(message: Message) -> str: - return interleaved_text_media_as_str(message.content) - - -class TextShield(ShieldBase): - def convert_messages_to_text(self, messages: List[Message]) -> str: - return "\n".join([message_content_as_str(m) for m in messages]) - - async def run(self, messages: List[Message]) -> ShieldResponse: - text = self.convert_messages_to_text(messages) - return await self.run_impl(text) - - @abstractmethod - async def run_impl(self, text: str) -> ShieldResponse: - raise NotImplementedError() diff --git a/llama_stack/providers/inline/safety/meta_reference/prompt_guard.py b/llama_stack/providers/inline/safety/meta_reference/prompt_guard.py deleted file mode 100644 index 54e911418..000000000 --- a/llama_stack/providers/inline/safety/meta_reference/prompt_guard.py +++ /dev/null @@ -1,145 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from enum import auto, Enum -from typing import List - -import torch - -from llama_models.llama3.api.datatypes import Message -from termcolor import cprint - -from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield - - -class PromptGuardShield(TextShield): - class Mode(Enum): - INJECTION = auto() - JAILBREAK = auto() - - _instances = {} - _model_cache = None - - @staticmethod - def instance( - model_dir: str, - threshold: float = 0.9, - temperature: float = 1.0, - mode: "PromptGuardShield.Mode" = Mode.JAILBREAK, - on_violation_action=OnViolationAction.RAISE, - ) -> "PromptGuardShield": - action_value = on_violation_action.value - key = (model_dir, threshold, temperature, mode, action_value) - if key not in PromptGuardShield._instances: - PromptGuardShield._instances[key] = PromptGuardShield( - model_dir=model_dir, - threshold=threshold, - temperature=temperature, - mode=mode, - on_violation_action=on_violation_action, - ) - return PromptGuardShield._instances[key] - - def __init__( - self, - model_dir: str, - threshold: float = 0.9, - temperature: float = 1.0, - mode: "PromptGuardShield.Mode" = Mode.JAILBREAK, - on_violation_action: OnViolationAction = OnViolationAction.RAISE, - ): - super().__init__(on_violation_action) - assert ( - model_dir is not None - ), "Must provide a model directory for prompt injection shield" - if temperature <= 0: - raise ValueError("Temperature must be greater than 0") - self.device = "cuda" - if PromptGuardShield._model_cache is None: - from transformers import AutoModelForSequenceClassification, AutoTokenizer - - # load model and tokenizer - tokenizer = AutoTokenizer.from_pretrained(model_dir) - model = AutoModelForSequenceClassification.from_pretrained( - model_dir, device_map=self.device - ) - PromptGuardShield._model_cache = (tokenizer, model) - - self.tokenizer, self.model = PromptGuardShield._model_cache - self.temperature = temperature - self.threshold = threshold - self.mode = mode - - def convert_messages_to_text(self, messages: List[Message]) -> str: - return message_content_as_str(messages[-1]) - - async def run_impl(self, text: str) -> ShieldResponse: - # run model on messages and return response - inputs = self.tokenizer(text, return_tensors="pt") - inputs = {name: tensor.to(self.model.device) for name, tensor in inputs.items()} - with torch.no_grad(): - outputs = self.model(**inputs) - logits = outputs[0] - probabilities = torch.softmax(logits / self.temperature, dim=-1) - score_embedded = probabilities[0, 1].item() - score_malicious = probabilities[0, 2].item() - cprint( - f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}", - color="magenta", - ) - - if self.mode == self.Mode.INJECTION and ( - score_embedded + score_malicious > self.threshold - ): - return ShieldResponse( - is_violation=True, - violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}", - violation_return_message="Sorry, I cannot do this.", - ) - elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold: - return ShieldResponse( - is_violation=True, - violation_type=f"prompt_injection:malicious={score_malicious}", - violation_return_message="Sorry, I cannot do this.", - ) - - return ShieldResponse( - is_violation=False, - ) - - -class JailbreakShield(PromptGuardShield): - def __init__( - self, - model_dir: str, - threshold: float = 0.9, - temperature: float = 1.0, - on_violation_action: OnViolationAction = OnViolationAction.RAISE, - ): - super().__init__( - model_dir=model_dir, - threshold=threshold, - temperature=temperature, - mode=PromptGuardShield.Mode.JAILBREAK, - on_violation_action=on_violation_action, - ) - - -class InjectionShield(PromptGuardShield): - def __init__( - self, - model_dir: str, - threshold: float = 0.9, - temperature: float = 1.0, - on_violation_action: OnViolationAction = OnViolationAction.RAISE, - ): - super().__init__( - model_dir=model_dir, - threshold=threshold, - temperature=temperature, - mode=PromptGuardShield.Mode.INJECTION, - on_violation_action=on_violation_action, - ) diff --git a/llama_stack/providers/inline/safety/meta_reference/safety.py b/llama_stack/providers/inline/safety/meta_reference/safety.py deleted file mode 100644 index 824a7cd7e..000000000 --- a/llama_stack/providers/inline/safety/meta_reference/safety.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any, Dict, List - -from llama_stack.distribution.utils.model_utils import model_local_dir -from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.apis.safety import * # noqa: F403 -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.distribution.datatypes import Api - -from llama_stack.providers.datatypes import ShieldsProtocolPrivate - -from .base import OnViolationAction, ShieldBase -from .config import SafetyConfig -from .llama_guard import LlamaGuardShield -from .prompt_guard import InjectionShield, JailbreakShield, PromptGuardShield - - -PROMPT_GUARD_MODEL = "Prompt-Guard-86M" -SUPPORTED_SHIELDS = [ShieldType.llama_guard, ShieldType.prompt_guard] - - -class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate): - def __init__(self, config: SafetyConfig, deps) -> None: - self.config = config - self.inference_api = deps[Api.inference] - - self.available_shields = [] - if config.llama_guard_shield: - self.available_shields.append(ShieldType.llama_guard) - if config.enable_prompt_guard: - self.available_shields.append(ShieldType.prompt_guard) - - async def initialize(self) -> None: - if self.config.enable_prompt_guard: - model_dir = model_local_dir(PROMPT_GUARD_MODEL) - _ = PromptGuardShield.instance(model_dir) - - async def shutdown(self) -> None: - pass - - async def register_shield(self, shield: Shield) -> None: - if shield.shield_type not in self.available_shields: - raise ValueError(f"Shield type {shield.shield_type} not supported") - - async def run_shield( - self, - shield_id: str, - messages: List[Message], - params: Dict[str, Any] = None, - ) -> RunShieldResponse: - shield = await self.shield_store.get_shield(shield_id) - if not shield: - raise ValueError(f"Shield {shield_id} not found") - - shield_impl = self.get_shield_impl(shield) - - messages = messages.copy() - # some shields like llama-guard require the first message to be a user message - # since this might be a tool call, first role might not be user - if len(messages) > 0 and messages[0].role != Role.user.value: - messages[0] = UserMessage(content=messages[0].content) - - # TODO: we can refactor ShieldBase, etc. to be inline with the API types - res = await shield_impl.run(messages) - violation = None - if ( - res.is_violation - and shield_impl.on_violation_action != OnViolationAction.IGNORE - ): - violation = SafetyViolation( - violation_level=( - ViolationLevel.ERROR - if shield_impl.on_violation_action == OnViolationAction.RAISE - else ViolationLevel.WARN - ), - user_message=res.violation_return_message, - metadata={ - "violation_type": res.violation_type, - }, - ) - - return RunShieldResponse(violation=violation) - - def get_shield_impl(self, shield: Shield) -> ShieldBase: - if shield.shield_type == ShieldType.llama_guard: - cfg = self.config.llama_guard_shield - return LlamaGuardShield( - model=cfg.model, - inference_api=self.inference_api, - excluded_categories=cfg.excluded_categories, - ) - elif shield.shield_type == ShieldType.prompt_guard: - model_dir = model_local_dir(PROMPT_GUARD_MODEL) - subtype = shield.params.get("prompt_guard_type", "injection") - if subtype == "injection": - return InjectionShield.instance(model_dir) - elif subtype == "jailbreak": - return JailbreakShield.instance(model_dir) - else: - raise ValueError(f"Unknown prompt guard type: {subtype}") - else: - raise ValueError(f"Unknown shield type: {shield.shield_type}") diff --git a/llama_stack/providers/inline/safety/prompt_guard/__init__.py b/llama_stack/providers/inline/safety/prompt_guard/__init__.py new file mode 100644 index 000000000..087aca6d9 --- /dev/null +++ b/llama_stack/providers/inline/safety/prompt_guard/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import PromptGuardConfig # noqa: F401 + + +async def get_provider_impl(config: PromptGuardConfig, deps): + from .prompt_guard import PromptGuardSafetyImpl + + impl = PromptGuardSafetyImpl(config, deps) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/safety/prompt_guard/config.py b/llama_stack/providers/inline/safety/prompt_guard/config.py new file mode 100644 index 000000000..bddd28452 --- /dev/null +++ b/llama_stack/providers/inline/safety/prompt_guard/config.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from enum import Enum + +from pydantic import BaseModel, field_validator + + +class PromptGuardType(Enum): + injection = "injection" + jailbreak = "jailbreak" + + +class PromptGuardConfig(BaseModel): + guard_type: str = PromptGuardType.injection.value + + @classmethod + @field_validator("guard_type") + def validate_guard_type(cls, v): + if v not in [t.value for t in PromptGuardType]: + raise ValueError(f"Unknown prompt guard type: {v}") + return v diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py new file mode 100644 index 000000000..20bfdd241 --- /dev/null +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -0,0 +1,120 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List + +import torch +from termcolor import cprint + +from transformers import AutoModelForSequenceClassification, AutoTokenizer + +from llama_stack.distribution.utils.model_utils import model_local_dir +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.safety import * # noqa: F403 +from llama_models.llama3.api.datatypes import * # noqa: F403 + +from llama_stack.providers.datatypes import ShieldsProtocolPrivate + +from .config import PromptGuardConfig, PromptGuardType + + +PROMPT_GUARD_MODEL = "Prompt-Guard-86M" + + +class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): + def __init__(self, config: PromptGuardConfig, _deps) -> None: + self.config = config + + async def initialize(self) -> None: + model_dir = model_local_dir(PROMPT_GUARD_MODEL) + self.shield = PromptGuardShield(model_dir, self.config) + + async def shutdown(self) -> None: + pass + + async def register_shield(self, shield: Shield) -> None: + if shield.shield_type != ShieldType.prompt_guard: + raise ValueError(f"Unsupported shield type: {shield.shield_type}") + + async def run_shield( + self, + shield_id: str, + messages: List[Message], + params: Dict[str, Any] = None, + ) -> RunShieldResponse: + shield = await self.shield_store.get_shield(shield_id) + if not shield: + raise ValueError(f"Unknown shield {shield_id}") + + return await self.shield.run(messages) + + +class PromptGuardShield: + def __init__( + self, + model_dir: str, + config: PromptGuardConfig, + threshold: float = 0.9, + temperature: float = 1.0, + ): + assert ( + model_dir is not None + ), "Must provide a model directory for prompt injection shield" + if temperature <= 0: + raise ValueError("Temperature must be greater than 0") + + self.config = config + self.temperature = temperature + self.threshold = threshold + + self.device = "cuda" + + # load model and tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(model_dir) + self.model = AutoModelForSequenceClassification.from_pretrained( + model_dir, device_map=self.device + ) + + async def run(self, messages: List[Message]) -> RunShieldResponse: + message = messages[-1] + text = interleaved_text_media_as_str(message.content) + + # run model on messages and return response + inputs = self.tokenizer(text, return_tensors="pt") + inputs = {name: tensor.to(self.model.device) for name, tensor in inputs.items()} + with torch.no_grad(): + outputs = self.model(**inputs) + logits = outputs[0] + probabilities = torch.softmax(logits / self.temperature, dim=-1) + score_embedded = probabilities[0, 1].item() + score_malicious = probabilities[0, 2].item() + cprint( + f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}", + color="magenta", + ) + + violation = None + if self.config.guard_type == PromptGuardType.injection.value and ( + score_embedded + score_malicious > self.threshold + ): + violation = SafetyViolation( + violation_level=ViolationLevel.ERROR, + user_message="Sorry, I cannot do this.", + metadata={ + "violation_type": f"prompt_injection:embedded={score_embedded},malicious={score_malicious}", + }, + ) + elif ( + self.config.guard_type == PromptGuardType.jailbreak.value + and score_malicious > self.threshold + ): + violation = SafetyViolation( + violation_level=ViolationLevel.ERROR, + violation_type=f"prompt_injection:malicious={score_malicious}", + violation_return_message="Sorry, I cannot do this.", + ) + + return RunShieldResponse(violation=violation) diff --git a/llama_stack/providers/registry/memory.py b/llama_stack/providers/registry/memory.py index 93ecb7c13..50fd64d7b 100644 --- a/llama_stack/providers/registry/memory.py +++ b/llama_stack/providers/registry/memory.py @@ -38,11 +38,11 @@ def available_providers() -> List[ProviderSpec]: pip_packages=EMBEDDING_DEPS + ["faiss-cpu"], module="llama_stack.providers.inline.memory.faiss", config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig", - deprecation_warning="Please use the `faiss` provider instead.", + deprecation_warning="Please use the `inline::faiss` provider instead.", ), InlineProviderSpec( api=Api.memory, - provider_type="faiss", + provider_type="inline::faiss", pip_packages=EMBEDDING_DEPS + ["faiss-cpu"], module="llama_stack.providers.inline.memory.faiss", config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig", diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index fb5b6695a..63676c4f1 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -29,6 +29,43 @@ def available_providers() -> List[ProviderSpec]: api_dependencies=[ Api.inference, ], + deprecation_error=""" +Provider `meta-reference` for API `safety` does not work with the latest Llama Stack. + +- if you are using Llama Guard v3, please use the `inline::llama-guard` provider instead. +- if you are using Prompt Guard, please use the `inline::prompt-guard` provider instead. +- if you are using Code Scanner, please use the `inline::code-scanner` provider instead. + + """, + ), + InlineProviderSpec( + api=Api.safety, + provider_type="inline::llama-guard", + pip_packages=[], + module="llama_stack.providers.inline.safety.llama_guard", + config_class="llama_stack.providers.inline.safety.llama_guard.LlamaGuardConfig", + api_dependencies=[ + Api.inference, + ], + ), + InlineProviderSpec( + api=Api.safety, + provider_type="inline::prompt-guard", + pip_packages=[ + "transformers", + "torch --index-url https://download.pytorch.org/whl/cpu", + ], + module="llama_stack.providers.inline.safety.prompt_guard", + config_class="llama_stack.providers.inline.safety.prompt_guard.PromptGuardConfig", + ), + InlineProviderSpec( + api=Api.safety, + provider_type="inline::code-scanner", + pip_packages=[ + "codeshield", + ], + module="llama_stack.providers.inline.safety.code_scanner", + config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig", ), remote_provider_spec( api=Api.safety, @@ -48,14 +85,4 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig", ), ), - InlineProviderSpec( - api=Api.safety, - provider_type="meta-reference/codeshield", - pip_packages=[ - "codeshield", - ], - module="llama_stack.providers.inline.safety.meta_reference", - config_class="llama_stack.providers.inline.safety.meta_reference.CodeShieldConfig", - api_dependencies=[], - ), ] diff --git a/llama_stack/providers/remote/inference/bedrock/__init__.py b/llama_stack/providers/remote/inference/bedrock/__init__.py index a38af374a..e72c6ada9 100644 --- a/llama_stack/providers/remote/inference/bedrock/__init__.py +++ b/llama_stack/providers/remote/inference/bedrock/__init__.py @@ -3,11 +3,12 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .bedrock import BedrockInferenceAdapter from .config import BedrockConfig async def get_adapter_impl(config: BedrockConfig, _deps): + from .bedrock import BedrockInferenceAdapter + assert isinstance(config, BedrockConfig), f"Unexpected config type: {type(config)}" impl = BedrockInferenceAdapter(config) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 18cfef50d..938d05c08 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -80,6 +80,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): continue llama_model = ollama_to_llama[r["model"]] + print(f"Found model {llama_model} in Ollama") ret.append( Model( identifier=llama_model, diff --git a/llama_stack/providers/tests/agents/conftest.py b/llama_stack/providers/tests/agents/conftest.py index 7b16242cf..c2e1261f7 100644 --- a/llama_stack/providers/tests/agents/conftest.py +++ b/llama_stack/providers/tests/agents/conftest.py @@ -18,7 +18,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( { "inference": "meta_reference", - "safety": "meta_reference", + "safety": "llama_guard", "memory": "meta_reference", "agents": "meta_reference", }, @@ -28,7 +28,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( { "inference": "ollama", - "safety": "meta_reference", + "safety": "llama_guard", "memory": "meta_reference", "agents": "meta_reference", }, @@ -38,7 +38,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( { "inference": "together", - "safety": "meta_reference", + "safety": "llama_guard", # make this work with Weaviate which is what the together distro supports "memory": "meta_reference", "agents": "meta_reference", diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index b2c6d3a5e..d91337998 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -65,7 +65,6 @@ def inference_ollama(inference_model) -> ProviderFixture: inference_model = ( [inference_model] if isinstance(inference_model, str) else inference_model ) - print("!!!", inference_model) if "Llama3.1-8B-Instruct" in inference_model: pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing") @@ -162,9 +161,11 @@ async def inference_stack(request, inference_model): inference_fixture.provider_data, ) + provider_id = inference_fixture.providers[0].provider_id + print(f"Registering model {inference_model} with provider {provider_id}") await impls[Api.models].register_model( model_id=inference_model, - provider_model_id=inference_fixture.providers[0].provider_id, + provider_id=provider_id, ) return (impls[Api.inference], impls[Api.models]) diff --git a/llama_stack/providers/tests/safety/conftest.py b/llama_stack/providers/tests/safety/conftest.py index daf16aefc..cb380ce57 100644 --- a/llama_stack/providers/tests/safety/conftest.py +++ b/llama_stack/providers/tests/safety/conftest.py @@ -16,7 +16,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( { "inference": "meta_reference", - "safety": "meta_reference", + "safety": "llama_guard", }, id="meta_reference", marks=pytest.mark.meta_reference, @@ -24,7 +24,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( { "inference": "ollama", - "safety": "meta_reference", + "safety": "llama_guard", }, id="ollama", marks=pytest.mark.ollama, @@ -32,7 +32,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( { "inference": "together", - "safety": "meta_reference", + "safety": "llama_guard", }, id="together", marks=pytest.mark.together, diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index 035288cf8..10a6460cb 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -10,15 +10,14 @@ import pytest_asyncio from llama_stack.apis.shields import ShieldType from llama_stack.distribution.datatypes import Api, Provider -from llama_stack.providers.inline.safety.meta_reference import ( - LlamaGuardShieldConfig, - SafetyConfig, -) +from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig +from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig -from llama_stack.providers.tests.env import get_env_or_fail + from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from ..conftest import ProviderFixture, remote_stack_fixture +from ..env import get_env_or_fail @pytest.fixture(scope="session") @@ -34,17 +33,29 @@ def safety_model(request): @pytest.fixture(scope="session") -def safety_meta_reference(safety_model) -> ProviderFixture: +def safety_llama_guard(safety_model) -> ProviderFixture: return ProviderFixture( providers=[ Provider( - provider_id="meta-reference", - provider_type="meta-reference", - config=SafetyConfig( - llama_guard_shield=LlamaGuardShieldConfig( - model=safety_model, - ), - ).model_dump(), + provider_id="inline::llama-guard", + provider_type="inline::llama-guard", + config=LlamaGuardConfig(model=safety_model).model_dump(), + ) + ], + ) + + +# TODO: this is not tested yet; we would need to configure the run_shield() test +# and parametrize it with the "prompt" for testing depending on the safety fixture +# we are using. +@pytest.fixture(scope="session") +def safety_prompt_guard() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="inline::prompt-guard", + provider_type="inline::prompt-guard", + config=PromptGuardConfig().model_dump(), ) ], ) @@ -63,7 +74,7 @@ def safety_bedrock() -> ProviderFixture: ) -SAFETY_FIXTURES = ["meta_reference", "bedrock", "remote"] +SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"] @pytest_asyncio.fixture(scope="session") @@ -96,7 +107,21 @@ async def safety_stack(inference_model, safety_model, request): # Register the appropriate shield based on provider type provider_type = safety_fixture.providers[0].provider_type + shield = await create_and_register_shield(provider_type, safety_model, shields_impl) + provider_id = inference_fixture.providers[0].provider_id + print(f"Registering model {inference_model} with provider {provider_id}") + await impls[Api.models].register_model( + model_id=inference_model, + provider_id=provider_id, + ) + + return safety_impl, shields_impl, shield + + +async def create_and_register_shield( + provider_type: str, safety_model: str, shields_impl +): shield_config = {} shield_type = ShieldType.llama_guard identifier = "llama_guard" @@ -109,10 +134,8 @@ async def safety_stack(inference_model, safety_model, request): shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION") shield_type = ShieldType.generic_content_shield - shield = await shields_impl.register_shield( + return await shields_impl.register_shield( shield_id=identifier, shield_type=shield_type, params=shield_config, ) - - return safety_impl, shields_impl, shield diff --git a/llama_stack/templates/bedrock/build.yaml b/llama_stack/templates/bedrock/build.yaml index a3ff27949..44cc813ae 100644 --- a/llama_stack/templates/bedrock/build.yaml +++ b/llama_stack/templates/bedrock/build.yaml @@ -3,7 +3,7 @@ distribution_spec: description: Use Amazon Bedrock APIs. providers: inference: remote::bedrock - memory: meta-reference - safety: meta-reference + memory: inline::faiss + safety: inline::llama-guard agents: meta-reference telemetry: meta-reference diff --git a/llama_stack/templates/databricks/build.yaml b/llama_stack/templates/databricks/build.yaml index f6c8b50a1..aa22f54b2 100644 --- a/llama_stack/templates/databricks/build.yaml +++ b/llama_stack/templates/databricks/build.yaml @@ -3,7 +3,7 @@ distribution_spec: description: Use Databricks for running LLM inference providers: inference: remote::databricks - memory: meta-reference - safety: meta-reference + memory: inline::faiss + safety: inline::llama-guard agents: meta-reference telemetry: meta-reference diff --git a/llama_stack/templates/fireworks/build.yaml b/llama_stack/templates/fireworks/build.yaml index 5b662c213..833ce4ee2 100644 --- a/llama_stack/templates/fireworks/build.yaml +++ b/llama_stack/templates/fireworks/build.yaml @@ -6,6 +6,6 @@ distribution_spec: memory: - meta-reference - remote::weaviate - safety: meta-reference + safety: inline::llama-guard agents: meta-reference telemetry: meta-reference diff --git a/llama_stack/templates/hf-endpoint/build.yaml b/llama_stack/templates/hf-endpoint/build.yaml index 6c84e5ccf..b06ee2eb0 100644 --- a/llama_stack/templates/hf-endpoint/build.yaml +++ b/llama_stack/templates/hf-endpoint/build.yaml @@ -3,7 +3,7 @@ distribution_spec: description: "Like local, but use Hugging Face Inference Endpoints for running LLM inference.\nSee https://hf.co/docs/api-endpoints." providers: inference: remote::hf::endpoint - memory: meta-reference - safety: meta-reference + memory: inline::faiss + safety: inline::llama-guard agents: meta-reference telemetry: meta-reference diff --git a/llama_stack/templates/hf-serverless/build.yaml b/llama_stack/templates/hf-serverless/build.yaml index 32561c1fa..62ff2c953 100644 --- a/llama_stack/templates/hf-serverless/build.yaml +++ b/llama_stack/templates/hf-serverless/build.yaml @@ -3,7 +3,7 @@ distribution_spec: description: "Like local, but use Hugging Face Inference API (serverless) for running LLM inference.\nSee https://hf.co/docs/api-inference." providers: inference: remote::hf::serverless - memory: meta-reference - safety: meta-reference + memory: inline::faiss + safety: inline::llama-guard agents: meta-reference telemetry: meta-reference diff --git a/llama_stack/templates/inline-vllm/build.yaml b/llama_stack/templates/inline-vllm/build.yaml index d0fe93aa3..2e4b34bc6 100644 --- a/llama_stack/templates/inline-vllm/build.yaml +++ b/llama_stack/templates/inline-vllm/build.yaml @@ -8,6 +8,6 @@ distribution_spec: - meta-reference - remote::chromadb - remote::pgvector - safety: meta-reference + safety: inline::llama-guard agents: meta-reference telemetry: meta-reference diff --git a/llama_stack/templates/meta-reference-gpu/build.yaml b/llama_stack/templates/meta-reference-gpu/build.yaml index d0fe93aa3..2e4b34bc6 100644 --- a/llama_stack/templates/meta-reference-gpu/build.yaml +++ b/llama_stack/templates/meta-reference-gpu/build.yaml @@ -8,6 +8,6 @@ distribution_spec: - meta-reference - remote::chromadb - remote::pgvector - safety: meta-reference + safety: inline::llama-guard agents: meta-reference telemetry: meta-reference diff --git a/llama_stack/templates/meta-reference-quantized-gpu/build.yaml b/llama_stack/templates/meta-reference-quantized-gpu/build.yaml index 20500ea5a..8768bd430 100644 --- a/llama_stack/templates/meta-reference-quantized-gpu/build.yaml +++ b/llama_stack/templates/meta-reference-quantized-gpu/build.yaml @@ -8,6 +8,6 @@ distribution_spec: - meta-reference - remote::chromadb - remote::pgvector - safety: meta-reference + safety: inline::llama-guard agents: meta-reference telemetry: meta-reference diff --git a/llama_stack/templates/ollama/build.yaml b/llama_stack/templates/ollama/build.yaml index 06de2fc3c..410ae37cd 100644 --- a/llama_stack/templates/ollama/build.yaml +++ b/llama_stack/templates/ollama/build.yaml @@ -7,6 +7,6 @@ distribution_spec: - meta-reference - remote::chromadb - remote::pgvector - safety: meta-reference + safety: inline::llama-guard agents: meta-reference telemetry: meta-reference diff --git a/llama_stack/templates/remote-vllm/build.yaml b/llama_stack/templates/remote-vllm/build.yaml index ea95992f3..967b64413 100644 --- a/llama_stack/templates/remote-vllm/build.yaml +++ b/llama_stack/templates/remote-vllm/build.yaml @@ -7,6 +7,6 @@ distribution_spec: - meta-reference - remote::chromadb - remote::pgvector - safety: meta-reference + safety: inline::llama-guard agents: meta-reference telemetry: meta-reference diff --git a/llama_stack/templates/tgi/build.yaml b/llama_stack/templates/tgi/build.yaml index c5e618bb6..70c860001 100644 --- a/llama_stack/templates/tgi/build.yaml +++ b/llama_stack/templates/tgi/build.yaml @@ -7,6 +7,6 @@ distribution_spec: - meta-reference - remote::chromadb - remote::pgvector - safety: meta-reference + safety: inline::llama-guard agents: meta-reference telemetry: meta-reference diff --git a/llama_stack/templates/together/build.yaml b/llama_stack/templates/together/build.yaml index 05e59f677..614e31093 100644 --- a/llama_stack/templates/together/build.yaml +++ b/llama_stack/templates/together/build.yaml @@ -6,6 +6,6 @@ distribution_spec: memory: - meta-reference - remote::weaviate - safety: meta-reference + safety: inline::llama-guard agents: meta-reference telemetry: meta-reference