diff --git a/llama_stack/providers/impls/meta_reference/codeshield/__init__.py b/llama_stack/providers/impls/meta_reference/codeshield/__init__.py new file mode 100644 index 000000000..665c5c637 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/codeshield/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import CodeShieldConfig + + +async def get_provider_impl(config: CodeShieldConfig, deps): + from .code_scanner import MetaReferenceCodeScannerSafetyImpl + + impl = MetaReferenceCodeScannerSafetyImpl(config, deps) + await impl.initialize() + return impl diff --git a/llama_stack/providers/impls/meta_reference/codeshield/code_scanner.py b/llama_stack/providers/impls/meta_reference/codeshield/code_scanner.py new file mode 100644 index 000000000..37ea96270 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/codeshield/code_scanner.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List + +from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message +from termcolor import cprint + +from .config import CodeScannerConfig + +from llama_stack.apis.safety import * # noqa: F403 + + +class MetaReferenceCodeScannerSafetyImpl(Safety): + def __init__(self, config: CodeScannerConfig, deps) -> None: + self.config = config + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def register_shield(self, shield: ShieldDef) -> None: + if shield.type != ShieldType.code_scanner.value: + raise ValueError(f"Unsupported safety shield type: {shield.type}") + + async def run_shield( + self, + shield_type: str, + messages: List[Message], + params: Dict[str, Any] = None, + ) -> RunShieldResponse: + shield_def = await self.shield_store.get_shield(shield_type) + if not shield_def: + raise ValueError(f"Unknown shield {shield_type}") + + from codeshield.cs import CodeShield + + text = "\n".join([interleaved_text_media_as_str(m.content) for m in messages]) + cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta") + result = await CodeShield.scan_code(text) + + violation = None + if result.is_insecure: + violation = SafetyViolation( + violation_level=(ViolationLevel.ERROR), + user_message="Sorry, I found security concerns in the code.", + metadata={ + "violation_type": ",".join( + [issue.pattern_id for issue in result.issues_found] + ) + }, + ) + return RunShieldResponse(violation=violation) diff --git a/llama_stack/providers/impls/meta_reference/codeshield/config.py b/llama_stack/providers/impls/meta_reference/codeshield/config.py new file mode 100644 index 000000000..583c2c95f --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/codeshield/config.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + + +class CodeShieldConfig(BaseModel): + pass diff --git a/llama_stack/providers/impls/meta_reference/safety/shields/base.py b/llama_stack/providers/impls/meta_reference/safety/base.py similarity index 88% rename from llama_stack/providers/impls/meta_reference/safety/shields/base.py rename to llama_stack/providers/impls/meta_reference/safety/base.py index 6a03d1e61..3861a7c4a 100644 --- a/llama_stack/providers/impls/meta_reference/safety/shields/base.py +++ b/llama_stack/providers/impls/meta_reference/safety/base.py @@ -44,7 +44,6 @@ def message_content_as_str(message: Message) -> str: return interleaved_text_media_as_str(message.content) -# For shields that operate on simple strings class TextShield(ShieldBase): def convert_messages_to_text(self, messages: List[Message]) -> str: return "\n".join([message_content_as_str(m) for m in messages]) @@ -56,9 +55,3 @@ class TextShield(ShieldBase): @abstractmethod async def run_impl(self, text: str) -> ShieldResponse: raise NotImplementedError() - - -class DummyShield(TextShield): - async def run_impl(self, text: str) -> ShieldResponse: - # Dummy return LOW to test e2e - return ShieldResponse(is_violation=False) diff --git a/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py b/llama_stack/providers/impls/meta_reference/safety/llama_guard.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py rename to llama_stack/providers/impls/meta_reference/safety/llama_guard.py diff --git a/llama_stack/providers/impls/meta_reference/safety/shields/prompt_guard.py b/llama_stack/providers/impls/meta_reference/safety/prompt_guard.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/safety/shields/prompt_guard.py rename to llama_stack/providers/impls/meta_reference/safety/prompt_guard.py diff --git a/llama_stack/providers/impls/meta_reference/safety/safety.py b/llama_stack/providers/impls/meta_reference/safety/safety.py index 5d6747f9f..7457bf246 100644 --- a/llama_stack/providers/impls/meta_reference/safety/safety.py +++ b/llama_stack/providers/impls/meta_reference/safety/safety.py @@ -12,19 +12,11 @@ from llama_stack.apis.safety import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import Api -from llama_stack.providers.impls.meta_reference.safety.shields.base import ( - OnViolationAction, -) - +from .base import OnViolationAction, ShieldBase from .config import SafetyConfig +from .llama_guard import LlamaGuardShield +from .prompt_guard import InjectionShield, JailbreakShield, PromptGuardShield -from .shields import ( - CodeScannerShield, - InjectionShield, - JailbreakShield, - LlamaGuardShield, - ShieldBase, -) PROMPT_GUARD_MODEL = "Prompt-Guard-86M" @@ -34,7 +26,7 @@ class MetaReferenceSafetyImpl(Safety): self.config = config self.inference_api = deps[Api.inference] - self.available_shields = [ShieldType.code_scanner.value] + self.available_shields = [] if config.llama_guard_shield: self.available_shields.append(ShieldType.llama_guard.value) if config.enable_prompt_guard: @@ -42,8 +34,6 @@ class MetaReferenceSafetyImpl(Safety): async def initialize(self) -> None: if self.config.enable_prompt_guard: - from .shields import PromptGuardShield - model_dir = model_local_dir(PROMPT_GUARD_MODEL) _ = PromptGuardShield.instance(model_dir) @@ -107,7 +97,5 @@ class MetaReferenceSafetyImpl(Safety): return JailbreakShield.instance(model_dir) else: raise ValueError(f"Unknown prompt guard type: {subtype}") - elif shield.type == ShieldType.code_scanner.value: - return CodeScannerShield.instance() else: raise ValueError(f"Unknown shield type: {shield.type}") diff --git a/llama_stack/providers/impls/meta_reference/safety/shields/__init__.py b/llama_stack/providers/impls/meta_reference/safety/shields/__init__.py deleted file mode 100644 index 9caf10883..000000000 --- a/llama_stack/providers/impls/meta_reference/safety/shields/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -# supress warnings and spew of logs from hugging face -import transformers - -from .base import ( # noqa: F401 - DummyShield, - OnViolationAction, - ShieldBase, - ShieldResponse, - TextShield, -) -from .code_scanner import CodeScannerShield # noqa: F401 -from .llama_guard import LlamaGuardShield # noqa: F401 -from .prompt_guard import ( # noqa: F401 - InjectionShield, - JailbreakShield, - PromptGuardShield, -) - -transformers.logging.set_verbosity_error() - -import os - -os.environ["TOKENIZERS_PARALLELISM"] = "false" - -import warnings - -warnings.filterwarnings("ignore") diff --git a/llama_stack/providers/impls/meta_reference/safety/shields/code_scanner.py b/llama_stack/providers/impls/meta_reference/safety/shields/code_scanner.py deleted file mode 100644 index 9b043ff04..000000000 --- a/llama_stack/providers/impls/meta_reference/safety/shields/code_scanner.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from termcolor import cprint - -from .base import ShieldResponse, TextShield - - -class CodeScannerShield(TextShield): - async def run_impl(self, text: str) -> ShieldResponse: - from codeshield.cs import CodeShield - - cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta") - result = await CodeShield.scan_code(text) - if result.is_insecure: - return ShieldResponse( - is_violation=True, - violation_type=",".join( - [issue.pattern_id for issue in result.issues_found] - ), - violation_return_message="Sorry, I found security concerns in the code.", - ) - else: - return ShieldResponse(is_violation=False) diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index 58307be11..3fa62479a 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -21,7 +21,6 @@ def available_providers() -> List[ProviderSpec]: api=Api.safety, provider_type="meta-reference", pip_packages=[ - "codeshield", "transformers", "torch --index-url https://download.pytorch.org/whl/cpu", ], @@ -61,4 +60,14 @@ def available_providers() -> List[ProviderSpec]: provider_data_validator="llama_stack.providers.adapters.safety.together.TogetherProviderDataValidator", ), ), + InlineProviderSpec( + api=Api.safety, + provider_type="meta-reference/codeshield", + pip_packages=[ + "codeshield", + ], + module="llama_stack.providers.impls.meta_reference.codeshield", + config_class="llama_stack.providers.impls.meta_reference.codeshield.CodeShieldConfig", + api_dependencies=[], + ), ]