mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
chore: add mypy code scanner
Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
parent
81109a0f72
commit
52aa646e1b
2 changed files with 21 additions and 5 deletions
|
@ -15,6 +15,7 @@ from llama_stack.apis.safety import (
|
|||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
@ -29,9 +30,18 @@ ALLOWED_CODE_SCANNER_MODEL_IDS = [
|
|||
]
|
||||
|
||||
|
||||
class MetaReferenceCodeScannerSafetyImpl(Safety):
|
||||
class MetaReferenceCodeScannerSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: CodeScannerConfig, deps) -> None:
|
||||
self.config = config
|
||||
self._shield_store = None
|
||||
|
||||
@property
|
||||
def shield_store(self):
|
||||
return self._shield_store
|
||||
|
||||
@shield_store.setter
|
||||
def shield_store(self, value):
|
||||
self._shield_store = value
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
@ -49,13 +59,19 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
|||
self,
|
||||
shield_id: str,
|
||||
messages: list[Message],
|
||||
params: dict[str, Any] = None,
|
||||
params: dict[str, Any],
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if self._shield_store is None:
|
||||
raise RuntimeError("Shield store not initialized")
|
||||
|
||||
shield = await self._shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
raise ValueError(f"Shield {shield_id} not found")
|
||||
|
||||
from codeshield.cs import CodeShield
|
||||
try:
|
||||
from codeshield.cs import CodeShield # type: ignore
|
||||
except ImportError:
|
||||
raise ImportError("codeshield is not installed. Please install it to use the CodeScanner shield.") from None
|
||||
|
||||
text = "\n".join([interleaved_content_as_str(m.content) for m in messages])
|
||||
log.info(f"Running CodeScannerShield on {text[50:]}")
|
||||
|
|
|
@ -256,7 +256,7 @@ exclude = [
|
|||
"^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$",
|
||||
"^llama_stack/providers/inline/inference/vllm/",
|
||||
"^llama_stack/providers/inline/post_training/common/validator\\.py$",
|
||||
"^llama_stack/providers/inline/safety/code_scanner/",
|
||||
"^llama_stack/providers/inline/post_training/torchtune/post_training\\.py$",
|
||||
"^llama_stack/providers/inline/safety/llama_guard/",
|
||||
"^llama_stack/providers/inline/safety/prompt_guard/",
|
||||
"^llama_stack/providers/inline/scoring/basic/",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue