mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
Split safety into (llama-guard, prompt-guard, code-scanner) (#400)
Splits the meta-reference safety implementation into three distinct providers: - inline::llama-guard - inline::prompt-guard - inline::code-scanner Note that this PR is a backward incompatible change to the llama stack server. I have added deprecation_error field to ProviderSpec -- the server reads it and immediately barfs. This is used to direct the user with a specific message on what action to perform. An automagical "config upgrade" is a bit too much work to implement right now :/ (Note that we will be gradually prefixing all inline providers with inline:: -- I am only doing this for this set of new providers because otherwise existing configuration files will break even more badly.)
This commit is contained in:
parent
6d38b1690b
commit
c1f7ba3aed
47 changed files with 464 additions and 500 deletions
|
@ -19,15 +19,14 @@ providers:
|
|||
url: http://127.0.0.1:80
|
||||
safety:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
llama_guard_shield:
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
disable_input_check: false
|
||||
disable_output_check: false
|
||||
prompt_guard_shield:
|
||||
model: Prompt-Guard-86M
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
- provider_id: meta1
|
||||
provider_type: inline::prompt-guard
|
||||
config:
|
||||
model: Prompt-Guard-86M
|
||||
memory:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference
|
||||
|
|
|
@ -19,16 +19,16 @@ providers:
|
|||
url: https://api.fireworks.ai/inference
|
||||
# api_key: <ENTER_YOUR_API_KEY>
|
||||
safety:
|
||||
safety:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
llama_guard_shield:
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
disable_input_check: false
|
||||
disable_output_check: false
|
||||
prompt_guard_shield:
|
||||
model: Prompt-Guard-86M
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
- provider_id: meta1
|
||||
provider_type: inline::prompt-guard
|
||||
config:
|
||||
model: Prompt-Guard-86M
|
||||
memory:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference
|
||||
|
|
|
@ -21,7 +21,7 @@ providers:
|
|||
gpu_memory_utilization: 0.4
|
||||
enforce_eager: true
|
||||
max_tokens: 4096
|
||||
- provider_id: vllm-safety
|
||||
- provider_id: vllm-inference-safety
|
||||
provider_type: inline::vllm
|
||||
config:
|
||||
model: Llama-Guard-3-1B
|
||||
|
@ -31,14 +31,15 @@ providers:
|
|||
max_tokens: 4096
|
||||
safety:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
llama_guard_shield:
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
# Uncomment to use prompt guard
|
||||
# prompt_guard_shield:
|
||||
# model: Prompt-Guard-86M
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
# Uncomment to use prompt guard
|
||||
# - provider_id: meta1
|
||||
# provider_type: inline::prompt-guard
|
||||
# config:
|
||||
# model: Prompt-Guard-86M
|
||||
memory:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference
|
||||
|
|
|
@ -13,7 +13,7 @@ apis:
|
|||
- safety
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: meta-reference-inference
|
||||
- provider_id: inference0
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
model: Llama3.2-3B-Instruct
|
||||
|
@ -21,7 +21,7 @@ providers:
|
|||
torch_seed: null
|
||||
max_seq_len: 4096
|
||||
max_batch_size: 1
|
||||
- provider_id: meta-reference-safety
|
||||
- provider_id: inference1
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
model: Llama-Guard-3-1B
|
||||
|
@ -31,11 +31,14 @@ providers:
|
|||
max_batch_size: 1
|
||||
safety:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
llama_guard_shield:
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
- provider_id: meta1
|
||||
provider_type: inline::prompt-guard
|
||||
config:
|
||||
model: Prompt-Guard-86M
|
||||
# Uncomment to use prompt guard
|
||||
# prompt_guard_shield:
|
||||
# model: Prompt-Guard-86M
|
||||
|
|
|
@ -22,17 +22,25 @@ providers:
|
|||
torch_seed: null
|
||||
max_seq_len: 2048
|
||||
max_batch_size: 1
|
||||
- provider_id: meta1
|
||||
provider_type: meta-reference-quantized
|
||||
config:
|
||||
# not a quantized model !
|
||||
model: Llama-Guard-3-1B
|
||||
quantization: null
|
||||
torch_seed: null
|
||||
max_seq_len: 2048
|
||||
max_batch_size: 1
|
||||
safety:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
llama_guard_shield:
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
disable_input_check: false
|
||||
disable_output_check: false
|
||||
prompt_guard_shield:
|
||||
model: Prompt-Guard-86M
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
- provider_id: meta1
|
||||
provider_type: inline::prompt-guard
|
||||
config:
|
||||
model: Prompt-Guard-86M
|
||||
memory:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference
|
||||
|
|
|
@ -19,15 +19,14 @@ providers:
|
|||
url: http://127.0.0.1:14343
|
||||
safety:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
llama_guard_shield:
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
disable_input_check: false
|
||||
disable_output_check: false
|
||||
prompt_guard_shield:
|
||||
model: Prompt-Guard-86M
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
- provider_id: meta1
|
||||
provider_type: inline::prompt-guard
|
||||
config:
|
||||
model: Prompt-Guard-86M
|
||||
memory:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference
|
||||
|
|
|
@ -19,15 +19,14 @@ providers:
|
|||
url: http://127.0.0.1:14343
|
||||
safety:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
llama_guard_shield:
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
disable_input_check: false
|
||||
disable_output_check: false
|
||||
prompt_guard_shield:
|
||||
model: Prompt-Guard-86M
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
- provider_id: meta1
|
||||
provider_type: inline::prompt-guard
|
||||
config:
|
||||
model: Prompt-Guard-86M
|
||||
memory:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference
|
||||
|
|
|
@ -19,15 +19,14 @@ providers:
|
|||
url: http://127.0.0.1:8000
|
||||
safety:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
llama_guard_shield:
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
disable_input_check: false
|
||||
disable_output_check: false
|
||||
prompt_guard_shield:
|
||||
model: Prompt-Guard-86M
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
- provider_id: meta1
|
||||
provider_type: inline::prompt-guard
|
||||
config:
|
||||
model: Prompt-Guard-86M
|
||||
memory:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference
|
||||
|
|
|
@ -19,15 +19,14 @@ providers:
|
|||
url: http://127.0.0.1:5009
|
||||
safety:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
llama_guard_shield:
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
disable_input_check: false
|
||||
disable_output_check: false
|
||||
prompt_guard_shield:
|
||||
model: Prompt-Guard-86M
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
- provider_id: meta1
|
||||
provider_type: inline::prompt-guard
|
||||
config:
|
||||
model: Prompt-Guard-86M
|
||||
memory:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference
|
||||
|
|
|
@ -20,15 +20,14 @@ providers:
|
|||
# api_key: <ENTER_YOUR_API_KEY>
|
||||
safety:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
llama_guard_shield:
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
disable_input_check: false
|
||||
disable_output_check: false
|
||||
prompt_guard_shield:
|
||||
model: Prompt-Guard-86M
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
- provider_id: meta1
|
||||
provider_type: inline::prompt-guard
|
||||
config:
|
||||
model: Prompt-Guard-86M
|
||||
memory:
|
||||
- provider_id: meta0
|
||||
provider_type: remote::weaviate
|
||||
|
|
|
@ -36,9 +36,9 @@ the provider types (implementations) you want to use for these APIs.
|
|||
Tip: use <TAB> to see options for the providers.
|
||||
|
||||
> Enter provider for API inference: meta-reference
|
||||
> Enter provider for API safety: meta-reference
|
||||
> Enter provider for API safety: inline::llama-guard
|
||||
> Enter provider for API agents: meta-reference
|
||||
> Enter provider for API memory: meta-reference
|
||||
> Enter provider for API memory: inline::faiss
|
||||
> Enter provider for API datasetio: meta-reference
|
||||
> Enter provider for API scoring: meta-reference
|
||||
> Enter provider for API eval: meta-reference
|
||||
|
@ -203,8 +203,8 @@ distribution_spec:
|
|||
description: Like local, but use ollama for running LLM inference
|
||||
providers:
|
||||
inference: remote::ollama
|
||||
memory: meta-reference
|
||||
safety: meta-reference
|
||||
memory: inline::faiss
|
||||
safety: inline::llama-guard
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
image_type: conda
|
||||
|
|
|
@ -33,6 +33,10 @@ from llama_stack.distribution.store import DistributionRegistry
|
|||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
||||
|
||||
class InvalidProviderError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def api_protocol_map() -> Dict[Api, Any]:
|
||||
return {
|
||||
Api.agents: Agents,
|
||||
|
@ -102,16 +106,20 @@ async def resolve_impls(
|
|||
)
|
||||
|
||||
p = provider_registry[api][provider.provider_type]
|
||||
if p.deprecation_warning:
|
||||
if p.deprecation_error:
|
||||
cprint(p.deprecation_error, "red", attrs=["bold"])
|
||||
raise InvalidProviderError(p.deprecation_error)
|
||||
|
||||
elif p.deprecation_warning:
|
||||
cprint(
|
||||
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
||||
"red",
|
||||
"yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
p.deps__ = [a.value for a in p.api_dependencies]
|
||||
spec = ProviderWithSpec(
|
||||
spec=p,
|
||||
**(provider.dict()),
|
||||
**(provider.model_dump()),
|
||||
)
|
||||
specs[provider.provider_id] = spec
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ import functools
|
|||
import inspect
|
||||
import json
|
||||
import signal
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
@ -41,7 +42,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
|||
)
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.resolver import resolve_impls
|
||||
from llama_stack.distribution.resolver import InvalidProviderError, resolve_impls
|
||||
|
||||
from .endpoints import get_all_api_endpoints
|
||||
|
||||
|
@ -282,7 +283,13 @@ def main(
|
|||
|
||||
dist_registry, dist_kvstore = asyncio.run(create_dist_registry(config))
|
||||
|
||||
impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry))
|
||||
try:
|
||||
impls = asyncio.run(
|
||||
resolve_impls(config, get_provider_registry(), dist_registry)
|
||||
)
|
||||
except InvalidProviderError:
|
||||
sys.exit(1)
|
||||
|
||||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
|
||||
|
|
|
@ -90,6 +90,10 @@ class ProviderSpec(BaseModel):
|
|||
default=None,
|
||||
description="If this provider is deprecated, specify the warning message here",
|
||||
)
|
||||
deprecation_error: Optional[str] = Field(
|
||||
default=None,
|
||||
description="If this provider is deprecated and does NOT work, specify the error message here",
|
||||
)
|
||||
|
||||
# used internally by the resolver; this is a hack for now
|
||||
deps__: List[str] = Field(default_factory=list)
|
||||
|
|
|
@ -25,7 +25,7 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
|||
pass
|
||||
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
if shield.shield_type != ShieldType.code_scanner.value:
|
||||
if shield.shield_type != ShieldType.code_scanner:
|
||||
raise ValueError(f"Unsupported safety shield type: {shield.shield_type}")
|
||||
|
||||
async def run_shield(
|
|
@ -7,5 +7,5 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CodeShieldConfig(BaseModel):
|
||||
class CodeScannerConfig(BaseModel):
|
||||
pass
|
19
llama_stack/providers/inline/safety/llama_guard/__init__.py
Normal file
19
llama_stack/providers/inline/safety/llama_guard/__init__.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import LlamaGuardConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: LlamaGuardConfig, deps):
|
||||
from .llama_guard import LlamaGuardSafetyImpl
|
||||
|
||||
assert isinstance(
|
||||
config, LlamaGuardConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = LlamaGuardSafetyImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -4,20 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
from llama_models.sku_list import CoreModelId, safety_models
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class PromptGuardType(Enum):
|
||||
injection = "injection"
|
||||
jailbreak = "jailbreak"
|
||||
|
||||
|
||||
class LlamaGuardShieldConfig(BaseModel):
|
||||
class LlamaGuardConfig(BaseModel):
|
||||
model: str = "Llama-Guard-3-1B"
|
||||
excluded_categories: List[str] = []
|
||||
|
||||
|
@ -41,8 +35,3 @@ class LlamaGuardShieldConfig(BaseModel):
|
|||
f"Invalid model: {model}. Must be one of {permitted_models}"
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
class SafetyConfig(BaseModel):
|
||||
llama_guard_shield: Optional[LlamaGuardShieldConfig] = None
|
||||
enable_prompt_guard: Optional[bool] = False
|
|
@ -7,16 +7,21 @@
|
|||
import re
|
||||
|
||||
from string import Template
|
||||
from typing import List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
|
||||
from .config import LlamaGuardConfig
|
||||
|
||||
|
||||
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||
|
||||
SAFE_RESPONSE = "safe"
|
||||
_INSTANCE = None
|
||||
|
||||
CAT_VIOLENT_CRIMES = "Violent Crimes"
|
||||
CAT_NON_VIOLENT_CRIMES = "Non-Violent Crimes"
|
||||
|
@ -107,16 +112,52 @@ PROMPT_TEMPLATE = Template(
|
|||
)
|
||||
|
||||
|
||||
class LlamaGuardShield(ShieldBase):
|
||||
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: LlamaGuardConfig, deps) -> None:
|
||||
self.config = config
|
||||
self.inference_api = deps[Api.inference]
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.shield = LlamaGuardShield(
|
||||
model=self.config.model,
|
||||
inference_api=self.inference_api,
|
||||
excluded_categories=self.config.excluded_categories,
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
print(f"Registering shield {shield}")
|
||||
if shield.shield_type != ShieldType.llama_guard:
|
||||
raise ValueError(f"Unsupported shield type: {shield.shield_type}")
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
raise ValueError(f"Unknown shield {shield_id}")
|
||||
|
||||
messages = messages.copy()
|
||||
# some shields like llama-guard require the first message to be a user message
|
||||
# since this might be a tool call, first role might not be user
|
||||
if len(messages) > 0 and messages[0].role != Role.user.value:
|
||||
messages[0] = UserMessage(content=messages[0].content)
|
||||
|
||||
return await self.shield.run(messages)
|
||||
|
||||
|
||||
class LlamaGuardShield:
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
inference_api: Inference,
|
||||
excluded_categories: List[str] = None,
|
||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||
excluded_categories: Optional[List[str]] = None,
|
||||
):
|
||||
super().__init__(on_violation_action)
|
||||
|
||||
if excluded_categories is None:
|
||||
excluded_categories = []
|
||||
|
||||
|
@ -174,7 +215,7 @@ class LlamaGuardShield(ShieldBase):
|
|||
)
|
||||
return messages
|
||||
|
||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
||||
messages = self.validate_messages(messages)
|
||||
|
||||
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
|
||||
|
@ -195,8 +236,7 @@ class LlamaGuardShield(ShieldBase):
|
|||
content += event.delta
|
||||
|
||||
content = content.strip()
|
||||
shield_response = self.get_shield_response(content)
|
||||
return shield_response
|
||||
return self.get_shield_response(content)
|
||||
|
||||
def build_text_shield_input(self, messages: List[Message]) -> UserMessage:
|
||||
return UserMessage(content=self.build_prompt(messages))
|
||||
|
@ -250,19 +290,23 @@ class LlamaGuardShield(ShieldBase):
|
|||
conversations=conversations_str,
|
||||
)
|
||||
|
||||
def get_shield_response(self, response: str) -> ShieldResponse:
|
||||
def get_shield_response(self, response: str) -> RunShieldResponse:
|
||||
response = response.strip()
|
||||
if response == SAFE_RESPONSE:
|
||||
return ShieldResponse(is_violation=False)
|
||||
return RunShieldResponse(violation=None)
|
||||
|
||||
unsafe_code = self.check_unsafe_response(response)
|
||||
if unsafe_code:
|
||||
unsafe_code_list = unsafe_code.split(",")
|
||||
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
|
||||
return ShieldResponse(is_violation=False)
|
||||
return ShieldResponse(
|
||||
is_violation=True,
|
||||
violation_type=unsafe_code,
|
||||
violation_return_message=CANNED_RESPONSE_TEXT,
|
||||
return RunShieldResponse(violation=None)
|
||||
|
||||
return RunShieldResponse(
|
||||
violation=SafetyViolation(
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
user_message=CANNED_RESPONSE_TEXT,
|
||||
metadata={"violation_type": unsafe_code},
|
||||
),
|
||||
)
|
||||
|
||||
raise ValueError(f"Unexpected response: {response}")
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import LlamaGuardShieldConfig, SafetyConfig # noqa: F401
|
||||
|
||||
|
||||
async def get_provider_impl(config: SafetyConfig, deps):
|
||||
from .safety import MetaReferenceSafetyImpl
|
||||
|
||||
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = MetaReferenceSafetyImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,57 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
|
||||
from pydantic import BaseModel
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||
|
||||
|
||||
# TODO: clean this up; just remove this type completely
|
||||
class ShieldResponse(BaseModel):
|
||||
is_violation: bool
|
||||
violation_type: Optional[str] = None
|
||||
violation_return_message: Optional[str] = None
|
||||
|
||||
|
||||
# TODO: this is a caller / agent concern
|
||||
class OnViolationAction(Enum):
|
||||
IGNORE = 0
|
||||
WARN = 1
|
||||
RAISE = 2
|
||||
|
||||
|
||||
class ShieldBase(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||
):
|
||||
self.on_violation_action = on_violation_action
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def message_content_as_str(message: Message) -> str:
|
||||
return interleaved_text_media_as_str(message.content)
|
||||
|
||||
|
||||
class TextShield(ShieldBase):
|
||||
def convert_messages_to_text(self, messages: List[Message]) -> str:
|
||||
return "\n".join([message_content_as_str(m) for m in messages])
|
||||
|
||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||
text = self.convert_messages_to_text(messages)
|
||||
return await self.run_impl(text)
|
||||
|
||||
@abstractmethod
|
||||
async def run_impl(self, text: str) -> ShieldResponse:
|
||||
raise NotImplementedError()
|
|
@ -1,145 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import auto, Enum
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from termcolor import cprint
|
||||
|
||||
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
|
||||
|
||||
|
||||
class PromptGuardShield(TextShield):
|
||||
class Mode(Enum):
|
||||
INJECTION = auto()
|
||||
JAILBREAK = auto()
|
||||
|
||||
_instances = {}
|
||||
_model_cache = None
|
||||
|
||||
@staticmethod
|
||||
def instance(
|
||||
model_dir: str,
|
||||
threshold: float = 0.9,
|
||||
temperature: float = 1.0,
|
||||
mode: "PromptGuardShield.Mode" = Mode.JAILBREAK,
|
||||
on_violation_action=OnViolationAction.RAISE,
|
||||
) -> "PromptGuardShield":
|
||||
action_value = on_violation_action.value
|
||||
key = (model_dir, threshold, temperature, mode, action_value)
|
||||
if key not in PromptGuardShield._instances:
|
||||
PromptGuardShield._instances[key] = PromptGuardShield(
|
||||
model_dir=model_dir,
|
||||
threshold=threshold,
|
||||
temperature=temperature,
|
||||
mode=mode,
|
||||
on_violation_action=on_violation_action,
|
||||
)
|
||||
return PromptGuardShield._instances[key]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str,
|
||||
threshold: float = 0.9,
|
||||
temperature: float = 1.0,
|
||||
mode: "PromptGuardShield.Mode" = Mode.JAILBREAK,
|
||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||
):
|
||||
super().__init__(on_violation_action)
|
||||
assert (
|
||||
model_dir is not None
|
||||
), "Must provide a model directory for prompt injection shield"
|
||||
if temperature <= 0:
|
||||
raise ValueError("Temperature must be greater than 0")
|
||||
self.device = "cuda"
|
||||
if PromptGuardShield._model_cache is None:
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
# load model and tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
model_dir, device_map=self.device
|
||||
)
|
||||
PromptGuardShield._model_cache = (tokenizer, model)
|
||||
|
||||
self.tokenizer, self.model = PromptGuardShield._model_cache
|
||||
self.temperature = temperature
|
||||
self.threshold = threshold
|
||||
self.mode = mode
|
||||
|
||||
def convert_messages_to_text(self, messages: List[Message]) -> str:
|
||||
return message_content_as_str(messages[-1])
|
||||
|
||||
async def run_impl(self, text: str) -> ShieldResponse:
|
||||
# run model on messages and return response
|
||||
inputs = self.tokenizer(text, return_tensors="pt")
|
||||
inputs = {name: tensor.to(self.model.device) for name, tensor in inputs.items()}
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
logits = outputs[0]
|
||||
probabilities = torch.softmax(logits / self.temperature, dim=-1)
|
||||
score_embedded = probabilities[0, 1].item()
|
||||
score_malicious = probabilities[0, 2].item()
|
||||
cprint(
|
||||
f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}",
|
||||
color="magenta",
|
||||
)
|
||||
|
||||
if self.mode == self.Mode.INJECTION and (
|
||||
score_embedded + score_malicious > self.threshold
|
||||
):
|
||||
return ShieldResponse(
|
||||
is_violation=True,
|
||||
violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
|
||||
violation_return_message="Sorry, I cannot do this.",
|
||||
)
|
||||
elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold:
|
||||
return ShieldResponse(
|
||||
is_violation=True,
|
||||
violation_type=f"prompt_injection:malicious={score_malicious}",
|
||||
violation_return_message="Sorry, I cannot do this.",
|
||||
)
|
||||
|
||||
return ShieldResponse(
|
||||
is_violation=False,
|
||||
)
|
||||
|
||||
|
||||
class JailbreakShield(PromptGuardShield):
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str,
|
||||
threshold: float = 0.9,
|
||||
temperature: float = 1.0,
|
||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||
):
|
||||
super().__init__(
|
||||
model_dir=model_dir,
|
||||
threshold=threshold,
|
||||
temperature=temperature,
|
||||
mode=PromptGuardShield.Mode.JAILBREAK,
|
||||
on_violation_action=on_violation_action,
|
||||
)
|
||||
|
||||
|
||||
class InjectionShield(PromptGuardShield):
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str,
|
||||
threshold: float = 0.9,
|
||||
temperature: float = 1.0,
|
||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||
):
|
||||
super().__init__(
|
||||
model_dir=model_dir,
|
||||
threshold=threshold,
|
||||
temperature=temperature,
|
||||
mode=PromptGuardShield.Mode.INJECTION,
|
||||
on_violation_action=on_violation_action,
|
||||
)
|
|
@ -1,107 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
|
||||
from .base import OnViolationAction, ShieldBase
|
||||
from .config import SafetyConfig
|
||||
from .llama_guard import LlamaGuardShield
|
||||
from .prompt_guard import InjectionShield, JailbreakShield, PromptGuardShield
|
||||
|
||||
|
||||
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
||||
SUPPORTED_SHIELDS = [ShieldType.llama_guard, ShieldType.prompt_guard]
|
||||
|
||||
|
||||
class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: SafetyConfig, deps) -> None:
|
||||
self.config = config
|
||||
self.inference_api = deps[Api.inference]
|
||||
|
||||
self.available_shields = []
|
||||
if config.llama_guard_shield:
|
||||
self.available_shields.append(ShieldType.llama_guard)
|
||||
if config.enable_prompt_guard:
|
||||
self.available_shields.append(ShieldType.prompt_guard)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
if self.config.enable_prompt_guard:
|
||||
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
||||
_ = PromptGuardShield.instance(model_dir)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
if shield.shield_type not in self.available_shields:
|
||||
raise ValueError(f"Shield type {shield.shield_type} not supported")
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
raise ValueError(f"Shield {shield_id} not found")
|
||||
|
||||
shield_impl = self.get_shield_impl(shield)
|
||||
|
||||
messages = messages.copy()
|
||||
# some shields like llama-guard require the first message to be a user message
|
||||
# since this might be a tool call, first role might not be user
|
||||
if len(messages) > 0 and messages[0].role != Role.user.value:
|
||||
messages[0] = UserMessage(content=messages[0].content)
|
||||
|
||||
# TODO: we can refactor ShieldBase, etc. to be inline with the API types
|
||||
res = await shield_impl.run(messages)
|
||||
violation = None
|
||||
if (
|
||||
res.is_violation
|
||||
and shield_impl.on_violation_action != OnViolationAction.IGNORE
|
||||
):
|
||||
violation = SafetyViolation(
|
||||
violation_level=(
|
||||
ViolationLevel.ERROR
|
||||
if shield_impl.on_violation_action == OnViolationAction.RAISE
|
||||
else ViolationLevel.WARN
|
||||
),
|
||||
user_message=res.violation_return_message,
|
||||
metadata={
|
||||
"violation_type": res.violation_type,
|
||||
},
|
||||
)
|
||||
|
||||
return RunShieldResponse(violation=violation)
|
||||
|
||||
def get_shield_impl(self, shield: Shield) -> ShieldBase:
|
||||
if shield.shield_type == ShieldType.llama_guard:
|
||||
cfg = self.config.llama_guard_shield
|
||||
return LlamaGuardShield(
|
||||
model=cfg.model,
|
||||
inference_api=self.inference_api,
|
||||
excluded_categories=cfg.excluded_categories,
|
||||
)
|
||||
elif shield.shield_type == ShieldType.prompt_guard:
|
||||
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
||||
subtype = shield.params.get("prompt_guard_type", "injection")
|
||||
if subtype == "injection":
|
||||
return InjectionShield.instance(model_dir)
|
||||
elif subtype == "jailbreak":
|
||||
return JailbreakShield.instance(model_dir)
|
||||
else:
|
||||
raise ValueError(f"Unknown prompt guard type: {subtype}")
|
||||
else:
|
||||
raise ValueError(f"Unknown shield type: {shield.shield_type}")
|
15
llama_stack/providers/inline/safety/prompt_guard/__init__.py
Normal file
15
llama_stack/providers/inline/safety/prompt_guard/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import PromptGuardConfig # noqa: F401
|
||||
|
||||
|
||||
async def get_provider_impl(config: PromptGuardConfig, deps):
|
||||
from .prompt_guard import PromptGuardSafetyImpl
|
||||
|
||||
impl = PromptGuardSafetyImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
25
llama_stack/providers/inline/safety/prompt_guard/config.py
Normal file
25
llama_stack/providers/inline/safety/prompt_guard/config.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class PromptGuardType(Enum):
|
||||
injection = "injection"
|
||||
jailbreak = "jailbreak"
|
||||
|
||||
|
||||
class PromptGuardConfig(BaseModel):
|
||||
guard_type: str = PromptGuardType.injection.value
|
||||
|
||||
@classmethod
|
||||
@field_validator("guard_type")
|
||||
def validate_guard_type(cls, v):
|
||||
if v not in [t.value for t in PromptGuardType]:
|
||||
raise ValueError(f"Unknown prompt guard type: {v}")
|
||||
return v
|
120
llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py
Normal file
120
llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py
Normal file
|
@ -0,0 +1,120 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
from termcolor import cprint
|
||||
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
|
||||
from .config import PromptGuardConfig, PromptGuardType
|
||||
|
||||
|
||||
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
||||
|
||||
|
||||
class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: PromptGuardConfig, _deps) -> None:
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
||||
self.shield = PromptGuardShield(model_dir, self.config)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
if shield.shield_type != ShieldType.prompt_guard:
|
||||
raise ValueError(f"Unsupported shield type: {shield.shield_type}")
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
raise ValueError(f"Unknown shield {shield_id}")
|
||||
|
||||
return await self.shield.run(messages)
|
||||
|
||||
|
||||
class PromptGuardShield:
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str,
|
||||
config: PromptGuardConfig,
|
||||
threshold: float = 0.9,
|
||||
temperature: float = 1.0,
|
||||
):
|
||||
assert (
|
||||
model_dir is not None
|
||||
), "Must provide a model directory for prompt injection shield"
|
||||
if temperature <= 0:
|
||||
raise ValueError("Temperature must be greater than 0")
|
||||
|
||||
self.config = config
|
||||
self.temperature = temperature
|
||||
self.threshold = threshold
|
||||
|
||||
self.device = "cuda"
|
||||
|
||||
# load model and tokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||
model_dir, device_map=self.device
|
||||
)
|
||||
|
||||
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
||||
message = messages[-1]
|
||||
text = interleaved_text_media_as_str(message.content)
|
||||
|
||||
# run model on messages and return response
|
||||
inputs = self.tokenizer(text, return_tensors="pt")
|
||||
inputs = {name: tensor.to(self.model.device) for name, tensor in inputs.items()}
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
logits = outputs[0]
|
||||
probabilities = torch.softmax(logits / self.temperature, dim=-1)
|
||||
score_embedded = probabilities[0, 1].item()
|
||||
score_malicious = probabilities[0, 2].item()
|
||||
cprint(
|
||||
f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}",
|
||||
color="magenta",
|
||||
)
|
||||
|
||||
violation = None
|
||||
if self.config.guard_type == PromptGuardType.injection.value and (
|
||||
score_embedded + score_malicious > self.threshold
|
||||
):
|
||||
violation = SafetyViolation(
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
user_message="Sorry, I cannot do this.",
|
||||
metadata={
|
||||
"violation_type": f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
|
||||
},
|
||||
)
|
||||
elif (
|
||||
self.config.guard_type == PromptGuardType.jailbreak.value
|
||||
and score_malicious > self.threshold
|
||||
):
|
||||
violation = SafetyViolation(
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
violation_type=f"prompt_injection:malicious={score_malicious}",
|
||||
violation_return_message="Sorry, I cannot do this.",
|
||||
)
|
||||
|
||||
return RunShieldResponse(violation=violation)
|
|
@ -38,11 +38,11 @@ def available_providers() -> List[ProviderSpec]:
|
|||
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
||||
module="llama_stack.providers.inline.memory.faiss",
|
||||
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
||||
deprecation_warning="Please use the `faiss` provider instead.",
|
||||
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.memory,
|
||||
provider_type="faiss",
|
||||
provider_type="inline::faiss",
|
||||
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
||||
module="llama_stack.providers.inline.memory.faiss",
|
||||
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
||||
|
|
|
@ -29,6 +29,43 @@ def available_providers() -> List[ProviderSpec]:
|
|||
api_dependencies=[
|
||||
Api.inference,
|
||||
],
|
||||
deprecation_error="""
|
||||
Provider `meta-reference` for API `safety` does not work with the latest Llama Stack.
|
||||
|
||||
- if you are using Llama Guard v3, please use the `inline::llama-guard` provider instead.
|
||||
- if you are using Prompt Guard, please use the `inline::prompt-guard` provider instead.
|
||||
- if you are using Code Scanner, please use the `inline::code-scanner` provider instead.
|
||||
|
||||
""",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.safety,
|
||||
provider_type="inline::llama-guard",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.inline.safety.llama_guard",
|
||||
config_class="llama_stack.providers.inline.safety.llama_guard.LlamaGuardConfig",
|
||||
api_dependencies=[
|
||||
Api.inference,
|
||||
],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.safety,
|
||||
provider_type="inline::prompt-guard",
|
||||
pip_packages=[
|
||||
"transformers",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu",
|
||||
],
|
||||
module="llama_stack.providers.inline.safety.prompt_guard",
|
||||
config_class="llama_stack.providers.inline.safety.prompt_guard.PromptGuardConfig",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.safety,
|
||||
provider_type="inline::code-scanner",
|
||||
pip_packages=[
|
||||
"codeshield",
|
||||
],
|
||||
module="llama_stack.providers.inline.safety.code_scanner",
|
||||
config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig",
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.safety,
|
||||
|
@ -48,14 +85,4 @@ def available_providers() -> List[ProviderSpec]:
|
|||
config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig",
|
||||
),
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.safety,
|
||||
provider_type="meta-reference/codeshield",
|
||||
pip_packages=[
|
||||
"codeshield",
|
||||
],
|
||||
module="llama_stack.providers.inline.safety.meta_reference",
|
||||
config_class="llama_stack.providers.inline.safety.meta_reference.CodeShieldConfig",
|
||||
api_dependencies=[],
|
||||
),
|
||||
]
|
||||
|
|
|
@ -3,11 +3,12 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from .bedrock import BedrockInferenceAdapter
|
||||
from .config import BedrockConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: BedrockConfig, _deps):
|
||||
from .bedrock import BedrockInferenceAdapter
|
||||
|
||||
assert isinstance(config, BedrockConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = BedrockInferenceAdapter(config)
|
||||
|
|
|
@ -80,6 +80,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
continue
|
||||
|
||||
llama_model = ollama_to_llama[r["model"]]
|
||||
print(f"Found model {llama_model} in Ollama")
|
||||
ret.append(
|
||||
Model(
|
||||
identifier=llama_model,
|
||||
|
|
|
@ -18,7 +18,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
pytest.param(
|
||||
{
|
||||
"inference": "meta_reference",
|
||||
"safety": "meta_reference",
|
||||
"safety": "llama_guard",
|
||||
"memory": "meta_reference",
|
||||
"agents": "meta_reference",
|
||||
},
|
||||
|
@ -28,7 +28,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
pytest.param(
|
||||
{
|
||||
"inference": "ollama",
|
||||
"safety": "meta_reference",
|
||||
"safety": "llama_guard",
|
||||
"memory": "meta_reference",
|
||||
"agents": "meta_reference",
|
||||
},
|
||||
|
@ -38,7 +38,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
pytest.param(
|
||||
{
|
||||
"inference": "together",
|
||||
"safety": "meta_reference",
|
||||
"safety": "llama_guard",
|
||||
# make this work with Weaviate which is what the together distro supports
|
||||
"memory": "meta_reference",
|
||||
"agents": "meta_reference",
|
||||
|
|
|
@ -65,7 +65,6 @@ def inference_ollama(inference_model) -> ProviderFixture:
|
|||
inference_model = (
|
||||
[inference_model] if isinstance(inference_model, str) else inference_model
|
||||
)
|
||||
print("!!!", inference_model)
|
||||
if "Llama3.1-8B-Instruct" in inference_model:
|
||||
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")
|
||||
|
||||
|
@ -162,9 +161,11 @@ async def inference_stack(request, inference_model):
|
|||
inference_fixture.provider_data,
|
||||
)
|
||||
|
||||
provider_id = inference_fixture.providers[0].provider_id
|
||||
print(f"Registering model {inference_model} with provider {provider_id}")
|
||||
await impls[Api.models].register_model(
|
||||
model_id=inference_model,
|
||||
provider_model_id=inference_fixture.providers[0].provider_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
return (impls[Api.inference], impls[Api.models])
|
||||
|
|
|
@ -16,7 +16,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
pytest.param(
|
||||
{
|
||||
"inference": "meta_reference",
|
||||
"safety": "meta_reference",
|
||||
"safety": "llama_guard",
|
||||
},
|
||||
id="meta_reference",
|
||||
marks=pytest.mark.meta_reference,
|
||||
|
@ -24,7 +24,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
pytest.param(
|
||||
{
|
||||
"inference": "ollama",
|
||||
"safety": "meta_reference",
|
||||
"safety": "llama_guard",
|
||||
},
|
||||
id="ollama",
|
||||
marks=pytest.mark.ollama,
|
||||
|
@ -32,7 +32,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
pytest.param(
|
||||
{
|
||||
"inference": "together",
|
||||
"safety": "meta_reference",
|
||||
"safety": "llama_guard",
|
||||
},
|
||||
id="together",
|
||||
marks=pytest.mark.together,
|
||||
|
|
|
@ -10,15 +10,14 @@ import pytest_asyncio
|
|||
from llama_stack.apis.shields import ShieldType
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.providers.inline.safety.meta_reference import (
|
||||
LlamaGuardShieldConfig,
|
||||
SafetyConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
|
||||
from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig
|
||||
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
|
||||
from llama_stack.providers.tests.env import get_env_or_fail
|
||||
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
from ..env import get_env_or_fail
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
@ -34,17 +33,29 @@ def safety_model(request):
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def safety_meta_reference(safety_model) -> ProviderFixture:
|
||||
def safety_llama_guard(safety_model) -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="meta-reference",
|
||||
config=SafetyConfig(
|
||||
llama_guard_shield=LlamaGuardShieldConfig(
|
||||
model=safety_model,
|
||||
),
|
||||
).model_dump(),
|
||||
provider_id="inline::llama-guard",
|
||||
provider_type="inline::llama-guard",
|
||||
config=LlamaGuardConfig(model=safety_model).model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# TODO: this is not tested yet; we would need to configure the run_shield() test
|
||||
# and parametrize it with the "prompt" for testing depending on the safety fixture
|
||||
# we are using.
|
||||
@pytest.fixture(scope="session")
|
||||
def safety_prompt_guard() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="inline::prompt-guard",
|
||||
provider_type="inline::prompt-guard",
|
||||
config=PromptGuardConfig().model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
@ -63,7 +74,7 @@ def safety_bedrock() -> ProviderFixture:
|
|||
)
|
||||
|
||||
|
||||
SAFETY_FIXTURES = ["meta_reference", "bedrock", "remote"]
|
||||
SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
|
@ -96,7 +107,21 @@ async def safety_stack(inference_model, safety_model, request):
|
|||
|
||||
# Register the appropriate shield based on provider type
|
||||
provider_type = safety_fixture.providers[0].provider_type
|
||||
shield = await create_and_register_shield(provider_type, safety_model, shields_impl)
|
||||
|
||||
provider_id = inference_fixture.providers[0].provider_id
|
||||
print(f"Registering model {inference_model} with provider {provider_id}")
|
||||
await impls[Api.models].register_model(
|
||||
model_id=inference_model,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
return safety_impl, shields_impl, shield
|
||||
|
||||
|
||||
async def create_and_register_shield(
|
||||
provider_type: str, safety_model: str, shields_impl
|
||||
):
|
||||
shield_config = {}
|
||||
shield_type = ShieldType.llama_guard
|
||||
identifier = "llama_guard"
|
||||
|
@ -109,10 +134,8 @@ async def safety_stack(inference_model, safety_model, request):
|
|||
shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")
|
||||
shield_type = ShieldType.generic_content_shield
|
||||
|
||||
shield = await shields_impl.register_shield(
|
||||
return await shields_impl.register_shield(
|
||||
shield_id=identifier,
|
||||
shield_type=shield_type,
|
||||
params=shield_config,
|
||||
)
|
||||
|
||||
return safety_impl, shields_impl, shield
|
||||
|
|
|
@ -3,7 +3,7 @@ distribution_spec:
|
|||
description: Use Amazon Bedrock APIs.
|
||||
providers:
|
||||
inference: remote::bedrock
|
||||
memory: meta-reference
|
||||
safety: meta-reference
|
||||
memory: inline::faiss
|
||||
safety: inline::llama-guard
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
|
|
|
@ -3,7 +3,7 @@ distribution_spec:
|
|||
description: Use Databricks for running LLM inference
|
||||
providers:
|
||||
inference: remote::databricks
|
||||
memory: meta-reference
|
||||
safety: meta-reference
|
||||
memory: inline::faiss
|
||||
safety: inline::llama-guard
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
|
|
|
@ -6,6 +6,6 @@ distribution_spec:
|
|||
memory:
|
||||
- meta-reference
|
||||
- remote::weaviate
|
||||
safety: meta-reference
|
||||
safety: inline::llama-guard
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
|
|
|
@ -3,7 +3,7 @@ distribution_spec:
|
|||
description: "Like local, but use Hugging Face Inference Endpoints for running LLM inference.\nSee https://hf.co/docs/api-endpoints."
|
||||
providers:
|
||||
inference: remote::hf::endpoint
|
||||
memory: meta-reference
|
||||
safety: meta-reference
|
||||
memory: inline::faiss
|
||||
safety: inline::llama-guard
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
|
|
|
@ -3,7 +3,7 @@ distribution_spec:
|
|||
description: "Like local, but use Hugging Face Inference API (serverless) for running LLM inference.\nSee https://hf.co/docs/api-inference."
|
||||
providers:
|
||||
inference: remote::hf::serverless
|
||||
memory: meta-reference
|
||||
safety: meta-reference
|
||||
memory: inline::faiss
|
||||
safety: inline::llama-guard
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
|
|
|
@ -8,6 +8,6 @@ distribution_spec:
|
|||
- meta-reference
|
||||
- remote::chromadb
|
||||
- remote::pgvector
|
||||
safety: meta-reference
|
||||
safety: inline::llama-guard
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
|
|
|
@ -8,6 +8,6 @@ distribution_spec:
|
|||
- meta-reference
|
||||
- remote::chromadb
|
||||
- remote::pgvector
|
||||
safety: meta-reference
|
||||
safety: inline::llama-guard
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
|
|
|
@ -8,6 +8,6 @@ distribution_spec:
|
|||
- meta-reference
|
||||
- remote::chromadb
|
||||
- remote::pgvector
|
||||
safety: meta-reference
|
||||
safety: inline::llama-guard
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
|
|
|
@ -7,6 +7,6 @@ distribution_spec:
|
|||
- meta-reference
|
||||
- remote::chromadb
|
||||
- remote::pgvector
|
||||
safety: meta-reference
|
||||
safety: inline::llama-guard
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
|
|
|
@ -7,6 +7,6 @@ distribution_spec:
|
|||
- meta-reference
|
||||
- remote::chromadb
|
||||
- remote::pgvector
|
||||
safety: meta-reference
|
||||
safety: inline::llama-guard
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
|
|
|
@ -7,6 +7,6 @@ distribution_spec:
|
|||
- meta-reference
|
||||
- remote::chromadb
|
||||
- remote::pgvector
|
||||
safety: meta-reference
|
||||
safety: inline::llama-guard
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
|
|
|
@ -6,6 +6,6 @@ distribution_spec:
|
|||
memory:
|
||||
- meta-reference
|
||||
- remote::weaviate
|
||||
safety: meta-reference
|
||||
safety: inline::llama-guard
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue