mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-28 15:02:37 +00:00
move codeshield into an independent safety provider
This commit is contained in:
parent
380b9dab90
commit
4540d8bd87
10 changed files with 98 additions and 84 deletions
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import CodeShieldConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: CodeShieldConfig, deps):
|
||||
from .code_scanner import MetaReferenceCodeScannerSafetyImpl
|
||||
|
||||
impl = MetaReferenceCodeScannerSafetyImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
|
||||
from termcolor import cprint
|
||||
|
||||
from .config import CodeScannerConfig
|
||||
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
|
||||
class MetaReferenceCodeScannerSafetyImpl(Safety):
|
||||
def __init__(self, config: CodeScannerConfig, deps) -> None:
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
if shield.type != ShieldType.code_scanner.value:
|
||||
raise ValueError(f"Unsupported safety shield type: {shield.type}")
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_type: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
shield_def = await self.shield_store.get_shield(shield_type)
|
||||
if not shield_def:
|
||||
raise ValueError(f"Unknown shield {shield_type}")
|
||||
|
||||
from codeshield.cs import CodeShield
|
||||
|
||||
text = "\n".join([interleaved_text_media_as_str(m.content) for m in messages])
|
||||
cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta")
|
||||
result = await CodeShield.scan_code(text)
|
||||
|
||||
violation = None
|
||||
if result.is_insecure:
|
||||
violation = SafetyViolation(
|
||||
violation_level=(ViolationLevel.ERROR),
|
||||
user_message="Sorry, I found security concerns in the code.",
|
||||
metadata={
|
||||
"violation_type": ",".join(
|
||||
[issue.pattern_id for issue in result.issues_found]
|
||||
)
|
||||
},
|
||||
)
|
||||
return RunShieldResponse(violation=violation)
|
|
@ -0,0 +1,11 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CodeShieldConfig(BaseModel):
|
||||
pass
|
|
@ -44,7 +44,6 @@ def message_content_as_str(message: Message) -> str:
|
|||
return interleaved_text_media_as_str(message.content)
|
||||
|
||||
|
||||
# For shields that operate on simple strings
|
||||
class TextShield(ShieldBase):
|
||||
def convert_messages_to_text(self, messages: List[Message]) -> str:
|
||||
return "\n".join([message_content_as_str(m) for m in messages])
|
||||
|
@ -56,9 +55,3 @@ class TextShield(ShieldBase):
|
|||
@abstractmethod
|
||||
async def run_impl(self, text: str) -> ShieldResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DummyShield(TextShield):
|
||||
async def run_impl(self, text: str) -> ShieldResponse:
|
||||
# Dummy return LOW to test e2e
|
||||
return ShieldResponse(is_violation=False)
|
|
@ -12,19 +12,11 @@ from llama_stack.apis.safety import * # noqa: F403
|
|||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
|
||||
OnViolationAction,
|
||||
)
|
||||
|
||||
from .base import OnViolationAction, ShieldBase
|
||||
from .config import SafetyConfig
|
||||
from .llama_guard import LlamaGuardShield
|
||||
from .prompt_guard import InjectionShield, JailbreakShield, PromptGuardShield
|
||||
|
||||
from .shields import (
|
||||
CodeScannerShield,
|
||||
InjectionShield,
|
||||
JailbreakShield,
|
||||
LlamaGuardShield,
|
||||
ShieldBase,
|
||||
)
|
||||
|
||||
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
||||
|
||||
|
@ -34,7 +26,7 @@ class MetaReferenceSafetyImpl(Safety):
|
|||
self.config = config
|
||||
self.inference_api = deps[Api.inference]
|
||||
|
||||
self.available_shields = [ShieldType.code_scanner.value]
|
||||
self.available_shields = []
|
||||
if config.llama_guard_shield:
|
||||
self.available_shields.append(ShieldType.llama_guard.value)
|
||||
if config.enable_prompt_guard:
|
||||
|
@ -42,8 +34,6 @@ class MetaReferenceSafetyImpl(Safety):
|
|||
|
||||
async def initialize(self) -> None:
|
||||
if self.config.enable_prompt_guard:
|
||||
from .shields import PromptGuardShield
|
||||
|
||||
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
||||
_ = PromptGuardShield.instance(model_dir)
|
||||
|
||||
|
@ -107,7 +97,5 @@ class MetaReferenceSafetyImpl(Safety):
|
|||
return JailbreakShield.instance(model_dir)
|
||||
else:
|
||||
raise ValueError(f"Unknown prompt guard type: {subtype}")
|
||||
elif shield.type == ShieldType.code_scanner.value:
|
||||
return CodeScannerShield.instance()
|
||||
else:
|
||||
raise ValueError(f"Unknown shield type: {shield.type}")
|
||||
|
|
|
@ -1,33 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# supress warnings and spew of logs from hugging face
|
||||
import transformers
|
||||
|
||||
from .base import ( # noqa: F401
|
||||
DummyShield,
|
||||
OnViolationAction,
|
||||
ShieldBase,
|
||||
ShieldResponse,
|
||||
TextShield,
|
||||
)
|
||||
from .code_scanner import CodeScannerShield # noqa: F401
|
||||
from .llama_guard import LlamaGuardShield # noqa: F401
|
||||
from .prompt_guard import ( # noqa: F401
|
||||
InjectionShield,
|
||||
JailbreakShield,
|
||||
PromptGuardShield,
|
||||
)
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
import os
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore")
|
|
@ -1,27 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from .base import ShieldResponse, TextShield
|
||||
|
||||
|
||||
class CodeScannerShield(TextShield):
|
||||
async def run_impl(self, text: str) -> ShieldResponse:
|
||||
from codeshield.cs import CodeShield
|
||||
|
||||
cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta")
|
||||
result = await CodeShield.scan_code(text)
|
||||
if result.is_insecure:
|
||||
return ShieldResponse(
|
||||
is_violation=True,
|
||||
violation_type=",".join(
|
||||
[issue.pattern_id for issue in result.issues_found]
|
||||
),
|
||||
violation_return_message="Sorry, I found security concerns in the code.",
|
||||
)
|
||||
else:
|
||||
return ShieldResponse(is_violation=False)
|
|
@ -21,7 +21,6 @@ def available_providers() -> List[ProviderSpec]:
|
|||
api=Api.safety,
|
||||
provider_type="meta-reference",
|
||||
pip_packages=[
|
||||
"codeshield",
|
||||
"transformers",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu",
|
||||
],
|
||||
|
@ -61,4 +60,14 @@ def available_providers() -> List[ProviderSpec]:
|
|||
provider_data_validator="llama_stack.providers.adapters.safety.together.TogetherProviderDataValidator",
|
||||
),
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.safety,
|
||||
provider_type="meta-reference/codeshield",
|
||||
pip_packages=[
|
||||
"codeshield",
|
||||
],
|
||||
module="llama_stack.providers.impls.meta_reference.codeshield",
|
||||
config_class="llama_stack.providers.impls.meta_reference.codeshield.CodeShieldConfig",
|
||||
api_dependencies=[],
|
||||
),
|
||||
]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue