update starter file

This commit is contained in:
Swapna Lekkala 2025-08-14 00:04:46 -07:00
parent a0eb092b52
commit 0b429da496
8 changed files with 29 additions and 19 deletions

View file

@ -6,9 +6,7 @@
from typing import Any
from llama_stack.apis.inference import (
Message,
)
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.safety.safety import ModerationObject
from llama_stack.apis.shields import Shield
@ -68,6 +66,7 @@ class SafetyRouter(Safety):
list_shields_response = await self.routing_table.list_shields()
matches = [s.identifier for s in list_shields_response.data if model == s.provider_resource_id]
if not matches:
raise ValueError(f"No shield associated with provider_resource id {model}")
if len(matches) > 1:

View file

@ -28,6 +28,7 @@ distribution_spec:
- provider_type: inline::localfs
safety:
- provider_type: inline::llama-guard
- provider_type: inline::code-scanner
agents:
- provider_type: inline::meta-reference
telemetry:

View file

@ -134,6 +134,8 @@ providers:
provider_type: inline::llama-guard
config:
excluded_categories: []
- provider_id: code-scanner
provider_type: inline::code-scanner
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
@ -215,6 +217,9 @@ shields:
- shield_id: llama-guard
provider_id: ${env.SAFETY_MODEL:+llama-guard}
provider_shield_id: ${env.SAFETY_MODEL:=}
- shield_id: code-scanner
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
vector_dbs: []
datasets: []
scoring_fns: []

View file

@ -28,6 +28,7 @@ distribution_spec:
- provider_type: inline::localfs
safety:
- provider_type: inline::llama-guard
- provider_type: inline::code-scanner
agents:
- provider_type: inline::meta-reference
telemetry:

View file

@ -134,10 +134,8 @@ providers:
provider_type: inline::llama-guard
config:
excluded_categories: []
- provider_id: CodeScanner
- provider_id: code-scanner
provider_type: inline::code-scanner
config:
excluded_categories: []
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
@ -219,6 +217,9 @@ shields:
- shield_id: llama-guard
provider_id: ${env.SAFETY_MODEL:+llama-guard}
provider_shield_id: ${env.SAFETY_MODEL:=}
- shield_id: code-scanner
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
vector_dbs: []
datasets: []
scoring_fns: []

View file

@ -15,19 +15,14 @@ from llama_stack.core.datatypes import (
ToolGroupInput,
)
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.distributions.template import (
DistributionTemplate,
RunConfigSettings,
)
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
from llama_stack.providers.datatypes import RemoteProviderSpec
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.milvus.config import (
MilvusVectorIOConfig,
)
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig
from llama_stack.providers.inline.vector_io.sqlite_vec.config import (
SQLiteVectorIOConfig,
)
@ -119,7 +114,10 @@ def get_distribution_template() -> DistributionTemplate:
BuildProvider(provider_type="remote::pgvector"),
],
"files": [BuildProvider(provider_type="inline::localfs")],
"safety": [BuildProvider(provider_type="inline::llama-guard")],
"safety": [
BuildProvider(provider_type="inline::llama-guard"),
BuildProvider(provider_type="inline::code-scanner"),
],
"agents": [BuildProvider(provider_type="inline::meta-reference")],
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
"post_training": [BuildProvider(provider_type="inline::huggingface")],
@ -167,6 +165,11 @@ def get_distribution_template() -> DistributionTemplate:
provider_id="${env.SAFETY_MODEL:+llama-guard}",
provider_shield_id="${env.SAFETY_MODEL:=}",
),
ShieldInput(
shield_id="code-scanner",
provider_id="${env.CODE_SCANNER_MODEL:+code-scanner}",
provider_shield_id="${env.CODE_SCANNER_MODEL:=}",
),
]
return DistributionTemplate(

View file

@ -28,8 +28,8 @@ from .config import CodeScannerConfig
log = logging.getLogger(__name__)
ALLOWED_CODE_SCANNER_MODEL_IDS = [
"CodeScanner",
"CodeShield",
"code-scanner",
"code-shield",
]

View file

@ -27,10 +27,10 @@ def data_url_from_image(file_path):
@pytest.fixture(scope="session")
def code_scanner_shield_id(available_shields):
if "CodeScanner" in available_shields:
return "CodeScanner"
if "code-scanner" in available_shields:
return "code-scanner"
pytest.skip("CodeScanner shield is not available. Skipping.")
pytest.skip("code-scanner shield is not available. Skipping.")
def test_unsafe_examples(client_with_models, shield_id):