update starter file

This commit is contained in:
Swapna Lekkala 2025-08-14 00:04:46 -07:00
parent a0eb092b52
commit 0b429da496
8 changed files with 29 additions and 19 deletions

View file

@ -6,9 +6,7 @@
from typing import Any from typing import Any
from llama_stack.apis.inference import ( from llama_stack.apis.inference import Message
Message,
)
from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.safety.safety import ModerationObject from llama_stack.apis.safety.safety import ModerationObject
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
@ -68,6 +66,7 @@ class SafetyRouter(Safety):
list_shields_response = await self.routing_table.list_shields() list_shields_response = await self.routing_table.list_shields()
matches = [s.identifier for s in list_shields_response.data if model == s.provider_resource_id] matches = [s.identifier for s in list_shields_response.data if model == s.provider_resource_id]
if not matches: if not matches:
raise ValueError(f"No shield associated with provider_resource id {model}") raise ValueError(f"No shield associated with provider_resource id {model}")
if len(matches) > 1: if len(matches) > 1:

View file

@ -28,6 +28,7 @@ distribution_spec:
- provider_type: inline::localfs - provider_type: inline::localfs
safety: safety:
- provider_type: inline::llama-guard - provider_type: inline::llama-guard
- provider_type: inline::code-scanner
agents: agents:
- provider_type: inline::meta-reference - provider_type: inline::meta-reference
telemetry: telemetry:

View file

@ -134,6 +134,8 @@ providers:
provider_type: inline::llama-guard provider_type: inline::llama-guard
config: config:
excluded_categories: [] excluded_categories: []
- provider_id: code-scanner
provider_type: inline::code-scanner
agents: agents:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference provider_type: inline::meta-reference
@ -215,6 +217,9 @@ shields:
- shield_id: llama-guard - shield_id: llama-guard
provider_id: ${env.SAFETY_MODEL:+llama-guard} provider_id: ${env.SAFETY_MODEL:+llama-guard}
provider_shield_id: ${env.SAFETY_MODEL:=} provider_shield_id: ${env.SAFETY_MODEL:=}
- shield_id: code-scanner
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
vector_dbs: [] vector_dbs: []
datasets: [] datasets: []
scoring_fns: [] scoring_fns: []

View file

@ -28,6 +28,7 @@ distribution_spec:
- provider_type: inline::localfs - provider_type: inline::localfs
safety: safety:
- provider_type: inline::llama-guard - provider_type: inline::llama-guard
- provider_type: inline::code-scanner
agents: agents:
- provider_type: inline::meta-reference - provider_type: inline::meta-reference
telemetry: telemetry:

View file

@ -134,10 +134,8 @@ providers:
provider_type: inline::llama-guard provider_type: inline::llama-guard
config: config:
excluded_categories: [] excluded_categories: []
- provider_id: CodeScanner - provider_id: code-scanner
provider_type: inline::code-scanner provider_type: inline::code-scanner
config:
excluded_categories: []
agents: agents:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference provider_type: inline::meta-reference
@ -219,6 +217,9 @@ shields:
- shield_id: llama-guard - shield_id: llama-guard
provider_id: ${env.SAFETY_MODEL:+llama-guard} provider_id: ${env.SAFETY_MODEL:+llama-guard}
provider_shield_id: ${env.SAFETY_MODEL:=} provider_shield_id: ${env.SAFETY_MODEL:=}
- shield_id: code-scanner
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
vector_dbs: [] vector_dbs: []
datasets: [] datasets: []
scoring_fns: [] scoring_fns: []

View file

@ -15,19 +15,14 @@ from llama_stack.core.datatypes import (
ToolGroupInput, ToolGroupInput,
) )
from llama_stack.core.utils.dynamic import instantiate_class_type from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.distributions.template import ( from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
DistributionTemplate,
RunConfigSettings,
)
from llama_stack.providers.datatypes import RemoteProviderSpec from llama_stack.providers.datatypes import RemoteProviderSpec
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
from llama_stack.providers.inline.inference.sentence_transformers import ( from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig, SentenceTransformersInferenceConfig,
) )
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.milvus.config import ( from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig
MilvusVectorIOConfig,
)
from llama_stack.providers.inline.vector_io.sqlite_vec.config import ( from llama_stack.providers.inline.vector_io.sqlite_vec.config import (
SQLiteVectorIOConfig, SQLiteVectorIOConfig,
) )
@ -119,7 +114,10 @@ def get_distribution_template() -> DistributionTemplate:
BuildProvider(provider_type="remote::pgvector"), BuildProvider(provider_type="remote::pgvector"),
], ],
"files": [BuildProvider(provider_type="inline::localfs")], "files": [BuildProvider(provider_type="inline::localfs")],
"safety": [BuildProvider(provider_type="inline::llama-guard")], "safety": [
BuildProvider(provider_type="inline::llama-guard"),
BuildProvider(provider_type="inline::code-scanner"),
],
"agents": [BuildProvider(provider_type="inline::meta-reference")], "agents": [BuildProvider(provider_type="inline::meta-reference")],
"telemetry": [BuildProvider(provider_type="inline::meta-reference")], "telemetry": [BuildProvider(provider_type="inline::meta-reference")],
"post_training": [BuildProvider(provider_type="inline::huggingface")], "post_training": [BuildProvider(provider_type="inline::huggingface")],
@ -167,6 +165,11 @@ def get_distribution_template() -> DistributionTemplate:
provider_id="${env.SAFETY_MODEL:+llama-guard}", provider_id="${env.SAFETY_MODEL:+llama-guard}",
provider_shield_id="${env.SAFETY_MODEL:=}", provider_shield_id="${env.SAFETY_MODEL:=}",
), ),
ShieldInput(
shield_id="code-scanner",
provider_id="${env.CODE_SCANNER_MODEL:+code-scanner}",
provider_shield_id="${env.CODE_SCANNER_MODEL:=}",
),
] ]
return DistributionTemplate( return DistributionTemplate(

View file

@ -28,8 +28,8 @@ from .config import CodeScannerConfig
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
ALLOWED_CODE_SCANNER_MODEL_IDS = [ ALLOWED_CODE_SCANNER_MODEL_IDS = [
"CodeScanner", "code-scanner",
"CodeShield", "code-shield",
] ]

View file

@ -27,10 +27,10 @@ def data_url_from_image(file_path):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def code_scanner_shield_id(available_shields): def code_scanner_shield_id(available_shields):
if "CodeScanner" in available_shields: if "code-scanner" in available_shields:
return "CodeScanner" return "code-scanner"
pytest.skip("CodeScanner shield is not available. Skipping.") pytest.skip("code-scanner shield is not available. Skipping.")
def test_unsafe_examples(client_with_models, shield_id): def test_unsafe_examples(client_with_models, shield_id):