diff --git a/llama_stack/core/routers/safety.py b/llama_stack/core/routers/safety.py index c76673d2a..738ecded3 100644 --- a/llama_stack/core/routers/safety.py +++ b/llama_stack/core/routers/safety.py @@ -6,9 +6,7 @@ from typing import Any -from llama_stack.apis.inference import ( - Message, -) +from llama_stack.apis.inference import Message from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.safety.safety import ModerationObject from llama_stack.apis.shields import Shield @@ -68,6 +66,7 @@ class SafetyRouter(Safety): list_shields_response = await self.routing_table.list_shields() matches = [s.identifier for s in list_shields_response.data if model == s.provider_resource_id] + if not matches: raise ValueError(f"No shield associated with provider_resource id {model}") if len(matches) > 1: diff --git a/llama_stack/distributions/ci-tests/build.yaml b/llama_stack/distributions/ci-tests/build.yaml index e6e699b62..900371d95 100644 --- a/llama_stack/distributions/ci-tests/build.yaml +++ b/llama_stack/distributions/ci-tests/build.yaml @@ -28,6 +28,7 @@ distribution_spec: - provider_type: inline::localfs safety: - provider_type: inline::llama-guard + - provider_type: inline::code-scanner agents: - provider_type: inline::meta-reference telemetry: diff --git a/llama_stack/distributions/ci-tests/run.yaml b/llama_stack/distributions/ci-tests/run.yaml index 05e1b4576..0da15857c 100644 --- a/llama_stack/distributions/ci-tests/run.yaml +++ b/llama_stack/distributions/ci-tests/run.yaml @@ -134,6 +134,8 @@ providers: provider_type: inline::llama-guard config: excluded_categories: [] + - provider_id: code-scanner + provider_type: inline::code-scanner agents: - provider_id: meta-reference provider_type: inline::meta-reference @@ -215,6 +217,9 @@ shields: - shield_id: llama-guard provider_id: ${env.SAFETY_MODEL:+llama-guard} provider_shield_id: ${env.SAFETY_MODEL:=} +- shield_id: code-scanner + provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner} + provider_shield_id: ${env.CODE_SCANNER_MODEL:=} vector_dbs: [] datasets: [] scoring_fns: [] diff --git a/llama_stack/distributions/starter/build.yaml b/llama_stack/distributions/starter/build.yaml index 1a4f81d49..b5f2b9694 100644 --- a/llama_stack/distributions/starter/build.yaml +++ b/llama_stack/distributions/starter/build.yaml @@ -28,6 +28,7 @@ distribution_spec: - provider_type: inline::localfs safety: - provider_type: inline::llama-guard + - provider_type: inline::code-scanner agents: - provider_type: inline::meta-reference telemetry: diff --git a/llama_stack/distributions/starter/run.yaml b/llama_stack/distributions/starter/run.yaml index 0892b43b3..d1771580b 100644 --- a/llama_stack/distributions/starter/run.yaml +++ b/llama_stack/distributions/starter/run.yaml @@ -134,10 +134,8 @@ providers: provider_type: inline::llama-guard config: excluded_categories: [] - - provider_id: CodeScanner + - provider_id: code-scanner provider_type: inline::code-scanner - config: - excluded_categories: [] agents: - provider_id: meta-reference provider_type: inline::meta-reference @@ -219,6 +217,9 @@ shields: - shield_id: llama-guard provider_id: ${env.SAFETY_MODEL:+llama-guard} provider_shield_id: ${env.SAFETY_MODEL:=} +- shield_id: code-scanner + provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner} + provider_shield_id: ${env.CODE_SCANNER_MODEL:=} vector_dbs: [] datasets: [] scoring_fns: [] diff --git a/llama_stack/distributions/starter/starter.py b/llama_stack/distributions/starter/starter.py index 0270b68ad..90fca335a 100644 --- a/llama_stack/distributions/starter/starter.py +++ b/llama_stack/distributions/starter/starter.py @@ -15,19 +15,14 @@ from llama_stack.core.datatypes import ( ToolGroupInput, ) from llama_stack.core.utils.dynamic import instantiate_class_type -from llama_stack.distributions.template import ( - DistributionTemplate, - RunConfigSettings, -) +from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings from llama_stack.providers.datatypes import RemoteProviderSpec from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig -from llama_stack.providers.inline.vector_io.milvus.config import ( - MilvusVectorIOConfig, -) +from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig from llama_stack.providers.inline.vector_io.sqlite_vec.config import ( SQLiteVectorIOConfig, ) @@ -119,7 +114,10 @@ def get_distribution_template() -> DistributionTemplate: BuildProvider(provider_type="remote::pgvector"), ], "files": [BuildProvider(provider_type="inline::localfs")], - "safety": [BuildProvider(provider_type="inline::llama-guard")], + "safety": [ + BuildProvider(provider_type="inline::llama-guard"), + BuildProvider(provider_type="inline::code-scanner"), + ], "agents": [BuildProvider(provider_type="inline::meta-reference")], "telemetry": [BuildProvider(provider_type="inline::meta-reference")], "post_training": [BuildProvider(provider_type="inline::huggingface")], @@ -167,6 +165,11 @@ def get_distribution_template() -> DistributionTemplate: provider_id="${env.SAFETY_MODEL:+llama-guard}", provider_shield_id="${env.SAFETY_MODEL:=}", ), + ShieldInput( + shield_id="code-scanner", + provider_id="${env.CODE_SCANNER_MODEL:+code-scanner}", + provider_shield_id="${env.CODE_SCANNER_MODEL:=}", + ), ] return DistributionTemplate( diff --git a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py index 79655e5f4..0669b65bb 100644 --- a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py +++ b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py @@ -28,8 +28,8 @@ from .config import CodeScannerConfig log = logging.getLogger(__name__) ALLOWED_CODE_SCANNER_MODEL_IDS = [ - "CodeScanner", - "CodeShield", + "code-scanner", + "code-shield", ] diff --git a/tests/integration/safety/test_safety.py b/tests/integration/safety/test_safety.py index 8419dcd94..324480848 100644 --- a/tests/integration/safety/test_safety.py +++ b/tests/integration/safety/test_safety.py @@ -27,10 +27,10 @@ def data_url_from_image(file_path): @pytest.fixture(scope="session") def code_scanner_shield_id(available_shields): - if "CodeScanner" in available_shields: - return "CodeScanner" + if "code-scanner" in available_shields: + return "code-scanner" - pytest.skip("CodeScanner shield is not available. Skipping.") + pytest.skip("code-scanner shield is not available. Skipping.") def test_unsafe_examples(client_with_models, shield_id):