From 52aa646e1b67b03391aaacf9691d8209ba67788b Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Wed, 9 Jul 2025 01:33:52 +0200 Subject: [PATCH] chore: add mypy code scanner Signed-off-by: Mustafa Elbehery --- .../safety/code_scanner/code_scanner.py | 24 +++++++++++++++---- pyproject.toml | 2 +- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py index be05ee436..ba389353c 100644 --- a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py +++ b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py @@ -15,6 +15,7 @@ from llama_stack.apis.safety import ( ViolationLevel, ) from llama_stack.apis.shields import Shield +from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) @@ -29,9 +30,18 @@ ALLOWED_CODE_SCANNER_MODEL_IDS = [ ] -class MetaReferenceCodeScannerSafetyImpl(Safety): +class MetaReferenceCodeScannerSafetyImpl(Safety, ShieldsProtocolPrivate): def __init__(self, config: CodeScannerConfig, deps) -> None: self.config = config + self._shield_store = None + + @property + def shield_store(self): + return self._shield_store + + @shield_store.setter + def shield_store(self, value): + self._shield_store = value async def initialize(self) -> None: pass @@ -49,13 +59,19 @@ class MetaReferenceCodeScannerSafetyImpl(Safety): self, shield_id: str, messages: list[Message], - params: dict[str, Any] = None, + params: dict[str, Any], ) -> RunShieldResponse: - shield = await self.shield_store.get_shield(shield_id) + if self._shield_store is None: + raise RuntimeError("Shield store not initialized") + + shield = await self._shield_store.get_shield(shield_id) if not shield: raise ValueError(f"Shield {shield_id} not found") - from codeshield.cs import CodeShield + try: + from codeshield.cs import CodeShield # type: ignore + except ImportError: + raise ImportError("codeshield is not installed. Please install it to use the CodeScanner shield.") from None text = "\n".join([interleaved_content_as_str(m.content) for m in messages]) log.info(f"Running CodeScannerShield on {text[50:]}") diff --git a/pyproject.toml b/pyproject.toml index 30598e5e3..b92efc217 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -256,7 +256,7 @@ exclude = [ "^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$", "^llama_stack/providers/inline/inference/vllm/", "^llama_stack/providers/inline/post_training/common/validator\\.py$", - "^llama_stack/providers/inline/safety/code_scanner/", + "^llama_stack/providers/inline/post_training/torchtune/post_training\\.py$", "^llama_stack/providers/inline/safety/llama_guard/", "^llama_stack/providers/inline/safety/prompt_guard/", "^llama_stack/providers/inline/scoring/basic/",