mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 06:00:48 +00:00
update starter file
This commit is contained in:
parent
a0eb092b52
commit
0b429da496
8 changed files with 29 additions and 19 deletions
|
@ -6,9 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import Message
|
||||||
Message,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||||
from llama_stack.apis.safety.safety import ModerationObject
|
from llama_stack.apis.safety.safety import ModerationObject
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
|
@ -68,6 +66,7 @@ class SafetyRouter(Safety):
|
||||||
list_shields_response = await self.routing_table.list_shields()
|
list_shields_response = await self.routing_table.list_shields()
|
||||||
|
|
||||||
matches = [s.identifier for s in list_shields_response.data if model == s.provider_resource_id]
|
matches = [s.identifier for s in list_shields_response.data if model == s.provider_resource_id]
|
||||||
|
|
||||||
if not matches:
|
if not matches:
|
||||||
raise ValueError(f"No shield associated with provider_resource id {model}")
|
raise ValueError(f"No shield associated with provider_resource id {model}")
|
||||||
if len(matches) > 1:
|
if len(matches) > 1:
|
||||||
|
|
|
@ -28,6 +28,7 @@ distribution_spec:
|
||||||
- provider_type: inline::localfs
|
- provider_type: inline::localfs
|
||||||
safety:
|
safety:
|
||||||
- provider_type: inline::llama-guard
|
- provider_type: inline::llama-guard
|
||||||
|
- provider_type: inline::code-scanner
|
||||||
agents:
|
agents:
|
||||||
- provider_type: inline::meta-reference
|
- provider_type: inline::meta-reference
|
||||||
telemetry:
|
telemetry:
|
||||||
|
|
|
@ -134,6 +134,8 @@ providers:
|
||||||
provider_type: inline::llama-guard
|
provider_type: inline::llama-guard
|
||||||
config:
|
config:
|
||||||
excluded_categories: []
|
excluded_categories: []
|
||||||
|
- provider_id: code-scanner
|
||||||
|
provider_type: inline::code-scanner
|
||||||
agents:
|
agents:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
|
@ -215,6 +217,9 @@ shields:
|
||||||
- shield_id: llama-guard
|
- shield_id: llama-guard
|
||||||
provider_id: ${env.SAFETY_MODEL:+llama-guard}
|
provider_id: ${env.SAFETY_MODEL:+llama-guard}
|
||||||
provider_shield_id: ${env.SAFETY_MODEL:=}
|
provider_shield_id: ${env.SAFETY_MODEL:=}
|
||||||
|
- shield_id: code-scanner
|
||||||
|
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
|
||||||
|
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
|
||||||
vector_dbs: []
|
vector_dbs: []
|
||||||
datasets: []
|
datasets: []
|
||||||
scoring_fns: []
|
scoring_fns: []
|
||||||
|
|
|
@ -28,6 +28,7 @@ distribution_spec:
|
||||||
- provider_type: inline::localfs
|
- provider_type: inline::localfs
|
||||||
safety:
|
safety:
|
||||||
- provider_type: inline::llama-guard
|
- provider_type: inline::llama-guard
|
||||||
|
- provider_type: inline::code-scanner
|
||||||
agents:
|
agents:
|
||||||
- provider_type: inline::meta-reference
|
- provider_type: inline::meta-reference
|
||||||
telemetry:
|
telemetry:
|
||||||
|
|
|
@ -134,10 +134,8 @@ providers:
|
||||||
provider_type: inline::llama-guard
|
provider_type: inline::llama-guard
|
||||||
config:
|
config:
|
||||||
excluded_categories: []
|
excluded_categories: []
|
||||||
- provider_id: CodeScanner
|
- provider_id: code-scanner
|
||||||
provider_type: inline::code-scanner
|
provider_type: inline::code-scanner
|
||||||
config:
|
|
||||||
excluded_categories: []
|
|
||||||
agents:
|
agents:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
|
@ -219,6 +217,9 @@ shields:
|
||||||
- shield_id: llama-guard
|
- shield_id: llama-guard
|
||||||
provider_id: ${env.SAFETY_MODEL:+llama-guard}
|
provider_id: ${env.SAFETY_MODEL:+llama-guard}
|
||||||
provider_shield_id: ${env.SAFETY_MODEL:=}
|
provider_shield_id: ${env.SAFETY_MODEL:=}
|
||||||
|
- shield_id: code-scanner
|
||||||
|
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
|
||||||
|
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
|
||||||
vector_dbs: []
|
vector_dbs: []
|
||||||
datasets: []
|
datasets: []
|
||||||
scoring_fns: []
|
scoring_fns: []
|
||||||
|
|
|
@ -15,19 +15,14 @@ from llama_stack.core.datatypes import (
|
||||||
ToolGroupInput,
|
ToolGroupInput,
|
||||||
)
|
)
|
||||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||||
from llama_stack.distributions.template import (
|
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
|
||||||
DistributionTemplate,
|
|
||||||
RunConfigSettings,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.datatypes import RemoteProviderSpec
|
from llama_stack.providers.datatypes import RemoteProviderSpec
|
||||||
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
||||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||||
SentenceTransformersInferenceConfig,
|
SentenceTransformersInferenceConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||||
from llama_stack.providers.inline.vector_io.milvus.config import (
|
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig
|
||||||
MilvusVectorIOConfig,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.inline.vector_io.sqlite_vec.config import (
|
from llama_stack.providers.inline.vector_io.sqlite_vec.config import (
|
||||||
SQLiteVectorIOConfig,
|
SQLiteVectorIOConfig,
|
||||||
)
|
)
|
||||||
|
@ -119,7 +114,10 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
BuildProvider(provider_type="remote::pgvector"),
|
BuildProvider(provider_type="remote::pgvector"),
|
||||||
],
|
],
|
||||||
"files": [BuildProvider(provider_type="inline::localfs")],
|
"files": [BuildProvider(provider_type="inline::localfs")],
|
||||||
"safety": [BuildProvider(provider_type="inline::llama-guard")],
|
"safety": [
|
||||||
|
BuildProvider(provider_type="inline::llama-guard"),
|
||||||
|
BuildProvider(provider_type="inline::code-scanner"),
|
||||||
|
],
|
||||||
"agents": [BuildProvider(provider_type="inline::meta-reference")],
|
"agents": [BuildProvider(provider_type="inline::meta-reference")],
|
||||||
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
|
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
|
||||||
"post_training": [BuildProvider(provider_type="inline::huggingface")],
|
"post_training": [BuildProvider(provider_type="inline::huggingface")],
|
||||||
|
@ -167,6 +165,11 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
provider_id="${env.SAFETY_MODEL:+llama-guard}",
|
provider_id="${env.SAFETY_MODEL:+llama-guard}",
|
||||||
provider_shield_id="${env.SAFETY_MODEL:=}",
|
provider_shield_id="${env.SAFETY_MODEL:=}",
|
||||||
),
|
),
|
||||||
|
ShieldInput(
|
||||||
|
shield_id="code-scanner",
|
||||||
|
provider_id="${env.CODE_SCANNER_MODEL:+code-scanner}",
|
||||||
|
provider_shield_id="${env.CODE_SCANNER_MODEL:=}",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
|
|
|
@ -28,8 +28,8 @@ from .config import CodeScannerConfig
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALLOWED_CODE_SCANNER_MODEL_IDS = [
|
ALLOWED_CODE_SCANNER_MODEL_IDS = [
|
||||||
"CodeScanner",
|
"code-scanner",
|
||||||
"CodeShield",
|
"code-shield",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -27,10 +27,10 @@ def data_url_from_image(file_path):
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def code_scanner_shield_id(available_shields):
|
def code_scanner_shield_id(available_shields):
|
||||||
if "CodeScanner" in available_shields:
|
if "code-scanner" in available_shields:
|
||||||
return "CodeScanner"
|
return "code-scanner"
|
||||||
|
|
||||||
pytest.skip("CodeScanner shield is not available. Skipping.")
|
pytest.skip("code-scanner shield is not available. Skipping.")
|
||||||
|
|
||||||
|
|
||||||
def test_unsafe_examples(client_with_models, shield_id):
|
def test_unsafe_examples(client_with_models, shield_id):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue