mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
move codeshield into an independent safety provider
This commit is contained in:
parent
380b9dab90
commit
4540d8bd87
10 changed files with 98 additions and 84 deletions
|
@ -0,0 +1,15 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .config import CodeShieldConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config: CodeShieldConfig, deps):
|
||||||
|
from .code_scanner import MetaReferenceCodeScannerSafetyImpl
|
||||||
|
|
||||||
|
impl = MetaReferenceCodeScannerSafetyImpl(config, deps)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -0,0 +1,58 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from .config import CodeScannerConfig
|
||||||
|
|
||||||
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
class MetaReferenceCodeScannerSafetyImpl(Safety):
|
||||||
|
def __init__(self, config: CodeScannerConfig, deps) -> None:
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_shield(self, shield: ShieldDef) -> None:
|
||||||
|
if shield.type != ShieldType.code_scanner.value:
|
||||||
|
raise ValueError(f"Unsupported safety shield type: {shield.type}")
|
||||||
|
|
||||||
|
async def run_shield(
|
||||||
|
self,
|
||||||
|
shield_type: str,
|
||||||
|
messages: List[Message],
|
||||||
|
params: Dict[str, Any] = None,
|
||||||
|
) -> RunShieldResponse:
|
||||||
|
shield_def = await self.shield_store.get_shield(shield_type)
|
||||||
|
if not shield_def:
|
||||||
|
raise ValueError(f"Unknown shield {shield_type}")
|
||||||
|
|
||||||
|
from codeshield.cs import CodeShield
|
||||||
|
|
||||||
|
text = "\n".join([interleaved_text_media_as_str(m.content) for m in messages])
|
||||||
|
cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta")
|
||||||
|
result = await CodeShield.scan_code(text)
|
||||||
|
|
||||||
|
violation = None
|
||||||
|
if result.is_insecure:
|
||||||
|
violation = SafetyViolation(
|
||||||
|
violation_level=(ViolationLevel.ERROR),
|
||||||
|
user_message="Sorry, I found security concerns in the code.",
|
||||||
|
metadata={
|
||||||
|
"violation_type": ",".join(
|
||||||
|
[issue.pattern_id for issue in result.issues_found]
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return RunShieldResponse(violation=violation)
|
|
@ -0,0 +1,11 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class CodeShieldConfig(BaseModel):
|
||||||
|
pass
|
|
@ -44,7 +44,6 @@ def message_content_as_str(message: Message) -> str:
|
||||||
return interleaved_text_media_as_str(message.content)
|
return interleaved_text_media_as_str(message.content)
|
||||||
|
|
||||||
|
|
||||||
# For shields that operate on simple strings
|
|
||||||
class TextShield(ShieldBase):
|
class TextShield(ShieldBase):
|
||||||
def convert_messages_to_text(self, messages: List[Message]) -> str:
|
def convert_messages_to_text(self, messages: List[Message]) -> str:
|
||||||
return "\n".join([message_content_as_str(m) for m in messages])
|
return "\n".join([message_content_as_str(m) for m in messages])
|
||||||
|
@ -56,9 +55,3 @@ class TextShield(ShieldBase):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def run_impl(self, text: str) -> ShieldResponse:
|
async def run_impl(self, text: str) -> ShieldResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class DummyShield(TextShield):
|
|
||||||
async def run_impl(self, text: str) -> ShieldResponse:
|
|
||||||
# Dummy return LOW to test e2e
|
|
||||||
return ShieldResponse(is_violation=False)
|
|
|
@ -12,19 +12,11 @@ from llama_stack.apis.safety import * # noqa: F403
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
|
||||||
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
|
from .base import OnViolationAction, ShieldBase
|
||||||
OnViolationAction,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .config import SafetyConfig
|
from .config import SafetyConfig
|
||||||
|
from .llama_guard import LlamaGuardShield
|
||||||
|
from .prompt_guard import InjectionShield, JailbreakShield, PromptGuardShield
|
||||||
|
|
||||||
from .shields import (
|
|
||||||
CodeScannerShield,
|
|
||||||
InjectionShield,
|
|
||||||
JailbreakShield,
|
|
||||||
LlamaGuardShield,
|
|
||||||
ShieldBase,
|
|
||||||
)
|
|
||||||
|
|
||||||
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
||||||
|
|
||||||
|
@ -34,7 +26,7 @@ class MetaReferenceSafetyImpl(Safety):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = deps[Api.inference]
|
self.inference_api = deps[Api.inference]
|
||||||
|
|
||||||
self.available_shields = [ShieldType.code_scanner.value]
|
self.available_shields = []
|
||||||
if config.llama_guard_shield:
|
if config.llama_guard_shield:
|
||||||
self.available_shields.append(ShieldType.llama_guard.value)
|
self.available_shields.append(ShieldType.llama_guard.value)
|
||||||
if config.enable_prompt_guard:
|
if config.enable_prompt_guard:
|
||||||
|
@ -42,8 +34,6 @@ class MetaReferenceSafetyImpl(Safety):
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
if self.config.enable_prompt_guard:
|
if self.config.enable_prompt_guard:
|
||||||
from .shields import PromptGuardShield
|
|
||||||
|
|
||||||
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
||||||
_ = PromptGuardShield.instance(model_dir)
|
_ = PromptGuardShield.instance(model_dir)
|
||||||
|
|
||||||
|
@ -107,7 +97,5 @@ class MetaReferenceSafetyImpl(Safety):
|
||||||
return JailbreakShield.instance(model_dir)
|
return JailbreakShield.instance(model_dir)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown prompt guard type: {subtype}")
|
raise ValueError(f"Unknown prompt guard type: {subtype}")
|
||||||
elif shield.type == ShieldType.code_scanner.value:
|
|
||||||
return CodeScannerShield.instance()
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown shield type: {shield.type}")
|
raise ValueError(f"Unknown shield type: {shield.type}")
|
||||||
|
|
|
@ -1,33 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
# supress warnings and spew of logs from hugging face
|
|
||||||
import transformers
|
|
||||||
|
|
||||||
from .base import ( # noqa: F401
|
|
||||||
DummyShield,
|
|
||||||
OnViolationAction,
|
|
||||||
ShieldBase,
|
|
||||||
ShieldResponse,
|
|
||||||
TextShield,
|
|
||||||
)
|
|
||||||
from .code_scanner import CodeScannerShield # noqa: F401
|
|
||||||
from .llama_guard import LlamaGuardShield # noqa: F401
|
|
||||||
from .prompt_guard import ( # noqa: F401
|
|
||||||
InjectionShield,
|
|
||||||
JailbreakShield,
|
|
||||||
PromptGuardShield,
|
|
||||||
)
|
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
|
|
@ -1,27 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from .base import ShieldResponse, TextShield
|
|
||||||
|
|
||||||
|
|
||||||
class CodeScannerShield(TextShield):
|
|
||||||
async def run_impl(self, text: str) -> ShieldResponse:
|
|
||||||
from codeshield.cs import CodeShield
|
|
||||||
|
|
||||||
cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta")
|
|
||||||
result = await CodeShield.scan_code(text)
|
|
||||||
if result.is_insecure:
|
|
||||||
return ShieldResponse(
|
|
||||||
is_violation=True,
|
|
||||||
violation_type=",".join(
|
|
||||||
[issue.pattern_id for issue in result.issues_found]
|
|
||||||
),
|
|
||||||
violation_return_message="Sorry, I found security concerns in the code.",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return ShieldResponse(is_violation=False)
|
|
|
@ -21,7 +21,6 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
api=Api.safety,
|
api=Api.safety,
|
||||||
provider_type="meta-reference",
|
provider_type="meta-reference",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"codeshield",
|
|
||||||
"transformers",
|
"transformers",
|
||||||
"torch --index-url https://download.pytorch.org/whl/cpu",
|
"torch --index-url https://download.pytorch.org/whl/cpu",
|
||||||
],
|
],
|
||||||
|
@ -61,4 +60,14 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
provider_data_validator="llama_stack.providers.adapters.safety.together.TogetherProviderDataValidator",
|
provider_data_validator="llama_stack.providers.adapters.safety.together.TogetherProviderDataValidator",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.safety,
|
||||||
|
provider_type="meta-reference/codeshield",
|
||||||
|
pip_packages=[
|
||||||
|
"codeshield",
|
||||||
|
],
|
||||||
|
module="llama_stack.providers.impls.meta_reference.codeshield",
|
||||||
|
config_class="llama_stack.providers.impls.meta_reference.codeshield.CodeShieldConfig",
|
||||||
|
api_dependencies=[],
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue